## Overall plan
# Setup
    Variables: the input/output token length, bytes_per_value, types of layers (Done)
    Variables: parallelism (P4)
    Mixed-precision (P6)
# Read config.json
    Read from local file (Done)
    Read from HuggingFace (P5)
# Read hardware specification
    Read from Excel spreadsheet (P2)
# Breakdown the models into different kernels/tensors
    Manually breakdown the GQA models (done)
    MOE models (P3)
    Automatic model conversion (P4)
# Calculate time per kernels 
    Simple roofline (done)
    Theoretical tiling (done)
    Hierarchical cache (P3)
    Realistic tiling (P4)
# Output 
    Breakdown between different kernels and the shape of the tensors: Mw_Ma, Ma_Ma, Mw_Va, Vw_Va, Va_Va (Done)
    Estimated execution time for different kernels (Done)
    Estimated activities for different kernels for power analysis (P6)

Setup

In [34]:
#### Setup
# Variables: the input/output token length
token_lengths = [
    (1024, 1024),
    (8192, 1024)
]
# Variables: workload characteristics
bytes_per_value = 2
# To-do list: mixed precision between weights and activations (P6)

# Variables: parallelism (P4)

# Variables: types of layers
Ma_Mw = "Ma_Mw" # The first matrix is the activations, the second matrix is the weights
Ma_Mqa = "Ma_Mqa" # The first matrix is the activations, the second matrix is the grouped activations (special usages for QGA)
Mw_Ma = "Mw_Ma" # The first matrix is the weights, the second matrix is the activations
Ma_Ma = "Ma_Ma" # The first matrix is the activations, the second matrix is the activations

Va_Mw = "Va_Mw" # The first vector is the activations, the second matrix is the weights
Va_Mqa = "Va_Mqa" # The first vector is the activations, the second matrix is the grouped activations (special usages for QGA)
Vw_Va = "Vw_Va" # The first vector is the weights, the second vector is the activations
Va_Va = "Va_Va" # The first vector is the activations, the second vector is the activations

KVCache = "KVCache" # This is to capture the retrieval of the key/value cache
Async_KVCache = "Async_KVCache" # This is to capture the store (and some retrieval) of the key/value cache

# Flags for debug and detailed print outputs
DEBUG_PRINT = True
DETAIL_PRINT = True 

Utility functions

In [35]:
# Converts an integer number into a size string with M, G, T, P suffix.
def convert_number_to_string(count):
    # Define the suffixes for bytes, terabytes, and petabytes
    suffixes = ["M", "G", "T", "P"]
    # Define the corresponding byte multiples
    multiples = [1024**2, 1024**3, 1024**4, 1024**5]

    # Iterate over the multiples in reverse to find the largest fitting multiple
    for i in reversed(range(len(multiples))):
        if count >= multiples[i]:
            size = count / multiples[i]  # Calculate the size in the corresponding unit
            return f"{size:.1f}{suffixes[i]}"  # Return the formatted size string
    # If the count is less than 1B, return it as a string
    return str(count)

# Function to print a list with a given name
def print_list(name, list):
    print("\t==", name, "==")
    for item in list:
        print("\t", item)

Read model configuration

In [36]:
# Download the model config from HuggingFace.
# If you get a permission error see: https://huggingface.co/docs/transformers.js/en/guides/private
#!pip install --upgrade huggingface_hub

model_repo_id = "meta-llama/Llama-3.1-8B" # https://huggingface.co/meta-llama/Llama-3.1-8B

import json
from huggingface_hub import HfApi, hf_hub_download, get_collection
api = HfApi()

downloaded_path = hf_hub_download(repo_id=model_repo_id, filename="config.json")
with open(downloaded_path, "r") as file:
    data = json.load(file)  # Parse the JSON file into a Python dictionary

# Extract relevant values from the JSON data
class LLMConfig:
    def __init__(self, data):
        self.hidden_size = data["hidden_size"]
        self.intermediate_size = data["intermediate_size"]
        self.num_q_heads = data["num_attention_heads"]
        self.num_kv_heads = data["num_key_value_heads"]
        self.num_hidden_layers = data["num_hidden_layers"]
        self.sub_dmodel = self.hidden_size // self.num_q_heads
        self.grouped_q = self.num_q_heads / self.num_kv_heads

llm = LLMConfig(data)

In [37]:
# TODO: Print model structure.

Read hardware specification

In [38]:
# To-Do: Read detailed hardware specification from spreadsheet (P2)

import pandas as pd

# Define hardware specs for multiple architectures
HW_SPEC_DF = pd.DataFrame({
    "NVL-AX": {
        "Num_Cores": 32,                        # Number of cores
        "Frequency": 2.5 * 1e9,                 # Frequency in Hz
        "MACs_per_cycle": 2*1024,               # 2K MACs per cycle per core for FP16 (depth = 8)
        "Bandwidth": 342*1024*1024*1024,        # 342 GB/s 
        "L3_cache_capacity": 36*1024*1024,      # 36 MB
    },
    "JGS": {
        "Num_Cores": 192,                       # Number of cores
        "Frequency": 1.5 * 1e9,                 # Frequency in Hz
        "MACs_per_cycle": 8*1024,               # 8K MACs per cycle per core for FP32
        "Bandwidth": 30*1024*1024*1024*1024,    # 30 TB/s 
        "L3_cache_capacity": 120*1024*1024,     # 120 MB
    }
})

# Add derived row for MACs_per_second
HW_SPEC_DF.loc["MACs_per_second"] = (
    HW_SPEC_DF.loc["Num_Cores"] * HW_SPEC_DF.loc["Frequency"] * HW_SPEC_DF.loc["MACs_per_cycle"]
)

# Select active architecture
arch = "NVL-AX"  # or "JGS"
HW_SPEC = HW_SPEC_DF[arch]

# Example usage:
if DEBUG_PRINT:
    print("Selected architecture:", arch)
    print(HW_SPEC)
    print("MACs_per_second:", HW_SPEC["MACs_per_second"])


Selected architecture: NVL-AX
Num_Cores            3.200000e+01
Frequency            2.500000e+09
MACs_per_cycle       2.048000e+03
Bandwidth            3.672197e+11
L3_cache_capacity    3.774874e+07
MACs_per_second      1.638400e+14
Name: NVL-AX, dtype: float64
MACs_per_second: 163840000000000.0


Breakdown the models into different kernels/tensors

In [39]:
!pip install pandas


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [40]:
# To-Do: (1) MOE model, (2) auto-conversion, (3) parallelism 

import pandas as pd

class WorkloadProfile:
    def __init__(self):
        # Set max_colwidth to None to display full content without truncation
        pd.set_option('display.max_colwidth', 40)
        self.layers = pd.DataFrame(columns=[
            "type", "name", "m", "k", "n", "batch_size", "grouped_q", "compute", "weight", "kv_cache"
        ])
        self.total_compute = 0
        self.total_weight = 0
        self.total_kv_cache = 0

    def add_layer(self, layer_type, name, *args, **kwargs):
        dim_m = args[0]
        dim_k = args[1]
        dim_n = args[2]
        batch_size = args[3]
        grouped_q = kwargs.get('grouped_q', 1)
        compute = None
        weight = None
        kv_cache = None
        if grouped_q > 1 and layer_type not in ["Ma_Mqa", "Va_Mqa"]:
            raise ValueError(f"Invalid layer type {layer_type} for grouped_q > 1")
        match layer_type:
            case "Ma_Mw":
                compute = dim_m * dim_k * dim_n * batch_size
                weight = dim_k * dim_n
            case "Ma_Mqa":
                compute = dim_m * grouped_q * dim_k * dim_n * batch_size
            case "Mw_Ma":
                compute = dim_m * dim_k * dim_n * batch_size
                weight = dim_m * dim_k
            case "Ma_Ma":
                compute = dim_m * dim_k * dim_n * batch_size
            case "Va_Mw":
                compute = dim_m * dim_k * dim_n * batch_size
                weight = dim_k * dim_n
            case "Va_Mqa":
                compute = dim_m * grouped_q * dim_k * dim_n * batch_size
            case "Vw_Va":
                compute = dim_m * dim_k * dim_n * batch_size
                weight = dim_m * dim_k
            case "Va_Va":
                compute = dim_m * dim_k * dim_n * batch_size
            case "KVCache":
                kv_cache = dim_m * dim_k * dim_n * batch_size
            case "Async_KVCache":
                kv_cache = dim_m * dim_k * dim_n * batch_size
            case _:
                raise ValueError(f"Unknown layer_type: {layer_type}")
        new_row = {
            "type": layer_type,
            "name": name,
            "m": dim_m,
            "k": dim_k,
            "n": dim_n,
            "batch_size": batch_size,
            "grouped_q": grouped_q,
            "compute": compute,
            "weight": weight,
            "kv_cache": kv_cache
        }
        self.layers = pd.concat([self.layers, pd.DataFrame([new_row])], ignore_index=True)

    def sum_totals(self):
        self.layers["compute"] = self.layers["compute"].infer_objects(copy=False)
        self.total_compute = self.layers["compute"].fillna(0).sum()
        self.layers["weight"] = self.layers["weight"].infer_objects(copy=False)
        self.total_weight = self.layers["weight"].fillna(0).sum()
        self.layers["kv_cache"] = self.layers["kv_cache"].infer_objects(copy=False)
        self.total_kv_cache = self.layers["kv_cache"].fillna(0).sum()

def perform_prefill(llm, seq_length, p_batch_size):
    # Note some parts of the transformer are not modeled, e.g., 
    # Input and position embedding
    # Softmax
    # Residual connection and layer normalization
    # Task-specific output layer

    # Create a WorkloadProfile instance
    wl = WorkloadProfile()

    # Add layers using the WorkloadProfile object
    wl.add_layer(Ma_Mw, "Q/K/V*W", seq_length, llm.hidden_size, llm.sub_dmodel * (llm.num_q_heads+llm.num_kv_heads*2), p_batch_size)
    wl.add_layer(Ma_Mqa, "Q*gK", seq_length, llm.hidden_size, seq_length, p_batch_size, qrouped_q=llm.grouped_q)
    wl.add_layer(Ma_Mqa, "Q*gK*gV", seq_length, llm.hidden_size, seq_length, p_batch_size, qrouped_q=llm.grouped_q)
    wl.add_layer(Ma_Mw, "O*W", seq_length, llm.hidden_size, llm.hidden_size, p_batch_size)
    wl.add_layer(Ma_Mw, "FFN_up", seq_length, llm.hidden_size, llm.intermediate_size, p_batch_size)
    wl.add_layer(Ma_Mw, "FFN_gate", seq_length, llm.hidden_size, llm.intermediate_size, p_batch_size)
    wl.add_layer(Ma_Mw, "FFN_down", seq_length, llm.intermediate_size, llm.hidden_size, p_batch_size)
    # add memory to KV cache #### More optimization opportunity (P1)
    wl.add_layer(Async_KVCache, "KVCache_Store", seq_length, llm.sub_dmodel, 2*llm.num_kv_heads, p_batch_size)

    # Sum up the compute and memory usage
    wl.sum_totals()

    if DEBUG_PRINT:
        print("Prefill Compute", wl.total_compute)
        print("Prefill Weights", wl.total_weight)
        print("Prefill KV cache", wl.total_kv_cache)
        print(wl.layers)


    # Print detailed information if the flag is set
    if DETAIL_PRINT:
        print("")
        print("   Details:")
        print("   === Input parameters ===")
        print("   Context length:", seq_length)
        print("   Bytes per FP value:", bytes_per_value) # To-Do: this is set globally, not specific to the model or workload (P2)
        print("   Layers:", llm.num_hidden_layers)

        print("   Total Compute (MACs):", convert_number_to_string(wl.total_compute * llm.num_hidden_layers))
        print("   Total Weights Footprint (Bytes):", convert_number_to_string(wl.total_weight * llm.num_hidden_layers))
        print("   Total KV Cache Footprint (Bytes):", convert_number_to_string(wl.total_kv_cache * llm.num_hidden_layers))

    return wl

def perform_decode(llm, context_length, d_batch_size): # Decoding one token at a time
    # Create a WorkloadProfile instance
    wl = WorkloadProfile()

    # Add layers using the WorkloadProfile object
    wl.add_layer(Va_Mw, "Q/K/V*W", 1, llm.hidden_size, llm.sub_dmodel * (llm.num_q_heads+2*llm.num_kv_heads), d_batch_size)
    wl.add_layer(KVCache, "KVCache_Read", context_length, llm.sub_dmodel, 2*llm.num_kv_heads, d_batch_size)
    wl.add_layer(Va_Mqa, "Q*gK", 1, llm.hidden_size, context_length, d_batch_size, qrouped_q=llm.grouped_q)
    wl.add_layer(Va_Mqa, "Q*gK*gV", 1, llm.hidden_size, context_length, d_batch_size, qrouped_q=llm.grouped_q)
    wl.add_layer(Va_Mw, "O*W", 1, llm.hidden_size, llm.hidden_size, d_batch_size)
    wl.add_layer(Va_Mw, "FFN_up", 1, llm.hidden_size, llm.intermediate_size, d_batch_size)
    wl.add_layer(Va_Mw, "FFN_gate", 1, llm.hidden_size, llm.intermediate_size, d_batch_size)
    wl.add_layer(Va_Mw, "FFN_down", 1, llm.intermediate_size, llm.hidden_size, d_batch_size)
    # add the last token to KV cache #### More optimization opportunity (P1)

    # Sum up the compute and memory usage
    wl.sum_totals()

    if DEBUG_PRINT:
        print("Decode Compute", wl.total_compute)
        print("Decode Weights", wl.total_weight)
        print("Decode KV cache", wl.total_kv_cache)
        print(wl.layers)


    # Print detailed information if the flag is set
    if DETAIL_PRINT:
        print("")
        print("   Details:")
        print("   === Input parameters ===")
        print("   Context length:", context_length)
        print("   Bytes per FP value:", bytes_per_value) # To-Do: this is set globally, not specific to the model or workload (P2)
        print("   Layers:", llm.num_hidden_layers)

        print("   Total Compute (MACs):", convert_number_to_string(wl.total_compute * llm.num_hidden_layers))
        print("   Total Weights Footprint (Bytes):", convert_number_to_string(wl.total_weight * llm.num_hidden_layers))
        print("   Total KV Cache Footprint (Bytes):", convert_number_to_string(wl.total_kv_cache * llm.num_hidden_layers))

    return wl

def perform_chunkedprefill(llm, p_context_length, p_tokens, d_context_length, d_tokens): # Decoding one token at a time
    # Create a WorkloadProfile instance
    wl = WorkloadProfile()

    # Add layers using the profile object
    wl.add_layer(Ma_Mw, "Q/K/V*W", (p_tokens+d_tokens), llm.hidden_size, llm.sub_dmodel * (llm.num_q_heads+llm.num_kv_heads*2), 1)

    # decode token attention
    wl.add_layer(KVCache, "dKVCache_Read", d_context_length, llm.sub_dmodel, 2*llm.num_kv_heads, d_tokens)
    wl.add_layer(Va_Mqa, "dQ*gK", 1, llm.hidden_size, llm.sub_dmodel * llm.num_kv_heads, d_tokens, qrouped_q=llm.grouped_q)
    wl.add_layer(Va_Mqa, "dQ*gK*gV", 1, llm.hidden_size, llm.sub_dmodel * llm.num_kv_heads, d_tokens, qrouped_q=llm.grouped_q)

    # prefill token attention
    wl.add_layer(Async_KVCache, "pKVCache_Read", p_context_length, llm.sub_dmodel, 2*llm.num_kv_heads, 1)
    wl.add_layer(Ma_Mqa, "pQ*gK", p_tokens, llm.hidden_size, p_context_length, 1, qrouped_q=llm.grouped_q)
    wl.add_layer(Ma_Mqa, "pQ*gK*gV", p_tokens, llm.hidden_size, p_context_length, 1, qrouped_q=llm.grouped_q)
    wl.add_layer(Async_KVCache, "KVCache_Store", p_tokens, llm.sub_dmodel, 2*llm.num_kv_heads, 1)

    wl.add_layer(Ma_Mw, "O*W", (p_tokens+d_tokens), llm.hidden_size, llm.hidden_size, 1)
    wl.add_layer(Ma_Mw, "FFN_up", (p_tokens+d_tokens), llm.hidden_size, llm.intermediate_size, 1)
    wl.add_layer(Ma_Mw, "FFN_gate", (p_tokens+d_tokens), llm.hidden_size, llm.intermediate_size, 1)
    wl.add_layer(Ma_Mw, "FFN_down", (p_tokens+d_tokens), llm.intermediate_size, llm.hidden_size, 1)

    # Sum up the compute and memory usage
    wl.sum_totals()

    if DEBUG_PRINT:
        print("Chunked Compute", wl.total_compute)
        print("Chunked Weights", wl.total_weight)
        print("Chunked KV cache", wl.total_kv_cache)
        print(wl.layers)

    # Print detailed information if the flag is set
    if DETAIL_PRINT:
        print("")
        print("   Details:")
        print("   === Input parameters ===")
        print("   Prefill context length:", p_context_length)
        print("   Decode context length:", d_context_length)
        print("   Bytes per FP value:", bytes_per_value) # To-Do: this is set globally, not specific to the model or workload (P2)
        print("   Layers:", llm.num_hidden_layers)

        print("   Total Compute (MACs):", convert_number_to_string(wl.total_compute * llm.num_hidden_layers))
        print("   Total Weights Footprint (Bytes):", convert_number_to_string(wl.total_weight * llm.num_hidden_layers))
        print("   Total KV Cache Footprint (Bytes):", convert_number_to_string(wl.total_kv_cache * llm.num_hidden_layers))

    return wl

def perform_GEMM(dim_m, dim_k, dim_n): # Decoding one token at a time
    # Create a WorkloadProfile instance
    wl = WorkloadProfile()

    # Add layers using the profile object
    wl.add_layer(Ma_Ma, "GEMM", dim_m, dim_k, dim_n, 1)

    return wl


Kernel breakdowns

# DataFrame-based kernel time estimation and output

def estimate_kernel_times_df(wl, HW_SPEC):
    df = wl.layers.copy()
    df['compute_time_s'] = df['compute'].fillna(0) / HW_SPEC['MACs_per_second']
    df['memory_time_s'] = (df['weight'].fillna(0) + df['kv_cache'].fillna(0)) * bytes_per_value / HW_SPEC['Bandwidth']
    df['roofline_time_s'] = df[['compute_time_s', 'memory_time_s']].max(axis=1)
    return df

# Example usage for DataFrame-based output:
wl = perform_prefill(llm, 1024, 1)
kernel_times_df = estimate_kernel_times_df(wl, HW_SPEC)

print("Layer-wise kernel time estimates (DataFrame-based):")
print(kernel_times_df[['type', 'name', 'compute', 'weight', 'kv_cache', 'compute_time_s', 'memory_time_s', 'roofline_time_s']])

total_time = kernel_times_df['roofline_time_s'].sum()
print(f"Total estimated time: {total_time:.6f} seconds")

In [41]:
# Use the first tuple in token_lengths

if 0:
#for token_length in token_lengths:
    # Unpack the tuple into two variables
    input_length, output_length = token_length
    print("Input length:", input_length)
    print("Output length:", output_length)
    p_batch_size = 1
    d_batch_size = 256
    chunksize = 256

    # Perform prefill
    perform_prefill(llm, input_length, p_batch_size)
    # Perform decode
    perform_decode(llm, input_length, d_batch_size)    
    # Perform chunked prefill
    d_tokens = chunksize * output_length // (input_length + output_length)
    p_tokens = chunksize - d_tokens
    d_context_length = input_length + output_length // 2    # This is used to approximately the average decode computation
    p_context_length = int(input_length / 1.73 - p_tokens//2)             # 1.73 = sqrt(3) This is used to approximately the average prefill computation

    perform_chunkedprefill(llm, p_context_length, p_tokens, d_context_length, d_tokens)    

In [42]:
import math

#### Calculate time per kernels 
# Realistic tiling (P3)

def estimate_tiling_sizes(cache_size, bytes_per_value, num_fixed_dims, *dims):
    """
    Estimate tiling sizes for matrix multiplication given cache_size, bytes_per_value, and some fixed dimensions.
    If num_fixed_dims > 0, dims should provide the fixed dimensions in order (dim_1, dim_2, dim_3).
    The remaining dimensions will be estimated to fit the cache.

    Args:
        cache_size: total cache size available (e.g., L3 cache)
        bytes_per_value: bytes per matrix element
        num_fixed_dims: number of fixed dimensions provided (0-2)
        *dims: the fixed dimension values (dim_m, dim_k, dim_n) in order

    Returns:
        tuple: (dim_m, dim_k, dim_n) estimated tile sizes
    """
    match num_fixed_dims:
        case 0:
            # Assume cache is split equally among the three matrices
            size_per_matrix = cache_size / 3
            # For cubic root, use **(1/3)
            perfect_dim = (size_per_matrix / bytes_per_value) ** (1/2)
            return perfect_dim
        case 1:
            # Solve for x in: x * x + 2 * fixed_dim * x < cache_size / bytes_per_value
            # Rearranged: x^2 + 2 * fixed_dim * x - (cache_size / bytes_per_value) < 0
            # Use quadratic formula: x = (-b + sqrt(b^2 - 4ac)) / 2a, where a=1, b=2*fixed_dim, c=-(cache_size / bytes_per_value)
            b = dims[0]*2
            discriminant = b**2 + 4 * (cache_size / bytes_per_value)
            dim_2 = (-b + math.sqrt(discriminant)) / 2
            return dim_2
        case 2:
            # Because the third dimension is not fixed, we can use the remaining cache space
            # to estimate the third dimension
            remaining_capacity = (cache_size - (dims[0] * dims[1] * bytes_per_value))
            dim_3 = remaining_capacity / (bytes_per_value * (dims[0]+dims[1]))
            return dim_3
        
# Simple roofline: If matrix-matrix operations, assume compute bound.  If matrix-vector operations, assume memory bound.
# Assume wl is a WorkloadProfile instance (from perform_prefill, perform_decode, or perform_chunkedprefill)
# HW_SPEC should define MACs_per_second, L3_cache_capacity, Bandwidth, and Frequency
import numpy as np
def estimate_kernel_times_by_theoretical_tiler(wl, HW_SPEC):
    # Theoretical tiling: the best tiling used all the cache space for matrix A, matrix B, and matrix C
    # Add new columns if not present
    for col, dtype in [
        ("best_tiler", object),
        ("exec_cycle", float),
        ("cmp_cycle", float),
        ("mem_cycle", float),
        ("mem_cycle_to_cmp_cycle", float),
        ("bound", object)
    ]:
        if col not in wl.layers.columns:
            wl.layers[col] = pd.Series([None]*len(wl.layers), dtype=dtype)
    for idx, layer in wl.layers.iterrows():
        layer_type = layer["type"]
        layer_name = layer["name"]
        best_tiler = None
        cmp_cycle = None
        mem_cycle = None
        mem_cycle_to_cmp_cycle = None
        exec_cycle = None
        layer_bound = None
        if "M" in layer_type or "w" in layer_type:
            if not pd.isna(layer["compute"]):
                L3_cache_capacity = HW_SPEC["L3_cache_capacity"]
                wldims = [layer["m"], layer["k"], layer["n"]]
                batch_size = layer["batch_size"]

                # Step 0: use batch size to adjust dimensions so that the output matrix is as closer to square as possible
                if batch_size > 1:
                    if wldims[0] > wldims[2] and wldims[0] / wldims[2] > batch_size:
                        wldims[2] = wldims[2] * batch_size
                    elif wldims[2] > wldims[0] and wldims[2] / wldims[0] > batch_size:
                        wldims[0] = wldims[0] * batch_size
                    else:
                        ratio = wldims[0] / wldims[2]
                        batch_size = batch_size / ratio
                        wldims[2] = wldims[0] = wldims[0] * ((batch_size) ** (1/2))

                # Step 1: find the tiling dims when all 3 dims are flexible
                t_dim = estimate_tiling_sizes(L3_cache_capacity, bytes_per_value, 0)
                dims = [(0, wldims[0]), (1, wldims[1]), (2, wldims[2])]
                dims_sorted = sorted(dims, key=lambda x: x[1])
                if dims_sorted[0][1] >= t_dim:
                    tile_dims = [(0, t_dim), (1, t_dim), (2, t_dim)]
                else: # Step 2: find the tiling dims when 1 dim (the smallest) is fixed
                    t_dim = estimate_tiling_sizes(L3_cache_capacity, bytes_per_value, 1, dims_sorted[0][1])
                    if dims_sorted[1][1] >= t_dim:
                        dims_sorted[1] = (dims_sorted[1][0], t_dim)
                        dims_sorted[2] = (dims_sorted[2][0], t_dim)
                        tile_dims = sorted(dims_sorted, key=lambda x: x[0])
                    else: # Step 3: find the last tiling dim when the smallest 2 dims are fixed
                        t_dim = estimate_tiling_sizes(L3_cache_capacity, bytes_per_value, 2, dims_sorted[0][1], dims_sorted[1][1])
                        if dims_sorted[2][1] >= t_dim:
                            dims_sorted[2] = (dims_sorted[2][0], t_dim)
                            tile_dims = sorted(dims_sorted, key=lambda x: x[0])
                        else:
                            tile_dims = sorted(dims_sorted, key=lambda x: x[0])

                best_tiler = f"{tile_dims[0][1]:.1f}x{tile_dims[1][1]:.1f}x{tile_dims[2][1]:.1f}"

                # Step 4: compare the cmp_cycle and the mem_cycle and decide the bound
                block_compute = tile_dims[0][1] * tile_dims[1][1] * tile_dims[2][1]
                cmp_cycle = block_compute / HW_SPEC["MACs_per_second"] * HW_SPEC["Frequency"] * layer["compute"]/block_compute
                memory = (tile_dims[0][1] * tile_dims[1][1] + tile_dims[2][1] * tile_dims[1][1]) * bytes_per_value
                mem_cycle = memory / HW_SPEC["Bandwidth"] * HW_SPEC["Frequency"] * layer["compute"]/block_compute
                mem_cycle_to_cmp_cycle = mem_cycle / cmp_cycle if cmp_cycle else np.nan

                if cmp_cycle > mem_cycle:
                    exec_cycle = cmp_cycle
                    layer_bound = "Compute"
                else:
                    exec_cycle = mem_cycle
                    layer_bound = "Memory"
            else:
                exec_cycle = np.nan
                best_tiler = None
                cmp_cycle = np.nan
                mem_cycle = np.nan
                mem_cycle_to_cmp_cycle = np.nan
        elif layer_type == KVCache:
            if not pd.isna(layer["kv_cache"]):
                memory = layer["kv_cache"] * bytes_per_value
                exec_cycle = memory / HW_SPEC["Bandwidth"] * HW_SPEC["Frequency"]
                layer_bound = "Memory"
            else:
                exec_cycle = np.nan
        elif layer_type == Async_KVCache:
            exec_cycle = np.nan
        else:
            exec_cycle = np.nan
        # Update DataFrame in place
        wl.layers.at[idx, "best_tiler"] = best_tiler
        wl.layers.at[idx, "exec_cycle"] = exec_cycle
        wl.layers.at[idx, "cmp_cycle"] = cmp_cycle
        wl.layers.at[idx, "mem_cycle"] = mem_cycle
        wl.layers.at[idx, "mem_cycle_to_cmp_cycle"] = mem_cycle_to_cmp_cycle
        wl.layers.at[idx, "bound"] = layer_bound
    return wl.layers

In [43]:
### Output 

# Estimated execution time for different kernels (P2)
# Estimated activities for different kernels for power analysis (P6)

# Breakdown between different kernels and the shape of the tensors: Mw_Ma, Ma_Ma, Mw_Va, Vw_Va, Va_Va
input_length = 8192
output_length = 512
p_batch_size = 1
d_batch_size = 256
chunksize = 256
d_tokens = chunksize * output_length // (input_length + output_length)
p_tokens = chunksize - d_tokens
d_context_length = input_length + output_length // 2    # This is used to approximately the average decode computation
p_context_length = int(input_length / 1.73 - p_tokens//2)             # 1.73 = sqrt(3) This is used to approximately the average prefill computation

print("Input length:", input_length)
wl = perform_prefill(llm, input_length, p_batch_size)
kernel_times = estimate_kernel_times_by_theoretical_tiler(wl, HW_SPEC)
print (wl.layers)

wl = perform_decode(llm, d_context_length, d_batch_size)    
kernel_times = estimate_kernel_times_by_theoretical_tiler(wl, HW_SPEC)
print (wl.layers)

wl = perform_chunkedprefill(llm, p_context_length, p_tokens, d_context_length, d_tokens)    
kernel_times = estimate_kernel_times_by_theoretical_tiler(wl, HW_SPEC)
print (wl.layers)

compute_cycles = kernel_times.loc[kernel_times["bound"] == "Compute", "exec_cycle"].sum()
memory_cycles = kernel_times.loc[kernel_times["bound"] == "Memory", "exec_cycle"].sum()

print(f"Total compute-bound cycles: {compute_cycles:.0f}")
print(f"Total memory-bound cycles: {memory_cycles:.0f}")

Input length: 8192
Prefill Compute 2336462209024.0
Prefill Weights 218103808.0
Prefill KV cache 16777216.0
            type           name     m      k      n batch_size grouped_q  \
0          Ma_Mw        Q/K/V*W  8192   4096   6144          1         1   
1         Ma_Mqa           Q*gK  8192   4096   8192          1         1   
2         Ma_Mqa        Q*gK*gV  8192   4096   8192          1         1   
3          Ma_Mw            O*W  8192   4096   4096          1         1   
4          Ma_Mw         FFN_up  8192   4096  14336          1         1   
5          Ma_Mw       FFN_gate  8192   4096  14336          1         1   
6          Ma_Mw       FFN_down  8192  14336   4096          1         1   
7  Async_KVCache  KVCache_Store  8192    128     16          1         1   

        compute      weight    kv_cache  
0  2.061584e+11  25165824.0         NaN  
1  2.748779e+11         NaN         NaN  
2  2.748779e+11         NaN         NaN  
3  1.374390e+11  16777216.0         NaN 