## Overall plan
# Setup
    Variables: the input/output token length, bytes_per_value, types of layers (Done)
    Variables: parallelism (P4)
    Mixed-precision (P6)
# Read config.json
    Read from local file (Done)
    Read from HuggingFace (P5)
# Read hardware specification
    Read from Excel spreadsheet (P2)
# Breakdown the models into different kernels/tensors
    Manually breakdown the GQA models (done)
    MOE models (P3)
    Automatic model conversion (P4)
# Calculate time per kernels 
    Simple roofline (done)
    Theoretical tiling (done)
    Hierarchical cache (P3)
    Realistic tiling (P4)
# Output 
    Breakdown between different kernels and the shape of the tensors: Mw_Ma, Ma_Ma, Mw_Va, Vw_Va, Va_Va (Done)
    Estimated execution time for different kernels (Done)
    Estimated activities for different kernels for power analysis (P6)

Setup

In [41]:
#### Setup
# Variables: the input/output token length
token_lengths = [
    (1024, 1024),
    (8192, 1024)
]
# Variables: workload characteristics
bytes_per_value = 2
# To-do list: mixed precision between weights and activations (P6)

# Variables: parallelism (P4)

# Variables: types of layers
Ma_Mw = "Ma_Mw" # The first matrix is the activations, the second matrix is the weights
Ma_Mqa = "Ma_Mqa" # The first matrix is the activations, the second matrix is the grouped activations (special usages for QGA)
Mw_Ma = "Mw_Ma" # The first matrix is the weights, the second matrix is the activations
Ma_Ma = "Ma_Ma" # The first matrix is the activations, the second matrix is the activations

Va_Mw = "Va_Mw" # The first vector is the activations, the second matrix is the weights
Va_Mqa = "Va_Mqa" # The first vector is the activations, the second matrix is the grouped activations (special usages for QGA)
Vw_Va = "Vw_Va" # The first vector is the weights, the second vector is the activations
Va_Va = "Va_Va" # The first vector is the activations, the second vector is the activations

KVCache = "KVCache" # This is to capture the retrieval of the key/value cache
Async_KVCache = "Async_KVCache" # This is to capture the store (and some retrieval) of the key/value cache

# Flags for debug and detailed print outputs
DEBUG_PRINT = True
DETAIL_PRINT = True 

Utility functions

In [42]:
# Converts an integer number into a size string with M, G, T, P suffix.
def convert_number_to_string(count):
    # Define the suffixes for bytes, terabytes, and petabytes
    suffixes = ["M", "G", "T", "P"]
    # Define the corresponding byte multiples
    multiples = [1024**2, 1024**3, 1024**4, 1024**5]

    # Iterate over the multiples in reverse to find the largest fitting multiple
    for i in reversed(range(len(multiples))):
        if count >= multiples[i]:
            size = count / multiples[i]  # Calculate the size in the corresponding unit
            return f"{size:.1f}{suffixes[i]}"  # Return the formatted size string
    # If the count is less than 1B, return it as a string
    return str(count)

# Function to print a list with a given name
def print_list(name, list):
    print("\t==", name, "==")
    for item in list:
        print("\t", item)

Read model configuration

In [43]:
# To-Do: Read from HuggingFace (P5)

import os
import json

# Read config.json from local disk 
# Specify the input file folder
input_folder = r"C:\Users\ychen4\OneDrive - Intel Corporation\Desktop\devtool\modelsize"
config_file = "config-8B.json"
config_path = os.path.join(input_folder, config_file)

with open(config_path, "r") as file:
    data = json.load(file)  # Parse the JSON file into a Python dictionary

# Extract relevant values from the JSON data
class LLMConfig:
    def __init__(self, data):
        self.hidden_size = data["hidden_size"]
        self.intermediate_size = data["intermediate_size"]
        self.num_q_heads = data["num_attention_heads"]
        self.num_kv_heads = data["num_key_value_heads"]
        self.num_hidden_layers = data["num_hidden_layers"]
        self.sub_dmodel = self.hidden_size // self.num_q_heads
        self.grouped_q = self.num_q_heads / self.num_kv_heads

llm = LLMConfig(data)

Read hardware specification

In [44]:
# To-Do: Read detailed hardware specification from spreadsheet (P2)

# NVL-AX
HW_SPEC = {
    "Num_Cores": 32,                        # Number of cores
    "Frequency": 3 * 1e9,                 # Frequency in Hz
    #"MACs_per_cycle": 2*1024,               # 2K MACs per cycle per core for FP16 (depth = 8)
    "MACs_per_cycle": 1*1024,               # 2K MACs per cycle per core for FP16 (depth = 4)
    "Bandwidth": 342*1024*1024*1024,        # 342 GB/s 
    "L3_cache_capacity": 36*1024*1024,      # 36 MB
    }

""" 
# JGS
HW_SPEC = {
    "Num_Cores": 192,                       # Number of cores
    "Frequency": 1.5 * 1e9,                 # Frequency in Hz
    "MACs_per_cycle": 8*1024,               # 8K MACs per cycle per core for FP32
    "Bandwidth": 30*1024*1024*1024*1024,    # 30 TB/s 
    "L3_cache_capacity": 120*1024*1024,     # 120 MB
    }
"""    
HW_SPEC["MACs_per_second"] = HW_SPEC["Num_Cores"] * HW_SPEC["Frequency"] * HW_SPEC["MACs_per_cycle"]


Breakdown the models into different kernels/tensors

In [45]:
# To-Do: (1) MOE model, (2) auto-conversion, (3) parallelism 

class WorkloadProfile:
    def __init__(self):
        self.layers = []
        self.compute = []
        self.weight = []
        self.kv_cache = []
        self.total_compute = 0
        self.total_weight = 0
        self.total_kv_cache = 0

    # Function to add compute usage to a list
    def add_compute(self, list, name, m, k, n, batch_size):
        # Calculate the number of computations
        computations = m * k * n * batch_size
        # Append the result to the list
        list.append([computations, name, m, k, n, batch_size])

    # Function to add memory usage to a list
    def add_wmemory(self, list, name, m, n):
        # Calculate the memory usage
        memory = m * n
        # Append the result to the list
        list.append([memory, name, m, n])

    # Function to add memory usage to a list
    def add_kv_cache(self, list, name, m, n, batch_size):
        # Calculate the memory usage
        cache = m * n * batch_size
        # Append the result to the list
        list.append([cache, name, m, n, batch_size])

    def add_layer(self, layer_type, name, *args, **kwargs):
        dim_m = args[0]
        dim_k = args[1]
        dim_n = args[2]
        batch_size = args[3]
        grouped_q = kwargs.get('grouped_q', 1)
        # Check if the layer type is valid
        if grouped_q > 1 and layer_type not in ["Ma_Mqa", "Va_Mqa"]:
            raise ValueError(f"Invalid layer type {layer_type} for grouped_q > 1")

        match layer_type:
            case "Ma_Mw":
                self.add_compute(self.compute, name, dim_m, dim_k, dim_n, batch_size)
                self.add_wmemory(self.weight, name, dim_k, dim_n)
            case "Ma_Mqa":
                self.add_compute(self.compute, name, dim_m*grouped_q, dim_k, dim_n, batch_size)
                # more optimization opportunity (P1)
            case "Mw_Ma":
                self.add_compute(self.compute, name, dim_m, dim_k, dim_n, batch_size)
                self.add_wmemory(self.weight, name, dim_m, dim_k)
            case "Ma_Ma":
                self.add_compute(self.compute, name, dim_m, dim_k, dim_n, batch_size)
            case "Va_Mw":
                self.add_compute(self.compute, name, dim_m, dim_k, dim_n, batch_size)
                self.add_wmemory(self.weight, name, dim_k, dim_n)
            case "Va_Mqa":
                self.add_compute(self.compute, name, dim_m*grouped_q, dim_k, dim_n, batch_size)
                # more optimization opportunity (P1)
            case "Vw_Va":
                self.add_compute(self.compute, name, dim_m, dim_k, dim_n, batch_size)
                self.add_wmemory(self.weight, name, dim_m, dim_k)
            case "Va_Va":
                self.add_compute(self.compute, name, dim_m, dim_k, dim_n, batch_size)
            case "KVCache":
                self.add_kv_cache(self.kv_cache, name, dim_m, dim_k*dim_n, batch_size)
            case "Async_KVCache":
                self.add_kv_cache(self.kv_cache, name, dim_m, dim_k*dim_n, batch_size)
            case _:
                raise ValueError(f"Unknown layer_type: {layer_type}")

        layer = {
            "type": layer_type,
            "name": name,
            "args": args,
            "kwargs": kwargs,
        }
        self.layers.append(layer)

def perform_prefill(llm, seq_length, p_batch_size):
    # Note some parts of the transformer are not modeled, e.g., 
    # Input and position embedding
    # Softmax
    # Residual connection and layer normalization
    # Task-specific output layer

    # Create a WorkloadProfile instance
    wl = WorkloadProfile()

    # Add layers using the WorkloadProfile object
    wl.add_layer(Ma_Mw, "Q/K/V*W", seq_length, llm.hidden_size, llm.sub_dmodel * (llm.num_q_heads+llm.num_kv_heads*2), p_batch_size)
    wl.add_layer(Ma_Mqa, "Q*gK", seq_length, llm.hidden_size, seq_length, p_batch_size, qrouped_q=llm.grouped_q)
    wl.add_layer(Ma_Mqa, "Q*gK*gV", seq_length, llm.hidden_size, seq_length, p_batch_size, qrouped_q=llm.grouped_q)
    wl.add_layer(Ma_Mw, "O*W", seq_length, llm.hidden_size, llm.hidden_size, p_batch_size)
    wl.add_layer(Ma_Mw, "FFN_up", seq_length, llm.hidden_size, llm.intermediate_size, p_batch_size)
    wl.add_layer(Ma_Mw, "FFN_gate", seq_length, llm.hidden_size, llm.intermediate_size, p_batch_size)
    wl.add_layer(Ma_Mw, "FFN_down", seq_length, llm.intermediate_size, llm.hidden_size, p_batch_size)
    # add memory to KV cache #### More optimization opportunity (P1)
    wl.add_layer(Async_KVCache, "KVCache_Store", seq_length, llm.sub_dmodel, 2*llm.num_kv_heads, p_batch_size)

    # Sum up the compute and memory usage
    wl.total_compute = sum(item[0] for item in wl.compute)
    wl.total_weight = sum(item[0] for item in wl.weight)
    wl.total_kv_cache = sum(item[0] for item in wl.kv_cache)

    if DEBUG_PRINT:
        print("Prefill Compute", wl.total_compute)
        print("Prefill Weights", wl.total_weight)
        print("Prefill KV cache", wl.total_kv_cache)
        print_list("Prefill Compute", wl.compute)
        print_list("Prefill Weights", wl.weight)
        print_list("Prefill KV cache", wl.kv_cache)

    # Print detailed information if the flag is set
    if DETAIL_PRINT:
        print("")
        print("   Details:")
        print("   === Input parameters ===")
        print("   Context length:", seq_length)
        print("   Bytes per FP value:", bytes_per_value) # To-Do: this is set globally, not specific to the model or workload (P2)
        print("   Layers:", llm.num_hidden_layers)

        print("   Total Compute (MACs):", convert_number_to_string(wl.total_compute * llm.num_hidden_layers))
        print("   Total Weights Footprint (Bytes):", convert_number_to_string(wl.total_weight * llm.num_hidden_layers))
        print("   Total KV Cache Footprint (Bytes):", convert_number_to_string(wl.total_kv_cache * llm.num_hidden_layers))

    return wl

def perform_decode(llm, context_length, d_batch_size): # Decoding one token at a time
    # Create a WorkloadProfile instance
    wl = WorkloadProfile()

    # Add layers using the WorkloadProfile object
    wl.add_layer(Va_Mw, "Q/K/V*W", 1, llm.hidden_size, llm.sub_dmodel * (llm.num_q_heads+2*llm.num_kv_heads), d_batch_size)
    wl.add_layer(KVCache, "KVCache_Read", context_length, llm.sub_dmodel, 2*llm.num_kv_heads, d_batch_size)
    wl.add_layer(Va_Mqa, "Q*gK", 1, llm.hidden_size, context_length, d_batch_size, qrouped_q=llm.grouped_q)
    wl.add_layer(Va_Mqa, "Q*gK*gV", 1, llm.hidden_size, context_length, d_batch_size, qrouped_q=llm.grouped_q)
    wl.add_layer(Va_Mw, "O*W", 1, llm.hidden_size, llm.hidden_size, d_batch_size)
    wl.add_layer(Va_Mw, "FFN_up", 1, llm.hidden_size, llm.intermediate_size, d_batch_size)
    wl.add_layer(Va_Mw, "FFN_gate", 1, llm.hidden_size, llm.intermediate_size, d_batch_size)
    wl.add_layer(Va_Mw, "FFN_down", 1, llm.intermediate_size, llm.hidden_size, d_batch_size)
    # add the last token to KV cache #### More optimization opportunity (P1)

    # Sum up the compute and memory usage
    wl.total_compute = sum(item[0] for item in wl.compute)
    wl.total_weight = sum(item[0] for item in wl.weight)
    wl.total_kv_cache = sum(item[0] for item in wl.kv_cache)

    if DEBUG_PRINT:
        print("Decode Compute", wl.total_compute)
        print("Decode Weights", wl.total_weight)
        print("Decode KV cache", wl.total_kv_cache)
        print_list("Decode Compute", wl.compute)
        print_list("Decode Weights", wl.weight)
        print_list("Decode KV cache", wl.kv_cache)

    # Print detailed information if the flag is set
    if DETAIL_PRINT:
        print("")
        print("   Details:")
        print("   === Input parameters ===")
        print("   Context length:", context_length)
        print("   Bytes per FP value:", bytes_per_value) # To-Do: this is set globally, not specific to the model or workload (P2)
        print("   Layers:", llm.num_hidden_layers)

        print("   Total Compute (MACs):", convert_number_to_string(wl.total_compute * llm.num_hidden_layers))
        print("   Total Weights Footprint (Bytes):", convert_number_to_string(wl.total_weight * llm.num_hidden_layers))
        print("   Total KV Cache Footprint (Bytes):", convert_number_to_string(wl.total_kv_cache * llm.num_hidden_layers))

    return wl

def perform_chunkedprefill(llm, p_context_length, p_tokens, d_context_length, d_tokens): # Decoding one token at a time
    # Create a WorkloadProfile instance
    wl = WorkloadProfile()

    # Add layers using the profile object
    wl.add_layer(Ma_Mw, "Q/K/V*W", (p_tokens+d_tokens), llm.hidden_size, llm.sub_dmodel * (llm.num_q_heads+llm.num_kv_heads*2), 1)

    # decode token attention
    wl.add_layer(KVCache, "dKVCache_Read", d_context_length, llm.sub_dmodel, 2*llm.num_kv_heads, d_tokens)
    wl.add_layer(Va_Mqa, "dQ*gK", 1, llm.hidden_size, llm.sub_dmodel * llm.num_kv_heads, d_tokens, qrouped_q=llm.grouped_q)
    wl.add_layer(Va_Mqa, "dQ*gK*gV", 1, llm.hidden_size, llm.sub_dmodel * llm.num_kv_heads, d_tokens, qrouped_q=llm.grouped_q)

    # prefill token attention
    wl.add_layer(Async_KVCache, "pKVCache_Read", p_context_length, llm.sub_dmodel, 2*llm.num_kv_heads, 1)
    wl.add_layer(Ma_Mqa, "pQ*gK", p_tokens, llm.hidden_size, p_context_length, 1, qrouped_q=llm.grouped_q)
    wl.add_layer(Ma_Mqa, "pQ*gK*gV", p_tokens, llm.hidden_size, p_context_length, 1, qrouped_q=llm.grouped_q)
    wl.add_layer(Async_KVCache, "KVCache_Store", p_tokens, llm.sub_dmodel, 2*llm.num_kv_heads, 1)

    wl.add_layer(Ma_Mw, "O*W", (p_tokens+d_tokens), llm.hidden_size, llm.hidden_size, 1)
    wl.add_layer(Ma_Mw, "FFN_up", (p_tokens+d_tokens), llm.hidden_size, llm.intermediate_size, 1)
    wl.add_layer(Ma_Mw, "FFN_gate", (p_tokens+d_tokens), llm.hidden_size, llm.intermediate_size, 1)
    wl.add_layer(Ma_Mw, "FFN_down", (p_tokens+d_tokens), llm.intermediate_size, llm.hidden_size, 1)

    # Sum up the compute and memory usage
    wl.total_compute = sum(item[0] for item in wl.compute)
    wl.total_weight = sum(item[0] for item in wl.weight)
    wl.total_kv_cache = sum(item[0] for item in wl.kv_cache)

    if DEBUG_PRINT:
        print("Chunked Compute", wl.total_compute)
        print("Chunked Weights", wl.total_weight)
        print("Chunked KV cache", wl.total_kv_cache)
        print_list("Chunked Compute", wl.compute)
        print_list("Chunked Weights", wl.weight)
        print_list("Chunked KV cache", wl.kv_cache)

    # Print detailed information if the flag is set
    if DETAIL_PRINT:
        print("")
        print("   Details:")
        print("   === Input parameters ===")
        print("   Prefill context length:", p_context_length)
        print("   Decode context length:", d_context_length)
        print("   Bytes per FP value:", bytes_per_value) # To-Do: this is set globally, not specific to the model or workload (P2)
        print("   Layers:", llm.num_hidden_layers)

        print("   Total Compute (MACs):", convert_number_to_string(wl.total_compute * llm.num_hidden_layers))
        print("   Total Weights Footprint (Bytes):", convert_number_to_string(wl.total_weight * llm.num_hidden_layers))
        print("   Total KV Cache Footprint (Bytes):", convert_number_to_string(wl.total_kv_cache * llm.num_hidden_layers))

    return wl

Kernel breakdowns

In [46]:
# Use the first tuple in token_lengths

if 0:
#for token_length in token_lengths:
    # Unpack the tuple into two variables
    input_length, output_length = token_length
    print("Input length:", input_length)
    print("Output length:", output_length)
    p_batch_size = 1
    d_batch_size = 256
    chunksize = 256

    # Perform prefill
    perform_prefill(llm, input_length, p_batch_size)
    # Perform decode
    perform_decode(llm, input_length, d_batch_size)    
    # Perform chunked prefill
    d_tokens = chunksize * output_length // (input_length + output_length)
    p_tokens = chunksize - d_tokens
    d_context_length = input_length + output_length // 2    # This is used to approximately the average decode computation
    p_context_length = int(input_length / 1.73 - p_tokens//2)             # 1.73 = sqrt(3) This is used to approximately the average prefill computation

    perform_chunkedprefill(llm, p_context_length, p_tokens, d_context_length, d_tokens)    

In [None]:
import math

#### Calculate time per kernels 
# Realistic tiling (P3)

# Simple roofline: If matrix-matrix operations, assume compute bound.  If matrix-vector operations, assume memory bound.
# Assume wl is a WorkloadProfile instance (from perform_prefill, perform_decode, or perform_chunkedprefill)
# HW_SPEC should define MACs_per_second, Bandwidth, and Frequency
def estimate_kernel_times(wl, HW_SPEC):
    kernel_times = []
    for layer in wl.layers:
        layer_type = layer["type"]
        name = layer["name"]
        args = layer["args"]

        # Matrix-matrix operations (compute bound)
        if layer_type in [Ma_Mw, Ma_Mqa, Mw_Ma, Ma_Ma]:
            # Find the corresponding compute entry
            compute_entry = next((c for c in wl.compute if c[1] == name), None)
            if compute_entry:
                compute = compute_entry[0]
                exec_cycle = compute / HW_SPEC["MACs_per_second"] * HW_SPEC["Frequency"]
                bound = "Compute"
            else:
                exec_cycle = None
                bound = "N/A"
        # Matrix-vector operations (memory bound)
        elif layer_type in [Va_Mw, Va_Mqa, Vw_Va, Va_Va]:
            # Find the corresponding weight entry
            weight_entry = next((w for w in wl.weight if w[1] == name), None)
            if weight_entry:
                memory = weight_entry[0] * bytes_per_value
                exec_cycle = memory / HW_SPEC["Bandwidth"] * HW_SPEC["Frequency"]
                bound = "Memory"
            else:
                exec_cycle = None
                bound = "N/A"
        elif layer_type == KVCache:
            # Find the corresponding weight entry
            cache_entry = next((w for w in wl.kv_cache if w[1] == name), None)
            if cache_entry:
                memory = cache_entry[0] * bytes_per_value
                exec_cycle = memory / HW_SPEC["Bandwidth"] * HW_SPEC["Frequency"]
                bound = "Memory"
            else:
                exec_cycle = None
                bound = "N/A"
        elif layer_type == Async_KVCache:
            exec_cycle = None
            bound = None
        else:
            exec_cycle = None
            bound = "Unknown"

        kernel_times.append({
            "name": name,
            "type": layer_type,
            "bound": bound,
            "exec_time in cycles": exec_cycle,
        })

        if DEBUG_PRINT:
            print(f"   Layer: {name}, Type: {layer_type}, Bound: {bound}, Exec Time: {exec_cycle:.0f} cycles" if exec_cycle is not None else f"      Layer: {name}, Type: {layer_type}, Bound: {bound}, Exec Time: N/A")

    return kernel_times

import math

def estimate_tiling_sizes(cache_size, bytes_per_value, num_fixed_dims, *dims):
    """
    Estimate tiling sizes for matrix multiplication given cache_size, bytes_per_value, and some fixed dimensions.
    If num_fixed_dims > 0, dims should provide the fixed dimensions in order (dim_1, dim_2, dim_3).
    The remaining dimensions will be estimated to fit the cache.

    Args:
        cache_size: total cache size available (e.g., L3 cache)
        bytes_per_value: bytes per matrix element
        num_fixed_dims: number of fixed dimensions provided (0-2)
        *dims: the fixed dimension values (dim_m, dim_k, dim_n) in order

    Returns:
        tuple: (dim_m, dim_k, dim_n) estimated tile sizes
    """
    match num_fixed_dims:
        case 0:
            # Assume cache is split equally among the three matrices
            size_per_matrix = cache_size / 3
            # For cubic root, use **(1/3)
            perfect_dim = (size_per_matrix / bytes_per_value) ** (1/2)
            return perfect_dim
        case 1:
            # Solve for x in: x * x + 2 * fixed_dim * x < cache_size / bytes_per_value
            # Rearranged: x^2 + 2 * fixed_dim * x - (cache_size / bytes_per_value) < 0
            # Use quadratic formula: x = (-b + sqrt(b^2 - 4ac)) / 2a, where a=1, b=2*fixed_dim, c=-(cache_size / bytes_per_value)
            b = dims[0]*2
            discriminant = b**2 + 4 * (cache_size / bytes_per_value)
            dim_2 = (-b + math.sqrt(discriminant)) / 2
            return dim_2
        case 2:
            # Because the third dimension is not fixed, we can use the remaining cache space
            # to estimate the third dimension
            remaining_capacity = (cache_size - (dims[0] * dims[1] * bytes_per_value))
            dim_3 = remaining_capacity / (bytes_per_value * (dims[0]+dims[1]))
            return dim_3
        
# Simple roofline: If matrix-matrix operations, assume compute bound.  If matrix-vector operations, assume memory bound.
# Assume wl is a WorkloadProfile instance (from perform_prefill, perform_decode, or perform_chunkedprefill)
# HW_SPEC should define MACs_per_second, L3_cache_capacity, Bandwidth, and Frequency
def estimate_kernel_times_by_theoretical_tiler(wl, HW_SPEC):

    # Theoretical tiling: the best tiling used all the cache space for matrix A, matrix B, and matrix C
    
    kernel_times = []
    for layer in wl.layers:
        layer_type = layer["type"]
        name = layer["name"]
        args = layer["args"]

        # Matrix-matrix operations or weight is associated in the layer type
        if "M" in layer_type or "w" in layer_type:
            # Find the corresponding compute entry
            compute_entry = next((c for c in wl.compute if c[1] == name), None)

            if compute_entry:
                # Unpack dimensions from args
                # Unpack dimensions from compute_entry if available, else fallback to args
                # compute_entry = (compute, name, dim_m, dim_k, dim_n, batch_size)
                L3_cache_capacity = HW_SPEC["L3_cache_capacity"]
                wldims = [compute_entry[2], compute_entry[3], compute_entry[4]]
                batch_size = compute_entry[5]

                # Step 0: use batch size to adjust dimensions so that the output matrix is as closer to square as possible
                if batch_size > 1:
                    # Adjust dimensions based on batch size to maximize cache reuse
                    if wldims[0] > wldims[2] and wldims[0] / wldims[2] > batch_size:
                        wldims[2] = wldims[2] * batch_size
                    elif wldims[2] > wldims[0] and wldims[2] / wldims[0] > batch_size:
                        wldims[0] = wldims[0] * batch_size
                    else:
                        ratio = wldims[0] / wldims[2]
                        batch_size = batch_size / ratio
                        wldims[2] = wldims[0] = wldims[0] * ((batch_size) ** (1/2))

                # Step 1: find the tiling dims when all 3 dims are flexible
                t_dim = estimate_tiling_sizes(L3_cache_capacity, bytes_per_value, 0)
                dims = [(0, wldims[0]), (1, wldims[1]), (2, wldims[2])]                
                dims_sorted = sorted(dims, key=lambda x: x[1])
                if dims_sorted[0][1] >= t_dim:
                    tile_dims = [(0, t_dim), (1, t_dim), (2, t_dim)]
                else: # Step 2: find the tiling dims when 1 dim (the smallest) is fixed
                    t_dim = estimate_tiling_sizes(L3_cache_capacity, bytes_per_value, 1, dims_sorted[0][1])
                    if dims_sorted[1][1] >= t_dim:
                        dims_sorted[1] = (dims_sorted[1][0], t_dim)
                        dims_sorted[2] = (dims_sorted[2][0], t_dim)
                        tile_dims = sorted(dims_sorted, key=lambda x: x[0])
                    else: # Step 3: find the last tiling dim when the smallest 2 dims are fixed
                        t_dim = estimate_tiling_sizes(L3_cache_capacity, bytes_per_value, 2, dims_sorted[0][1], dims_sorted[1][1])
                        if dims_sorted[2][1] >= t_dim:
                            dims_sorted[2] = (dims_sorted[2][0], t_dim) # fix the last dim due to the capacity limit
                            tile_dims = sorted(dims_sorted, key=lambda x: x[0])
                        else:
                            tile_dims = sorted(dims_sorted, key=lambda x: x[0])
                
                print(f"   WL_dims: {wldims[0]}x{wldims[1]}x{wldims[2]}, Best tiler: {tile_dims[0][1]:.1f}x{tile_dims[1][1]:.1f}x{tile_dims[2][1]:.1f}")
                compute = tile_dims[0][1] * tile_dims[1][1] * tile_dims[2][1]
                cmp_cycle = compute / HW_SPEC["MACs_per_second"] * HW_SPEC["Frequency"]
                memory = (tile_dims[0][1] * tile_dims[1][1] + tile_dims[2][1] * tile_dims[1][1]) * bytes_per_value
                mem_cycle = memory / HW_SPEC["Bandwidth"] * HW_SPEC["Frequency"]

                # Step 4: compare the cmp_cycle and the mem_cycle and decide the bound
                print(f"   cmp_cycle: {cmp_cycle*compute_entry[0]/compute:0.1f}, mem_cycle: {mem_cycle*compute_entry[0]/compute:0.1f}, mem_cycle-to-cmp_cycle: {(mem_cycle/cmp_cycle):0.1f}")
                if cmp_cycle > mem_cycle:
                    exec_cycle = cmp_cycle*compute_entry[0]/compute
                    bound = "Compute"
                else:
                    exec_cycle = mem_cycle*compute_entry[0]/compute
                    bound = "Memory"
            else:
                exec_cycle = None
                bound = "N/A"
        elif layer_type == KVCache:
            # Find the corresponding weight entry
            cache_entry = next((w for w in wl.kv_cache if w[1] == name), None)
            if cache_entry:
                memory = cache_entry[0] * bytes_per_value
                exec_cycle = memory / HW_SPEC["Bandwidth"] * HW_SPEC["Frequency"]
                bound = "Memory"
            else:
                exec_cycle = None
                bound = "N/A"
        elif layer_type == Async_KVCache:
            exec_cycle = None
            bound = None
        else:
            print (f"Unhandled layer type: {layer_type}")
            exec_cycle = None
            bound = "Unknown"

        kernel_times.append({
            "name": name,
            "type": layer_type,
            "bound": bound,
            "exec_time in cycles": exec_cycle,
        })

        if DEBUG_PRINT:
            print(f"   Layer: {name}, Type: {layer_type}, Bound: {bound}, Exec Time: {exec_cycle:.0f} cycles" if exec_cycle is not None else f"      Layer: {name}, Type: {layer_type}, Bound: {bound}, Exec Time: N/A")

    return kernel_times

In [48]:
#### Output 
# Estimated execution time for different kernels (P2)
# Estimated activities for different kernels for power analysis (P6)

# Breakdown between different kernels and the shape of the tensors: Mw_Ma, Ma_Ma, Mw_Va, Vw_Va, Va_Va
input_length = 1024
output_length = 512
p_batch_size = 1
d_batch_size = 256
chunksize = 256
d_tokens = chunksize * output_length // (input_length + output_length)
p_tokens = chunksize - d_tokens
d_context_length = input_length + output_length // 2    # This is used to approximately the average decode computation
p_context_length = int(input_length / 1.73 - p_tokens//2)             # 1.73 = sqrt(3) This is used to approximately the average prefill computation

wl = perform_prefill(llm, input_length, p_batch_size)
#wl = perform_decode(llm, input_length, d_batch_size)    
#wl = perform_chunkedprefill(llm, p_context_length, p_tokens, d_context_length, d_tokens)    
kernel_times = estimate_kernel_times_by_theoretical_tiler(wl, HW_SPEC)

compute_cycles = sum(k["exec_time in cycles"] for k in kernel_times if k["bound"] == "Compute" and k["exec_time in cycles"] is not None)
memory_cycles = sum(k["exec_time in cycles"] for k in kernel_times if k["bound"] == "Memory" and k["exec_time in cycles"] is not None)

print(f"Total compute-bound cycles: {compute_cycles:.0f}")
print(f"Total memory-bound cycles: {memory_cycles:.0f}")

Prefill Compute 231928233984
Prefill Weights 218103808
Prefill KV cache 2097152
	== Prefill Compute ==
	 [25769803776, 'Q/K/V*W', 1024, 4096, 6144, 1]
	 [4294967296, 'Q*gK', 1024, 4096, 1024, 1]
	 [4294967296, 'Q*gK*gV', 1024, 4096, 1024, 1]
	 [17179869184, 'O*W', 1024, 4096, 4096, 1]
	 [60129542144, 'FFN_up', 1024, 4096, 14336, 1]
	 [60129542144, 'FFN_gate', 1024, 4096, 14336, 1]
	 [60129542144, 'FFN_down', 1024, 14336, 4096, 1]
	== Prefill Weights ==
	 [25165824, 'Q/K/V*W', 4096, 6144]
	 [16777216, 'O*W', 4096, 4096]
	 [58720256, 'FFN_up', 4096, 14336]
	 [58720256, 'FFN_gate', 4096, 14336]
	 [58720256, 'FFN_down', 14336, 4096]
	== Prefill KV cache ==
	 [2097152, 'KVCache_Store', 1024, 2048, 1]

   Details:
   === Input parameters ===
   Context length: 1024
   Bytes per FP value: 2
   Layers: 32
   Total Compute (MACs): 6.8T
   Total Weights Footprint (Bytes): 6.5G
   Total KV Cache Footprint (Bytes): 64.0M
   WL_dims: 1024x4096x6144, Best tiler: 1024.0x3439.5x3439.5
   cmp_cycle: 78