## Summary

This notebook benchmarks PyTorch's scaled_dot_product_attention (SDPA) against a vLLMs Triton-based flash attention kernel.

Key highlights:
- Environment Setup: GPU checks and Triton installation.
- Baseline Performance: Measure PyTorch SDPA runtimes for various sequence lengths.
- vLLM Triton Kernel Benchmark: Compare initial vLLM kernel performance vs. PyTorch.
  - Triton Autotuning & Caching:
    - The first run triggers autotuning (testing multiple configurations), making it slower.
    - The best configuration is cached for future runs.
    - Subsequent runs reuse the cached kernel and run significantly faster without re-tuning.
- Visualization: Clear plots show performance improvements before and after autotuning.
- Speedup Summary: A table and plots demonstrate consistent 2-4x speedups compared to PyTorch after caching.


In [None]:
!python triton-gpu-check.py

## Summary

This notebook benchmarks PyTorch's scaled_dot_product_attention (SDPA) against a vLLMs Triton-based flash attention kernel.

Key highlights:
- Environment Setup: GPU checks and Triton installation.
- Baseline Performance: Measure PyTorch SDPA runtimes for various sequence lengths.
- vLLM Triton Kernel Benchmark: Compare initial vLLM kernel performance vs. PyTorch.
  - Triton Autotuning & Caching:
    - The first run triggers autotuning (testing multiple configurations), making it slower.
    - The best configuration is cached for future runs.
    - Subsequent runs reuse the cached kernel and run significantly faster without re-tuning.
- Visualization: Clear plots show performance improvements before and after autotuning.
- Speedup Summary: A table and plots demonstrate consistent 2-4x speedups compared to PyTorch after caching.


In [None]:
!python triton-gpu-check.py

In [None]:
!cd triton && pip install ./python && cd -

In [None]:
import torch
import triton
import triton.language as tl
import matplotlib.pyplot as plt
import time

print("Torch version:", torch.__version__)
print("Triton version:", triton.__version__)

## Flash Attention Benchmark (PyTorch SDPA vs vLLM Kernel)
This notebook benchmarks the PyTorch `scaled_dot_product_attention` against the vLLM Triton-based flash attention kernel.

In [None]:
# Assuming vllm_flash_attention.py is present in the same directory or accessible path
from flash_attention import triton_attention as vllm_flash_attention
from flash_attention import benchmark_flash_attention as vllm_benchmark
from flash_attention import attn_fwd


In [None]:
def run_pytorch_sdpa(q, k, v):
    return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0)

In [None]:
!ls /workspace/.triton/cache

In [None]:
def run_vllm_flash_attention(q, k, v, seqlen):
    q_flat = q.permute(0, 2, 1, 3).reshape(-1, q.shape[1], q.shape[3])
    k_flat = k.permute(0, 2, 1, 3).reshape(-1, k.shape[1], k.shape[3])
    v_flat = v.permute(0, 2, 1, 3).reshape(-1, v.shape[1], v.shape[3])
    cu_seqlens_q = torch.arange(0, q.shape[0] + 1, dtype=torch.int32, device=q.device) * seqlen
    cu_seqlens_k = torch.arange(0, q.shape[0] + 1, dtype=torch.int32, device=q.device) * seqlen
    o, _ = vllm_flash_attention(q_flat, k_flat, v_flat, None, cu_seqlens_q, cu_seqlens_k, seqlen, seqlen, False, 1.0, None)
    return o.view(q.shape[0], seqlen, q.shape[1], q.shape[3]).permute(0, 2, 1, 3)

In [None]:
def benchmark_flash_attention(batch, nheads, head_dim, seqlen):
    q = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')
    k = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')
    v = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')

    torch.cuda.synchronize()
    start = time.time()
    out_torch = run_pytorch_sdpa(q, k, v)
    torch.cuda.synchronize()
    pytorch_time = time.time() - start

    torch.cuda.synchronize()
    start = time.time()
    out_vllm = run_vllm_flash_attention(q, k, v, seqlen)
    torch.cuda.synchronize()
    vllm_time = time.time() - start

    diff_vllm = torch.max(torch.abs(out_torch - out_vllm)).item()
    return pytorch_time, vllm_time, diff_vllm

In [None]:
seqlens = [128, 256, 512, 1024]
batch, nheads, head_dim = 32, 8, 64
pytorch_times, vllm_times, vllm_diffs = [], [], []

for seqlen in seqlens:
    t_pt, t_vllm, d_vllm = benchmark_flash_attention(batch, nheads, head_dim, seqlen)
    pytorch_times.append(t_pt)
    vllm_times.append(t_vllm)
    vllm_diffs.append(d_vllm)
    print(f"Seqlen={seqlen}: PyTorch CUDA={t_pt:.4f}s, vLLM CUDA={t_vllm:.4f}s, Diff(vLLM)={d_vllm:.2e}")

In [None]:
!ls /workspace/.triton/cache

In [None]:
import os
import json
from pathlib import Path
import time
import pandas as pd
from tabulate import tabulate

def print_triton_cache_info_from_json():
    """
    Scans the Triton cache directory, reads all JSON metadata files,
    groups multiple kernels under the same JSON file name by printing distinct properties,
    and exports CSV/Markdown/HTML reports.
    """
    triton_cache_dir = Path(os.getenv('TRITON_CACHE_DIR', Path.home() / '.triton'))

    if not triton_cache_dir.exists():
        print(f"Triton cache directory not found at {triton_cache_dir}")
        return

    json_files = list(triton_cache_dir.glob('**/*.json'))

    if not json_files:
        print("No kernel metadata (.json) files found.")
        return

    table_data = []

    for json_file in json_files:
        if json_file.name.startswith("__grp__"):
            continue  # Skip group metadata files

        try:
            with open(json_file, 'r') as f:
                metadata = json.load(f)

            so_file = json_file.with_suffix('.so')
            so_exists = so_file.exists()

            row = {
                "JSON File": json_file.name,
                "Created": time.ctime(json_file.stat().st_ctime),
                "Binary?": "Yes" if so_exists else "No",
                "Binary Size (KB)": (so_file.stat().st_size / 1024) if so_exists else None,
                "Kernel Name": metadata.get("name", "-"),
                "Device": metadata.get("device", "-"),
                "Cache Key": metadata.get("cache_key", "-"),
                "Signature": metadata.get("signature", "-"),
                "Num Warps": metadata.get("num_warps", "-"),
                "Num Stages": metadata.get("num_stages", "-"),
            }

            table_data.append(row)

        except Exception as e:
            print(f"Error reading {json_file}: {e}")

    df = pd.DataFrame(table_data)

    # Show multiple distinct kernels by keeping all rows and sorting
    df["Created TS"] = pd.to_datetime(df["Created"])
    df = df.sort_values("Created TS", ascending=False).drop(columns=["Created TS"])

    print(tabulate(df, headers="keys", tablefmt="fancy_grid", showindex=False))

    # Export formats
    output_dir = triton_cache_dir / "cache_report"
    output_dir.mkdir(exist_ok=True)
    df.to_csv(output_dir / "triton_cache_report.csv", index=False)
    df.to_markdown(output_dir / "triton_cache_report.md", index=False)
    df.to_html(output_dir / "triton_cache_report.html", index=False)

    print(f"\nReports saved to: {output_dir}")


In [None]:
print_triton_cache_info_from_json()

In [None]:
plt.figure()
plt.plot(seqlens, pytorch_times, label="PyTorch SDPA (CUDA)")
plt.plot(seqlens, vllm_times, label="vLLM Flash Attention (CUDA)")
plt.xlabel("Sequence Length")
plt.ylabel("Time (s)")
plt.title("Flash Attention Performance: PyTorch vs vLLM on CUDA")
plt.legend()
plt.grid()
plt.show()

## What is Triton Autotuning?
Triton allows kernels to be **autotuned**, meaning it will try multiple kernel configurations (block sizes, warp counts, pipeline stages) to find the optimal setup for your specific GPU hardware and workload shape.

This autotuning process significantly improves performance and ensures the kernel is utilizing the GPU most efficiently.

**How does it work?**  
- Triton runs benchmarks internally with different configurations.  
- It measures which configurations are fastest.  
- The result is cached, so future runs use the best-found setup.

**Why do we re-run tuning?**  
- Hardware setups or driver versions may change.  
- Workload shapes (sequence lengths, batch sizes) might differ from defaults.  
- We want to confirm we’re using the best configuration for *this exact benchmark*.

In the next cell, we trigger this autotuning pass.


## Note on Triton Autotuning and Caching Example

- On the **first run**, when a specific kernel configuration (based on GPU hardware, batch size, sequence length, and head dimensions) is encountered for the first time, **Triton triggers autotuning**.  
   - This process tries multiple kernel configurations in the background and picks the fastest one.
   - As a result, the **first run may be significantly slower** due to this tuning process.

- Once the best-performing configuration is found, it is **stored in Triton's cache** (typically in `/workspace/.triton/cache`).

- On **subsequent runs** with the same input shape and environment:
   - Triton **loads the tuned configuration from cache** and skips tuning.
   - This leads to **consistently fast kernel launches and execution** without re-tuning overhead.

-  If you clear the cache, the next run will re-trigger autotuning.

> In short:  
> - First run = autotuning + execution (slow but smart)  
> - All future runs = cached config + execution (fast and efficient)


In [None]:
# Trigger re-tuning (will reuse cached or search if needed)
vllm_benchmark.run(show_plots=False, print_data=True)

In [None]:
# Minimal script for autotuning
import torch

batch = 2
nheads = 2
head_dim_options = [64, 128]
seqlen = 64

for head_dim in head_dim_options:
    for causal in [False, True]:
        for dropout_p in [0.0, 0.1]:
            total_tokens = batch * seqlen
            q = torch.randn((total_tokens, nheads, head_dim), device='cuda', dtype=torch.float16)
            k = torch.randn_like(q)
            v = torch.randn_like(q)
            o = torch.empty_like(q)
            cu_seqlens_q = torch.arange(0, batch + 1, device='cuda', dtype=torch.int32) * seqlen
            cu_seqlens_k = torch.arange(0, batch + 1, device='cuda', dtype=torch.int32) * seqlen

            # This calls attn_fwd under the hood and triggers autotune
            _ = vllm_flash_attention(q, k, v, o, cu_seqlens_q, cu_seqlens_k, seqlen, seqlen, causal, 1.0, None)

# Now inspect:
for key, config in attn_fwd.cache.items():
    print(f"Best config for {key}: {config.kwargs}, num_warps={config.num_warps}")


In [None]:
def inspect_autotune_cache(kernel, kernel_name=""):
    if not hasattr(kernel, 'cache') or not kernel.cache:
        print(f"{kernel_name or 'This kernel'} does not have autotuning results.")
        return

    print(f"\n=== Autotune Cache (Best Configs) for {kernel_name or kernel.__name__} ===")
    for key, config in kernel.cache.items():
        print(f"Key: {key}  ->  Best Config: {config.kwargs}, num_warps: {config.num_warps}")

    if hasattr(kernel, '_benchmarked_configs') and kernel._benchmarked_configs:
        print("\n=== Benchmarked Configs (all candidates) ===")
        for key, runs in kernel._benchmarked_configs.items():
            print(f"\nInput Key: {key}")
            sorted_runs = sorted(runs, key=lambda x: x[1])
            for config, timing in sorted_runs:
                print(f"  Config: {config.kwargs}, warps: {config.num_warps}, time: {timing * 1e3:.3f} ms")
            best = sorted_runs[0]
            print(f" Best: {best[0].kwargs} (time: {best[1] * 1e3:.3f} ms)")
    else:
        print("\n WARNING Benchmark history (_benchmarked_configs) not found or empty. This might mean:")
        print(" - The kernel has not been invoked yet.")
        print(" - Triton internals changed.")


In [None]:
inspect_autotune_cache(attn_fwd, "attn_fwd")

In [None]:
vllm_tuned_times = []

for seqlen in seqlens:
    q = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')
    k = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')
    v = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')

    torch.cuda.synchronize()
    start = time.time()
    out_vllm_tuned = run_vllm_flash_attention(q, k, v, seqlen)
    torch.cuda.synchronize()
    tuned_time = time.time() - start
    vllm_tuned_times.append(tuned_time)
    print(f"Seqlen={seqlen}: Tuned vLLM CUDA={tuned_time:.4f}s")

In [None]:
print(f"{'SeqLen':>8} | {'PyTorch Time (s)':>18} | {'vLLM Tuned Time (s)':>20} | {'Speedup (PyTorch/vLLM)':>24}")
print("-" * 75)
for seqlen, pt_time, tuned_time in zip(seqlens, pytorch_times, vllm_tuned_times):
    speedup = pt_time / tuned_time
    print(f"{seqlen:8} | {pt_time:18.6f} | {tuned_time:20.6f} | {speedup:24.2f}x")

In [None]:
for key, config in attn_fwd.cache.items():
    print(f"Best config for {key}: {config.kwargs}, num_warps={config.num_warps}")

In [None]:
plt.figure()
plt.plot(seqlens, pytorch_times, label="PyTorch SDPA (CUDA)")
plt.plot(seqlens, vllm_times, label="vLLM (Original)")
plt.plot(seqlens, vllm_tuned_times, label="vLLM (Autotuned)")
plt.xlabel("Sequence Length")
plt.ylabel("Time (s)")
plt.title("Flash Attention Benchmark: PyTorch vs vLLM (Before & After Autotune)")
plt.legend()
plt.grid()
plt.show()

In [None]:
plt.figure()
plt.plot(seqlens, pytorch_times, label="PyTorch SDPA (CUDA)")
plt.plot(seqlens, vllm_tuned_times, label="vLLM (Autotuned)")
plt.xlabel("Sequence Length")
plt.ylabel("Time (s)")
plt.title("Flash Attention Benchmark: PyTorch vs vLLM (After Autotune)")
plt.legend()
plt.grid()
plt.show()