In [None]:
!python triton-gpu-check.py

In [None]:
!pip install ./triton/python

In [None]:
import torch
import triton
import triton.language as tl
import matplotlib.pyplot as plt
import time

print("Torch version:", torch.__version__)
print("Triton version:", triton.__version__)

## Flash Attention Benchmark (PyTorch SDPA vs vLLM Kernel)
This notebook benchmarks the PyTorch `scaled_dot_product_attention` against the vLLM Triton-based flash attention kernel.

In [None]:
# Assuming vllm_flash_attention.py is present in the same directory or accessible path
from flash_attention import triton_attention as vllm_flash_attention
from flash_attention import benchmark_flash_attention as vllm_benchmark


In [None]:
!ls /workspace/.triton/cache

In [None]:
def run_pytorch_sdpa(q, k, v):
    return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0)

In [None]:
!ls /workspace/.triton/cache

In [None]:
def run_vllm_flash_attention(q, k, v, seqlen):
    q_flat = q.permute(0, 2, 1, 3).reshape(-1, q.shape[1], q.shape[3])
    k_flat = k.permute(0, 2, 1, 3).reshape(-1, k.shape[1], k.shape[3])
    v_flat = v.permute(0, 2, 1, 3).reshape(-1, v.shape[1], v.shape[3])
    cu_seqlens_q = torch.arange(0, q.shape[0] + 1, dtype=torch.int32, device=q.device) * seqlen
    cu_seqlens_k = torch.arange(0, q.shape[0] + 1, dtype=torch.int32, device=q.device) * seqlen
    o, _ = vllm_flash_attention(q_flat, k_flat, v_flat, None, cu_seqlens_q, cu_seqlens_k, seqlen, seqlen, False, 1.0, None)
    return o.view(q.shape[0], seqlen, q.shape[1], q.shape[3]).permute(0, 2, 1, 3)

In [None]:
def benchmark_flash_attention(batch, nheads, head_dim, seqlen):
    q = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')
    k = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')
    v = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')

    torch.cuda.synchronize()
    start = time.time()
    out_torch = run_pytorch_sdpa(q, k, v)
    torch.cuda.synchronize()
    pytorch_time = time.time() - start

    torch.cuda.synchronize()
    start = time.time()
    out_vllm = run_vllm_flash_attention(q, k, v, seqlen)
    torch.cuda.synchronize()
    vllm_time = time.time() - start

    diff_vllm = torch.max(torch.abs(out_torch - out_vllm)).item()
    return pytorch_time, vllm_time, diff_vllm

In [None]:
seqlens = [128, 256, 512, 1024]
batch, nheads, head_dim = 32, 8, 64
pytorch_times, vllm_times, vllm_diffs = [], [], []

for seqlen in seqlens:
    t_pt, t_vllm, d_vllm = benchmark_flash_attention(batch, nheads, head_dim, seqlen)
    pytorch_times.append(t_pt)
    vllm_times.append(t_vllm)
    vllm_diffs.append(d_vllm)
    print(f"Seqlen={seqlen}: PyTorch CUDA={t_pt:.4f}s, vLLM CUDA={t_vllm:.4f}s, Diff(vLLM)={d_vllm:.2e}")

In [None]:
!ls /workspace/.triton/cache

In [None]:
plt.figure()
plt.plot(seqlens, pytorch_times, label="PyTorch SDPA (CUDA)")
plt.plot(seqlens, vllm_times, label="vLLM Flash Attention (CUDA)")
plt.xlabel("Sequence Length")
plt.ylabel("Time (s)")
plt.title("Flash Attention Performance: PyTorch vs vLLM on CUDA")
plt.legend()
plt.grid()
plt.show()

## What is Triton Autotuning?
Triton allows kernels to be **autotuned**, meaning it will try multiple kernel configurations (block sizes, warp counts, pipeline stages) to find the optimal setup for your specific GPU hardware and workload shape.

This autotuning process significantly improves performance and ensures the kernel is utilizing the GPU most efficiently.

**How does it work?**  
- Triton runs benchmarks internally with different configurations.  
- It measures which configurations are fastest.  
- The result is cached, so future runs use the best-found setup.

**Why do we re-run tuning?**  
- Hardware setups or driver versions may change.  
- Workload shapes (sequence lengths, batch sizes) might differ from defaults.  
- We want to confirm we’re using the best configuration for *this exact benchmark*.

In the next cell, we trigger this autotuning pass.


In [None]:
# Trigger re-tuning (will reuse cached or search if needed)
vllm_benchmark.run(show_plots=False, print_data=True)

In [None]:
vllm_tuned_times = []

for seqlen in seqlens:
    q = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')
    k = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')
    v = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')

    torch.cuda.synchronize()
    start = time.time()
    out_vllm_tuned = run_vllm_flash_attention(q, k, v, seqlen)
    torch.cuda.synchronize()
    tuned_time = time.time() - start
    vllm_tuned_times.append(tuned_time)
    print(f"Seqlen={seqlen}: Tuned vLLM CUDA={tuned_time:.4f}s")

In [None]:
print(f"{'SeqLen':>8} | {'PyTorch Time (s)':>18} | {'vLLM Tuned Time (s)':>20} | {'Speedup (PyTorch/vLLM)':>24}")
print("-" * 75)
for seqlen, pt_time, tuned_time in zip(seqlens, pytorch_times, vllm_tuned_times):
    speedup = pt_time / tuned_time
    print(f"{seqlen:8} | {pt_time:18.6f} | {tuned_time:20.6f} | {speedup:24.2f}x")


In [None]:
plt.figure()
plt.plot(seqlens, pytorch_times, label="PyTorch SDPA (CUDA)")
plt.plot(seqlens, vllm_times, label="vLLM (Original)")
plt.plot(seqlens, vllm_tuned_times, label="vLLM (Autotuned)")
plt.xlabel("Sequence Length")
plt.ylabel("Time (s)")
plt.title("Flash Attention Benchmark: PyTorch vs vLLM (Before & After Autotune)")
plt.legend()
plt.grid()
plt.show()