In [1]:
! pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118

Looking in indexes: https://download.pytorch.org/whl/nightly/cu118


In [2]:
import torch
from torch.nn.attention.flex_attention import flex_attention

In [3]:
#Step1: Define the modifying function.
def no_op(score, b, h, q_idx, kv_idx):
    return score

#Step2: Set up the Q,K,V vectors.
batch_size = 8
seq_len = 8

Q = torch.randn(size=(batch_size, seq_len, 128, 128), requires_grad=True, device="cuda")
K = torch.randn(size=(batch_size, seq_len, 128, 128), requires_grad=True, device="cuda")
V = torch.randn(size=(batch_size, seq_len, 128, 128), requires_grad=True, device="cuda")

In [16]:
%%file test_benchmark.py
import torch
from torch.nn.attention.flex_attention import flex_attention
from functools import lru_cache

#Step1: Define the modifying function.
@torch.compile 
def no_op(score, b, h, q_idx, kv_idx):
    return score

#Step2: Set up the Q,K,V vectors.
batch_size = 8
seq_len = 8

Q = torch.randn(size=(batch_size, seq_len, 128, 128), requires_grad=True, device="cuda")
K = torch.randn(size=(batch_size, seq_len, 128, 128), requires_grad=True, device="cuda")
V = torch.randn(size=(batch_size, seq_len, 128, 128), requires_grad=True, device="cuda")

# define the functions as callable.
def sdpa():
  out_sdpa = torch.nn.functional.scaled_dot_product_attention(Q, K, V)
  return out_sdpa

def flex_sdpa():
  out_flex = flex_attention(Q, K, V, score_mod=no_op)
  return out_flex

# make the benchmark functions.
def test_torch_sdpa(benchmark):
  result = benchmark(sdpa)

def test_flex_attention_no_op(benchmark):
  result = benchmark(flex_sdpa)

Writing test_benchmark.py


In [31]:
import torch
from torch.utils._triton import has_triton

print(has_triton())  # Without Triton, we can't use the optimizations we want?

False


In [17]:
! pytest test_benchmark.py --benchmark-compare

platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0
benchmark: 5.1.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /kaggle/working
plugins: benchmark-5.1.0, typeguard-4.4.1, anyio-3.7.1
collected 2 items                                                                                  [0m[1m

test_benchmark.py [32m.[0m[32m.[0m[32m                                                                         [100%][0m


[33m----------------------------------------------------------------------------------------------- benchmark: 2 tests -----------------------------------------------------------------------------------------------[0m
Name (time in us)                    Min                    Max                  Mean              StdDev                Median                 IQR            Outliers          OPS            Rounds  Iterations
[33m--------

In [None]:
! pip install pytest-benchmark

In [None]:
# Lets define a helpful benchmarking function:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
query = torch.randn(size=(seq_len, 128, 128), requires_grad=True, device="cpu")
key = torch.randn(size=(seq_len, 128, 128), requires_grad=True, device="cpu")
value = torch.randn(size=(seq_len, 128, 128), requires_grad=True, device="cpu")


print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

# Lets explore the speed of each of the 3 implementations
from torch.nn.attention import SDPBackend, sdpa_kernel


with sdpa_kernel(SDPBackend.MATH):
    math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, Q, K, V)
    print(f"The math implementation runs in {math_time:.3f} microseconds")

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, Q, K, V)
        print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
    except RuntimeError as e:
        print(f"FlashAttention is not supported. See warnings for reasons: {e}")

with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    try:
        efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, Q, K, V)
        print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.benchmark as benchmark

class SmolAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.q = torch.nn.Linear(128,128, device="cpu")
        self.k = torch.nn.Linear(128,128, device="cpu")
        self.v = torch.nn.Linear(128,128, device="cpu")
        

    def forward(self, query, key, value):
        q = self.q(query)
        k = self.q(key)
        v = self.q(value)
        out = q @ k.transpose(-2, -1)
        probs = F.softmax(out)
        out = probs @ v
        return out


sa = SmolAttention()

batch_size = 8
seq_len = 8

q = torch.randn(size=(seq_len, 128, 128), requires_grad=True, device="cpu")
k = torch.randn(size=(seq_len, 128, 128), requires_grad=True, device="cpu")
v = torch.randn(size=(seq_len, 128, 128), requires_grad=True, device="cpu")

sa.eval()

out = sa(q,k,v) # This is working, now need to see how well the quantized model performs
# print(out.dtype) # fp32


sa_q = torch.ao.quantization.quantize_dynamic(
    sa,  # the original model
    {nn.Linear, nn.Linear, nn.Linear},  # a set of layers to dynamically quantize
    dtype=torch.qint8)  # the target dtype for quantized weights

# run the model. This works!! The values are close, but not enough s
out_q = sa_q(q,k,v)


  probs = F.softmax(out)


## What Next?

Before moving forward, let's have a brief recap:

1. FlexAttention is up and working. The `torch.compile` flex attention is not working. I suppose this is because we don't have triton, so we can't write optimized kernels. If we had a local cude machine, I would've narrowed down the error and give conclusive answer.

2. I didn't make use of block_mask in FlexAttention as sparsity is not a bottleneck I've explored as of now. Might do in future.

3. Apart from all these, I got to know the `F.scaled_dot_product_attention`, for CUDA tensor inputs makes use of 3 different techniques, namely: memory efficient attention, Flash attention, and a native C++ implementation. When this function is called, the best performing version is used. We can also isolate all these 3 methods(ptrblck to the rescue once again) and use any one of the specific methods.

4. Making quantization work was a bit of a hassle. PTDQ for now supports only `Linear` and `Recurrent` layers, so directly using `MultiheadAttention` was not working. A workaround was to implement the SDPA with Q,K,V as linear layers, and it seemed to be working(PTDQ, Eager mode)


Now, let us think what is something we can do given both of these axes are up and running.

1. Understand the KV Cache mechanism in Pytorch and how can we implement it with SDPA.
2. Running a lot of benchmark tests. If we are able to get all these 3(Quantization, FlexAttn, KV Cache) together with SDPA, out problem reduces to a search problem which we can optimize. In order to optimize, we need to run a lot of benchmark tests. Learn about Pytorch profiler, it'll come in real handy.
3. Start thinking about the API of the package if the benchmark tests are giving some results.

In [111]:
## KV Cache. Claude generated.

### v2
import torch
import torch.nn as nn
import torch.nn.functional as F

class CachedSDPA(nn.Module):
    def __init__(self, max_seq_len, head_dim):
        super().__init__()
        self.max_seq_len = max_seq_len
        self.head_dim = head_dim
        self.cache_k = None
        self.cache_v = None
        self.cur_len = 0
    
    def forward(self, q, k, v, is_causal=True):
        # q, k, v: [batch, heads, seq_len, head_dim]
        
        # Handle incremental state
        if self.cache_k is not None:
            k = torch.cat([self.cache_k, k], dim=2)
            v = torch.cat([self.cache_v, v], dim=2)
        
        # Update cache
        self.cache_k = k
        self.cache_v = v
        self.cur_len = k.shape[2]
        
        # Use PyTorch's native SDPA with incremental state
        out = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,  # PyTorch handles causal mask internally when is_causal=True
            dropout_p=0.0,
            is_causal=is_causal
        )
        
        return out

    def reset_cache(self):
        self.cache_k = None
        self.cache_v = None
        self.cur_len = 0

Let's get to profiling. We'll see how things pan out.

In [6]:
from torch.profiler import profile, record_function, ProfilerActivity 

In [8]:
# test with quantized and non-quantized models.
with profile(activities=[ProfilerActivity.CPU], record_shapes=True, profile_memory=True) as prof1:
    with record_function("model_inference"):
        sa(q,k,v)


with profile(activities=[ProfilerActivity.CPU], record_shapes=True, profile_memory=True) as prof2:
    with record_function("model_inference"):
        sa_q(q,k,v)

  probs = F.softmax(out)


In [9]:
print("For Non-Quantized Model:")
print(prof1.key_averages().table(sort_by="cpu_time_total"))

For Non-Quantized Model:
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
         model_inference        40.31%       6.145ms       100.00%      15.245ms      15.245ms           0 b      -3.00 Mb             1  
            aten::linear        22.69%       3.459ms        40.99%       6.249ms       2.083ms       1.50 Mb           0 b             3  
             aten::addmm        12.35%       1.883ms        17.42%       2.656ms     885.465us       1.50 Mb       1.50 Mb             3  
            aten::matmul         0.29%      44.463us        15.21%       2.319ms       1.159ms       1.00 Mb           0 b             2  
  

In [28]:
print("For Quantized Model:")
print(prof2.key_averages().table(sort_by="cpu_time_total"))

For Quantized Model:
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
              model_inference        14.15%     859.224us       100.00%       6.073ms       6.073ms           0 b      -3.00 Mb             1  
    quantized::linear_dynamic        44.86%       2.724ms        46.98%       2.853ms     951.002us       1.50 Mb      -1.50 Mb             3  
                 aten::matmul         0.97%      59.103us        28.70%       1.743ms     871.582us       1.00 Mb           0 b             2  
                    aten::bmm        26.21%       1.592ms        26.23%       1.593ms     796.580us       1.00 Mb  

In [15]:
with profile(activities=[ProfilerActivity.CUDA], profile_memory=True) as prof3:
    with record_function("model_inference"):
        m(Q,K,V)

In [None]:
print("For KVCached Model:")
print(prof3.key_averages().table(sort_by="cpu_time_total"))

In [34]:
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof4:
    with record_function("model_inference"):
        flex_attention(Q, K, V, score_mod=no_op)

In [None]:
print("For Flex Attention Model:")
print(prof4.key_averages().table(sort_by="cuda_time_total"))

I am able to do profiling for all the models, now I just need to figure out how to make graphs and such, or if Pytorch provides a experimentation framework so that we can get to a more detailed study.

In [37]:
prof1.export_chrome_trace("trace1.json")  # we can also export it to json and see it in chrome.

Now, we'll figure out how to run benchmarking(for timing, not memory). We can use Pytorch's built in utils benchmarking facility. 

That is good, but let me figure things out using the module `pytorch-benchmark`

Update: pytorch-bencmark does not contain the things I want, so I'll not be using that.

In [55]:
import torch
import matplotlib.pyplot as plt
import pandas as pd
from typing import List, Dict
from dataclasses import dataclass

@dataclass
class ProfilingResult:
    """Store profiling data for multiple runs"""
    name: str
    avg_time: float
    cpu_time: float
    #cuda_time: float
    cpu_memory: float
    #cuda_memory: float

class ProfileAnalyzer:
    def __init__(self):
        self.runs: Dict[str, List[ProfilingResult]] = {}
    
    def analyze_profile(self, prof: torch.profiler.profile, run_name: str = "default"):
        """Extract key metrics from profiler output"""
        if run_name not in self.runs:
            self.runs[run_name] = []
            
        for event in prof.key_averages():
            result = ProfilingResult(
                name=event.key,
                avg_time=event.cpu_time_total / 1000,  # convert to ms
                cpu_time=event.cpu_time_total / 1000,
                #cuda_time=event.cuda_time_total / 1000 if event.cuda_time_total else 0,
                cpu_memory=event.cpu_memory_usage / 1024 / 1024,  # convert to MB
                #cuda_memory=event.cuda_memory_usage / 1024 / 1024 if event.cuda_memory_usage else 0
            )
            self.runs[run_name].append(result)
    
    def plot_comparison(self, metric: str = "avg_time", top_k: int = 20):
        """Plot comparison of specified metric across runs"""
        plt.figure(figsize=(12, 6))
        
        # Convert data to DataFrame for easier plotting
        data = []
        for run_name, results in self.runs.items():
            for result in results:
                data.append({
                    "run": run_name,
                    "operation": result.name,
                    metric: getattr(result, metric)
                })
        
        df = pd.DataFrame(data)
        
        # Get top k operations by total time across all runs
        top_ops = df.groupby("operation")[metric].sum().nlargest(top_k).index
        df_filtered = df[df["operation"].isin(top_ops)]
        
        # Plot
        #df_pivot = df_filtered.pivot(index="operation", columns="run", values=metric)
        df_filtered.plot(kind="bar", ax=plt.gca())
        
        plt.title(f"Top {top_k} Operations by {metric}")
        plt.xticks(rotation=45, ha="right")
        plt.tight_layout()
        return plt.gcf()

# Example usage and bottleneck analysis
def analyze_bottlenecks(model, q, k, v, batch_size=32):
    """Example of identifying and analyzing bottlenecks"""
    analyzer = ProfileAnalyzer()
    
    # Profile original model
    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        record_shapes=True,
        with_stack=False
    ) as prof:
        output = model(q, k, v)
    
    analyzer.analyze_profile(prof, "original")
    
    # Example optimization: Add batch norm fusion
    model.eval()
    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        record_shapes=True,
        with_stack=False
    ) as prof:
        with record_function("model_inference"):
            output = model(q, k, v)
            
    analyzer.analyze_profile(prof, "optimized")
    
    # Plot comparisons
    analyzer.plot_comparison(metric="avg_time", top_k=5)
    return analyzer

In [None]:
# Profile your model
analyzer = analyze_bottlenecks(m, Q,K,V)

# Check most time-consuming operations
analyzer.plot_comparison(metric="avg_time") 
# Shows top operations by time
#analyzer.plot_comparison(metric="cuda_memory")  # Shows memory usage

In [57]:
## Claude Generated

import torch
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from typing import List, Dict
from dataclasses import dataclass

@dataclass
class OperationMetrics:
    """Detailed metrics for a single operation"""
    name: str
    self_cpu_time: float  # ms
    cpu_memory: float     # MB
    cuda_time: float      # ms
    cuda_memory: float    # MB
    calls: int
    input_shapes: List[str]
    stack_context: str

class PerformanceAnalyzer:
    def __init__(self):
        self.operations: Dict[str, List[OperationMetrics]] = {}
    
    def analyze_profile(self, prof: torch.profiler.profile, run_name: str = "default"):
        """Extract detailed performance metrics with context"""
        self.operations[run_name] = []
        
        for event in prof.key_averages():
            # Get stack trace for context
            stack = event.stack if event.stack else []
            stack_context = "\n".join(str(frame) for frame in stack[-3:])  # Last 3 frames
            
            metrics = OperationMetrics(
                name=event.key,
                self_cpu_time=event.self_cpu_time_total / 1000,
                cpu_memory=event.cpu_memory_usage / 1024 / 1024,
                cuda_time=event.cuda_time_total / 1000 if event.cuda_time_total else 0,
                cuda_memory=event.cuda_memory_usage / 1024 / 1024 if event.cuda_memory_usage else 0,
                calls=event.count,
                input_shapes=[str(shape) for shape in event.input_shapes],
                stack_context=stack_context
            )
            self.operations[run_name].append(metrics)

    def plot_bottleneck_analysis(self, run_name: str = "default"):
        """Create actionable visualization of performance bottlenecks"""
        ops = self.operations[run_name]
        
        # Create figure with subplots
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))
        
        # 1. Time per call analysis
        df_time = pd.DataFrame([{
            'Operation': op.name,
            'CPU Time/Call': op.self_cpu_time / op.calls,
            'CUDA Time/Call': op.cuda_time / op.calls,
            'Calls': op.calls
        } for op in ops])
        
        # Sort by total time per call
        df_time['Total Time/Call'] = df_time['CPU Time/Call'] + df_time['CUDA Time/Call']
        df_time = df_time.nlargest(10, 'Total Time/Call')
        
        # Plot time distribution
        df_time.plot(kind='barh', x='Operation', 
                    y=['CPU Time/Call', 'CUDA Time/Call'], 
                    ax=ax1, stacked=True)
        
        # Add call count annotations
        for i, calls in enumerate(df_time['Calls']):
            ax1.text(df_time['Total Time/Call'].max() * 1.05, i, 
                    f'Calls: {calls}', va='center')
        
        ax1.set_title('Top 10 Time-Consuming Operations (per call)')
        ax1.set_xlabel('Time (ms)')
        
        # 2. Memory impact visualization
        df_mem = pd.DataFrame([{
            'Operation': op.name,
            'CPU Memory (MB)': op.cpu_memory,
            'CUDA Memory (MB)': op.cuda_memory,
            'Input Shapes': '\n'.join(op.input_shapes[:2])  # Show first 2 shapes
        } for op in ops])
        
        df_mem['Total Memory'] = df_mem['CPU Memory (MB)'] + df_mem['CUDA Memory (MB)']
        df_mem = df_mem.nlargest(10, 'Total Memory')
        
        # Plot memory usage
        df_mem.plot(kind='barh', x='Operation', 
                   y=['CPU Memory (MB)', 'CUDA Memory (MB)'], 
                   ax=ax2, stacked=True)
        
        # Add input shape annotations
        for i, shapes in enumerate(df_mem['Input Shapes']):
            ax2.text(df_mem['Total Memory'].max() * 1.05, i, 
                    f'Shapes: {shapes}', va='center')
        
        ax2.set_title('Top 10 Memory-Intensive Operations')
        ax2.set_xlabel('Memory (MB)')
        
        plt.tight_layout()
        return fig

    def get_bottleneck_report(self, run_name: str = "default") -> str:
        """Generate actionable report of potential bottlenecks"""
        ops = self.operations[run_name]
        report = []
        
        # Find operations with high time/call ratio
        time_heavy = sorted(ops, 
                          key=lambda x: (x.self_cpu_time + x.cuda_time) / x.calls, 
                          reverse=True)[:5]
        
        report.append("Top 5 Time-Intensive Operations (per call):")
        for op in time_heavy:
            report.append(f"\n{op.name}:")
            report.append(f"- Time per call: {(op.self_cpu_time + op.cuda_time)/op.calls:.2f}ms")
            report.append(f"- Called {op.calls} times")
            report.append(f"- Input shapes: {', '.join(op.input_shapes[:2])}")
            if op.stack_context:
                report.append(f"- Context: {op.stack_context}")
        
        return "\n".join(report)

In [None]:
analyzer = PerformanceAnalyzer()

# Profile your model
with torch.profiler.profile(...) as prof:
    output = model(input)

analyzer.analyze_profile(prof, "baseline")

# Get visual and textual analysis
analyzer.plot_bottleneck_analysis("baseline")
print(analyzer.get_bottleneck_report("baseline"))

Now, we'll try to see Pytorch benchmark and see how it pans. The timeit module there can do multiple things, and we need to explore the thread facility in it :)

In [10]:
import torch.utils.benchmark as benchmark

In [12]:
num_threads = torch.get_num_threads()
num_threads

2

In [13]:
t0 = benchmark.Timer(
    stmt='flex_attention(Q,K,V, score_mod=no_op)',
    setup = "from __main__ import flex_attention, no_op",
    globals={'Q':Q, 'K': K, 'V':V},
    num_threads=num_threads,
    label = "Multithreaded SDPA - FlexAttention"
)

print(f"{t0.timeit(100)}")

<torch.utils.benchmark.utils.common.Measurement object at 0x7a00a2430c40>
Multithreaded SDPA - FlexAttention
setup: from __main__ import flex_attention, no_op
  8.17 ms
  1 measurement, 100 runs , 2 threads


In [None]:
## Pyorch code to compare benchmark performances:
## Will use later if required.
from itertools import product

# Compare takes a list of measurements which we'll save in results.
results = []

sizes = [1, 64, 1024, 10000]
for b, n in product(sizes, sizes):
    # label and sub_label are the rows
    # description is the column
    label = 'Batched dot'
    sub_label = f'[{b}, {n}]'
    x = torch.ones((b, n))
    for num_threads in [1, 4, 16, 32]:
        results.append(benchmark.Timer(
            stmt='batched_dot_mul_sum(x, x)',
            setup='from __main__ import batched_dot_mul_sum',
            globals={'x': x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='mul/sum',
        ).blocked_autorange(min_run_time=1))
        results.append(benchmark.Timer(
            stmt='batched_dot_bmm(x, x)',
            setup='from __main__ import batched_dot_bmm',
            globals={'x': x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='bmm',
        ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.print()

## Updates

Things are going and can go well depending of the next steps. Till now,

1. Able run profilers on different attention variants.
2. Run benchmarks on different variants, with option to compare different benchmarks too.
3. Able to run code on different threads, and note performance gains.
4. Have a custom library for graphing and understanding memeory and speed throughput.

Pytorch is really wonderful.

Now, just writing a testing suite is required. After that, only comparision is left.

Worklog soon!!

In [2]:
## There is real alpha in this.
import torch.utils.benchmark as benchmark
from torch.profiler import profile, record_function, ProfilerActivity 

class PerformanceAnalysis:
    def __init__(self, func, inputs, *args):
        self.m = func # this could be a nn.Module, function, or anything else.
        self.q, self.k, self.v = inputs #unpack  the inputs, we need it in this format only.
        self.args = args
        self.setup()
        
    def setup(self):
        pass 

    def profile(self):
        # Update, can add multiple settings.
        if self.q.device == "cuda":
            act = ProfilerActivity.CUDA
        else:
            act = ProfilerActivity.CPU
        
        with profile(activities=[act], record_shapes=True, profile_memory=True) as prof4:
            with record_function("model_inference"):
                if self.args:
                    self.m(self.q, self.k, self.v, self.args[0])
                else:
                    self.m(self.q, self.k, self.v)
        
        return prof4.key_averages().table()
    
    def benchmark(self, use_threads=True, num_exprs=100):
        # Custom Logic to make the statement for Benchmark Timer

        # If we figure out the class/function, setup part is done. Just need to fire out how the Q,K,V names are made.
        # Update: No need. We did it lol.
        # Update: Stuck on the benchmark class/function thing. 
        # Update: Made it work after scraping lol.
        
        import inspect
        if inspect.isfunction(self.m):
            func_name = f"{self.m.__name__}"
            # print(func_name)
            name = func_name
            if self.args:
                module_name = f"{self.args[0].__name__}"
                stmt_str = f"M(Q, K, V, {module_name})"
                setup_str = f"from __main__ import {module_name}"
            else:
                stmt_str = f"M(Q, K, V)"
                setup_str = f"from __main__ import {func_name}"
        else: # it must be a class. Add checks, but for now we can proceed.
            class_name = f"{self.m.__class__.__name__}"
            name = class_name
            if self.args:
                module_name = f"{self.args[0]}"
                stmt_str = f"M(Q, K, V, {module_name})"
                setup_str = f"from __main__ import {class_name}, {module_name}"
            else:
                stmt_str = f"M(Q, K, V)"
                setup_str = f"from __main__ import {class_name}"

        if use_threads:
            # Sorted
            num_threads = torch.get_num_threads()
            t0 = benchmark.Timer(
                stmt = stmt_str,
                setup = setup_str,
                globals={'M':self.m, 'Q':self.q, 'K': self.k, 'V':self.v},
                num_threads=num_threads,
                label = f"Multi Threaded SDPA - {name}"
            )
        else:
            t0 = benchmark.Timer(
                stmt = stmt_str,
                setup = setup_str,
                globals={'M': self.m, 'Q':self.q, 'K': self.k, 'V':self.v},
                label = f"Single Threaded SDPA - {name}"
            )

        return t0.timeit(num_exprs)
        

    def run(self):
        pass 

    def report(self):
        return f"""
=======================================================================================================================
Performance Analysis - Memory and Benchmark Report:
=======================================================================================================================
***********************************************************************************************************************
Memory Profile Report:
***********************************************************************************************************************

{self.profile()}
        
***********************************************************************************************************************
Benchmark Profile Report:
***********************************************************************************************************************

{self.benchmark()}"""

In [None]:
pa = PerformanceAnalysis(flex_attention, (Q, K, V), no_op)

print(pa.report())

Updates:

**Note:** The thing with pytorch built in benchmark is that it is just cumbersome. You write strings and such and it has a confusing API where we can easily run into scoping problems.

**Update:** After one initial ray of brilliance, I'm able to get the pytorch built in benchmark compatible with my class. It is going to be easy, a LOT!! 
 
Don't wanna procrastinate on writing tests and such. We should just get to it, and save the results in some file idk. Just to have a sanity check, I'll start off by running profiler and benchmark for:

1. KVCache SDPA
2. Quantized/Non Quantized SDPA. 
3. Flex Attention with different score modifiers.

How should we go about it? 
For reproducability, set the seed for Pytorch.

In a single cell, write the test for each variant -> Write the cell into a file -> Pytest.

In the cell, model+inputs -> profiler -> benchmarker -> return/print results. 

# KV Cache Tests

This work slaps ngl. The work we did to make the `PerformanceAnalysis` class reduced a lot of boilerplate, and the overhead from this was reduced a lot. We just need to make it more general to that we can pass on hyperparameters and can test with more options, but first, we had to make it work.

**Note:** We are running frequently into CUDA out of memory errors, so stay in the in zone till we figure out ways to deal with it.

In [3]:
## Actual Definition:

import torch
import torch.nn as nn
import torch.nn.functional as F

class CachedSDPA(nn.Module):
    def __init__(self, max_seq_len, head_dim):
        super().__init__()
        self.max_seq_len = max_seq_len
        self.head_dim = head_dim
        self.cache_k = None
        self.cache_v = None
        self.cur_len = 0
    
    def forward(self, q, k, v, is_causal=True):
        # q, k, v: [batch, heads, seq_len, head_dim]
        
        # Handle incremental state
        if self.cache_k is not None:
            k = torch.cat([self.cache_k, k], dim=1)
            v = torch.cat([self.cache_v, v], dim=1)
        
        # Update cache
        self.cache_k = k
        self.cache_v = v
        self.cur_len = k.shape[2]
        
        # Use PyTorch's native SDPA with incremental state
        out = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,  # PyTorch handles causal mask internally when is_causal=True
            dropout_p=0.0,
            is_causal=is_causal
        )
        
        return out

    def reset_cache(self):
        self.cache_k = None
        self.cache_v = None
        self.cur_len = 0

# Now, these functions on their own won't make too much sense. Changing hyperparams, and similar things will actually be nice.
def test_cached_sdpa_memory():
    pa.profile()

def test_cached_sdpa_benchmark():
    pa.benchmark()

seq_lengths = [4, 8, 10, 12] 
hidden_dims = [16, 32, 64, 128]

# We are running into CUDA out of memory erros frequently, so stay in the required zone.

print("*********************** Cached SDPA Tests *************************")

for i, (sl, hd) in enumerate(zip(seq_lengths, hidden_dims)):
    print(f"=============== Experiment {i+1}: seq_len={sl}, hidden_dim={hd} =========================")
    q = torch.randn(size=(sl, hd, hd), requires_grad=True, device="cuda")
    k = torch.randn(size=(sl, hd, hd), requires_grad=True, device="cuda")
    v = torch.randn(size=(sl, hd, hd), requires_grad=True, device="cuda")

    m = CachedSDPA(sl, hd)

    pa = PerformanceAnalysis(m, (q,k,v))
    print("=============== Memory Profile: ===========================")
    print(pa.profile())
    
    print("=============== Benchmark Report: =========================")
    print(pa.benchmark())

    torch.cuda.empty_cache()
    del q,k,v,m,pa

*********************** Cached SDPA Tests *************************
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference         1.69%       5.734ms       100.00%     338.528ms     338.528ms           0 b           0 b             1  
                     aten::scaled_dot_product_attention         3.19%      10.803ms        98.31%     332.793ms     332.793ms           0 b           0 b             1  
               aten::_scaled_dot_product_attention_math         3.85%      13.026m

## Quantized/Non Quantized Tests

If we were able to do the previous test, this will be ez af.

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.benchmark as benchmark

class SmolAttention(nn.Module):
    def __init__(self, max_seq_len, hidden_dims):
        super().__init__()
        self.q = nn.Linear(hidden_dims, hidden_dims)
        self.k = nn.Linear(hidden_dims, hidden_dims)
        self.v = nn.Linear(hidden_dims, hidden_dims)

    def forward(self, query, key, value):
        q = self.q(query)
        k = self.k(query)
        v = self.v(query)
        return F.scaled_dot_product_attention(q, k, v)

seq_lengths = [4, 8, 10, 12] 
hidden_dims = [16, 32, 64, 128]


print("*********************** Non - Quantized SDPA Tests *************************")

for i, (sl, hd) in enumerate(zip(seq_lengths, hidden_dims)):
    print(f"=============== Experiment {i+1}: seq_len={sl}, hidden_dim={hd} =========================")
    q = torch.randn(size=(sl, hd, hd), requires_grad=True, device="cpu")
    k = torch.randn(size=(sl, hd, hd), requires_grad=True, device="cpu")
    v = torch.randn(size=(sl, hd, hd), requires_grad=True, device="cpu")

    m = SmolAttention(sl, hd)
    m_q = torch.ao.quantization.quantize_dynamic(m,{nn.Linear, nn.Linear, nn.Linear},dtype=torch.qint8)

    print("Non-Quantized Model: ")
    pa = PerformanceAnalysis(m, (q,k,v))
    print("=============== Memory Profile: ===========================")
    print(pa.profile())
    
    print("=============== Benchmark Report: =========================")
    print(pa.benchmark())
    
    print("Quantized Model: ")
    pa = PerformanceAnalysis(m_q, (q,k,v))
    print("=============== Memory Profile: ===========================")
    print(pa.profile())
    
    print("=============== Benchmark Report: =========================")
    print(pa.benchmark())

    del q,k,v,m,m_q,pa

*********************** Non - Quantized SDPA Tests *************************
Non-Quantized Model: 
--------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
--------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             model_inference        17.26%       1.217ms       100.00%       7.047ms       7.047ms           0 b     -28.00 Kb             1  
                                aten::linear         2.24%     157.620us        76.38%       5.382ms       1.794ms      12.00 Kb           0 b             3  
                               aten::reshape         0.43%      30.023us         1.10%      77.865us      

## FlexAttention Tests
Now, we move on to the difficult stuff. If we are able to do this, things will be the easiest.

**Update:** This was easier than I though it would be, I was able to do it one shot. Now, we move on it difficult stuff. Just run experiments and see how to approach this difficult problem.

In [5]:
## Actual FlexAttention Call

import torch
from torch.nn.attention.flex_attention import flex_attention


@torch.compile 
def no_op(score, b, h, q_idx, kv_idx):
    return score

print("*********************** FlexAttention SDPA Tests *************************")

batch_sizes = [8, 16, 32]
seq_lengths = [8, 16, 32]
hidden_dims = [64, 128, 256]


for i, (bs, sl, hd) in enumerate(zip(batch_sizes, seq_lengths, hidden_dims)):
    print(f"=============== Experiment {i+1}: batch_size={bs} seq_len={sl}, hidden_dim={hd} =========================")
    q = torch.randn(size=(bs, sl, hd, hd), requires_grad=True, device="cuda")
    k = torch.randn(size=(bs, sl, hd, hd), requires_grad=True, device="cuda")
    v = torch.randn(size=(bs, sl, hd, hd), requires_grad=True, device="cuda")

    
    pa = PerformanceAnalysis(flex_attention, (q,k,v), no_op)    # extra score_mod is added
    print("=============== Memory Profile: ===========================")
    print(pa.profile())
    
    print("=============== Benchmark Report: =========================")
    print(pa.benchmark())
    
    
    del q,k,v,pa

*********************** FlexAttention SDPA Tests *************************
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        13.21%      68.260ms       100.00%     516.692ms     516.692ms           0 b           0 b             1  
                                             aten::ones         0.00%      25.530us         0.16%     815.273us     815.273us           0 b           0 b             1  
                                            aten::empty         0.21%      

## Part n-1 Completed!!

Wow. I was actually able to break down problem and attend to each of them individually, the next big task now is to find the "goldiocks" zone, and use a combination of approaches for which we have experimented. Before going there, if we are able to report only the important metrics from the `ProfileAnalysis.report()` function, we can reduce a lot of useless print statements. 