# EdgeLLM Flash Attention Benchmark

This notebook benchmarks our Flash Attention implementation on Tesla T4 GPU.

**Target:** Demonstrate 2-3x speedup over naive attention to help hit 630+ tok/s

**Tests:**
1. Correctness validation vs naive attention
2. Forward pass performance at various sequence lengths
3. Decode performance with KV cache (inference critical path)
4. Memory efficiency comparison

## 1. Environment Setup

In [None]:
import subprocess
import os

# Check GPU
result = subprocess.run(['nvidia-smi', '--query-gpu=name,compute_cap,memory.total', 
                         '--format=csv,noheader'], capture_output=True, text=True)
gpu_info = result.stdout.strip()
print(f"GPU: {gpu_info}")

# Parse compute capability
parts = gpu_info.split(', ')
GPU_NAME = parts[0] if len(parts) > 0 else 'Unknown'
COMPUTE_CAP = parts[1] if len(parts) > 1 else '0.0'
print(f"Compute Capability: {COMPUTE_CAP}")

In [None]:
# Clone repository
!rm -rf ollama-api-gateway
!git clone --depth 1 https://github.com/umerkhan95/ollama-api-gateway.git
%cd ollama-api-gateway/mojo-gateway

## 2. Build Flash Attention

In [None]:
# Check CUDA version
!nvcc --version

In [None]:
# Build Flash Attention for T4
%cd src/kernels/cuda
!make clean
!make t4 flash
print("\nBuild complete!")
!ls -la ../../../lib/*.so 2>/dev/null || echo "Checking for libraries..."

## 3. Run Tests

In [None]:
# Build and run Flash Attention tests
!make test-flash

## 4. Custom Benchmark

In [None]:
%%writefile custom_fa_bench.cu
/**
 * Custom Flash Attention Benchmark for SmolLM-135M
 */

#include <stdio.h>
#include <stdlib.h>
#include <chrono>
#include <cuda_runtime.h>
#include "flash_attention.h"

#define WARMUP 10
#define RUNS 100

void fill_random(float* data, int size) {
    for (int i = 0; i < size; i++) {
        data[i] = ((float)rand() / RAND_MAX - 0.5f) * 2.0f;
    }
}

int main() {
    printf("\n=== SmolLM-135M Flash Attention Benchmark ===");
    printf("\n\nConfiguration:\n");
    printf("  Model: SmolLM-135M\n");
    printf("  Heads: 9\n");
    printf("  Head dim: 64\n");
    printf("  Hidden: 576\n\n");
    
    // SmolLM-135M config
    int batch = 1;
    int num_heads = 9;
    int head_dim = 64;
    int batch_heads = batch * num_heads;
    int num_layers = 9;
    
    // Initialize
    flash_attention_init(batch, num_heads, 2048, head_dim);
    flash_attention_init_kv_cache(batch, num_heads, 2048, head_dim);
    
    int single_size = batch_heads * head_dim;
    float* Q = (float*)malloc(single_size * sizeof(float));
    float* K = (float*)malloc(single_size * sizeof(float));
    float* V = (float*)malloc(single_size * sizeof(float));
    float* O = (float*)malloc(single_size * sizeof(float));
    
    srand(42);
    fill_random(Q, single_size);
    fill_random(K, single_size);
    fill_random(V, single_size);
    
    // Fill cache to 256 tokens
    for (int pos = 0; pos < 256; pos++) {
        fill_random(K, single_size);
        fill_random(V, single_size);
        flash_attention_update_kv_cache(K, V, batch_heads, pos, 1, head_dim);
    }
    
    printf("Decode benchmark (cache_len=256):\n");
    printf("----------------------------------\n");
    
    // Warmup
    for (int i = 0; i < WARMUP; i++) {
        flash_attention_decode(Q, K, V, O, batch_heads, 255, head_dim);
    }
    
    // Benchmark single attention layer
    auto start = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < RUNS; i++) {
        flash_attention_decode(Q, K, V, O, batch_heads, 255, head_dim);
    }
    cudaDeviceSynchronize();
    auto end = std::chrono::high_resolution_clock::now();
    
    double total_ms = std::chrono::duration<double, std::milli>(end - start).count();
    double per_layer_ms = total_ms / RUNS;
    double per_token_ms = per_layer_ms * num_layers;  // 9 layers
    double tokens_per_sec = 1000.0 / per_token_ms;
    
    printf("  Per-layer attention: %.4f ms\n", per_layer_ms);
    printf("  Per-token (9 layers): %.4f ms\n", per_token_ms);
    printf("\n");
    printf("Attention-only throughput: %.1f tok/s\n", tokens_per_sec);
    printf("\n");
    
    // Estimate total inference throughput
    // Attention is typically 30-40% of total inference time
    double attention_fraction = 0.35;
    double estimated_total_ms = per_token_ms / attention_fraction;
    double estimated_throughput = 1000.0 / estimated_total_ms;
    
    printf("Estimated full inference throughput:\n");
    printf("  (assuming attention = 35%% of compute)\n");
    printf("  Per-token total: %.3f ms\n", estimated_total_ms);
    printf("  Throughput: %.1f tok/s\n", estimated_throughput);
    printf("\n");
    
    printf("Targets:\n");
    printf("  Current target: 630 tok/s\n");
    printf("  Ollama baseline: 423 tok/s\n");
    printf("\n");
    
    // JSON output
    printf("JSON Results:\n");
    printf("{\n");
    printf("  \"per_layer_attention_ms\": %.4f,\n", per_layer_ms);
    printf("  \"per_token_attention_ms\": %.4f,\n", per_token_ms);
    printf("  \"attention_throughput\": %.1f,\n", tokens_per_sec);
    printf("  \"estimated_total_throughput\": %.1f,\n", estimated_throughput);
    printf("  \"target_throughput\": 630,\n");
    printf("  \"ollama_baseline\": 423\n");
    printf("}\n");
    
    free(Q);
    free(K);
    free(V);
    free(O);
    flash_attention_cleanup();
    
    return 0;
}

In [None]:
# Compile and run custom benchmark
!nvcc -O3 -gencode arch=compute_75,code=sm_75 \
    -o custom_fa_bench custom_fa_bench.cu flash_attention.o -lcudart
!./custom_fa_bench

## 5. Compare: Flash Attention vs Naive Attention

In [None]:
%%writefile compare_attention.cu
/**
 * Compare Flash Attention vs Naive Attention
 */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <chrono>
#include <cuda_runtime.h>
#include "flash_attention.h"

// Naive attention kernel (for comparison)
__global__ void naive_attention_kernel(
    const float* Q, const float* K, const float* V, float* O,
    int seq_len, int head_dim, float scale
) {
    int row = blockIdx.x;
    int tid = threadIdx.x;
    
    if (row >= seq_len) return;
    
    extern __shared__ float smem[];
    float* scores = smem;  // [seq_len]
    
    // Compute Q[row] @ K^T
    for (int j = tid; j < seq_len; j += blockDim.x) {
        float dot = 0.0f;
        for (int d = 0; d < head_dim; d++) {
            dot += Q[row * head_dim + d] * K[j * head_dim + d];
        }
        scores[j] = dot * scale;
    }
    __syncthreads();
    
    // Find max (reduction)
    float max_val = -1e9f;
    for (int j = tid; j < seq_len; j += blockDim.x) {
        max_val = fmaxf(max_val, scores[j]);
    }
    __shared__ float smax;
    if (tid == 0) smax = -1e9f;
    __syncthreads();
    atomicMax((int*)&smax, __float_as_int(max_val));
    __syncthreads();
    max_val = smax;
    
    // Softmax
    float sum = 0.0f;
    for (int j = tid; j < seq_len; j += blockDim.x) {
        scores[j] = expf(scores[j] - max_val);
        sum += scores[j];
    }
    __shared__ float ssum;
    if (tid == 0) ssum = 0.0f;
    __syncthreads();
    atomicAdd(&ssum, sum);
    __syncthreads();
    
    for (int j = tid; j < seq_len; j += blockDim.x) {
        scores[j] /= ssum;
    }
    __syncthreads();
    
    // Output = scores @ V
    for (int d = tid; d < head_dim; d += blockDim.x) {
        float val = 0.0f;
        for (int j = 0; j < seq_len; j++) {
            val += scores[j] * V[j * head_dim + d];
        }
        O[row * head_dim + d] = val;
    }
}

void fill_random(float* data, int size) {
    for (int i = 0; i < size; i++) {
        data[i] = ((float)rand() / RAND_MAX - 0.5f) * 2.0f;
    }
}

int main() {
    printf("\n=== Flash Attention vs Naive Attention ===");
    printf("\n\n");
    
    int batch_heads = 9;
    int head_dim = 64;
    int seq_lengths[] = {64, 128, 256, 512};
    int num_tests = 4;
    
    printf("%-10s | %-12s | %-12s | %-10s\n",
           "Seq Len", "Naive (ms)", "Flash (ms)", "Speedup");
    printf("-----------|--------------|--------------|------------\n");
    
    for (int t = 0; t < num_tests; t++) {
        int seq_len = seq_lengths[t];
        int size = batch_heads * seq_len * head_dim;
        
        float* h_Q = (float*)malloc(size * sizeof(float));
        float* h_K = (float*)malloc(size * sizeof(float));
        float* h_V = (float*)malloc(size * sizeof(float));
        float* h_O = (float*)malloc(size * sizeof(float));
        
        fill_random(h_Q, size);
        fill_random(h_K, size);
        fill_random(h_V, size);
        
        float *d_Q, *d_K, *d_V, *d_O;
        cudaMalloc(&d_Q, size * sizeof(float));
        cudaMalloc(&d_K, size * sizeof(float));
        cudaMalloc(&d_V, size * sizeof(float));
        cudaMalloc(&d_O, size * sizeof(float));
        
        cudaMemcpy(d_Q, h_Q, size * sizeof(float), cudaMemcpyHostToDevice);
        cudaMemcpy(d_K, h_K, size * sizeof(float), cudaMemcpyHostToDevice);
        cudaMemcpy(d_V, h_V, size * sizeof(float), cudaMemcpyHostToDevice);
        
        float scale = 1.0f / sqrtf(head_dim);
        
        // Benchmark naive
        for (int i = 0; i < 5; i++) {
            for (int bh = 0; bh < batch_heads; bh++) {
                naive_attention_kernel<<<seq_len, 256, seq_len * sizeof(float)>>>(
                    d_Q + bh * seq_len * head_dim,
                    d_K + bh * seq_len * head_dim,
                    d_V + bh * seq_len * head_dim,
                    d_O + bh * seq_len * head_dim,
                    seq_len, head_dim, scale);
            }
        }
        cudaDeviceSynchronize();
        
        auto start = std::chrono::high_resolution_clock::now();
        for (int i = 0; i < 50; i++) {
            for (int bh = 0; bh < batch_heads; bh++) {
                naive_attention_kernel<<<seq_len, 256, seq_len * sizeof(float)>>>(
                    d_Q + bh * seq_len * head_dim,
                    d_K + bh * seq_len * head_dim,
                    d_V + bh * seq_len * head_dim,
                    d_O + bh * seq_len * head_dim,
                    seq_len, head_dim, scale);
            }
        }
        cudaDeviceSynchronize();
        auto end = std::chrono::high_resolution_clock::now();
        double naive_ms = std::chrono::duration<double, std::milli>(end - start).count() / 50;
        
        // Benchmark Flash Attention
        flash_attention_init(1, batch_heads, seq_len, head_dim);
        
        for (int i = 0; i < 5; i++) {
            flash_attention_forward(h_Q, h_K, h_V, h_O, batch_heads, seq_len, head_dim, 1);
        }
        
        start = std::chrono::high_resolution_clock::now();
        for (int i = 0; i < 50; i++) {
            flash_attention_forward(h_Q, h_K, h_V, h_O, batch_heads, seq_len, head_dim, 1);
        }
        end = std::chrono::high_resolution_clock::now();
        double flash_ms = std::chrono::duration<double, std::milli>(end - start).count() / 50;
        
        double speedup = naive_ms / flash_ms;
        
        printf("%-10d | %-12.3f | %-12.3f | %-10.2fx\n",
               seq_len, naive_ms, flash_ms, speedup);
        
        cudaFree(d_Q);
        cudaFree(d_K);
        cudaFree(d_V);
        cudaFree(d_O);
        free(h_Q);
        free(h_K);
        free(h_V);
        free(h_O);
        flash_attention_cleanup();
    }
    
    printf("\n");
    return 0;
}

In [None]:
!nvcc -O3 -gencode arch=compute_75,code=sm_75 \
    -o compare_attention compare_attention.cu flash_attention.o -lcudart
!./compare_attention

## 6. Summary

In [None]:
import json
from datetime import datetime

summary = {
    "timestamp": datetime.now().isoformat(),
    "gpu": GPU_NAME,
    "compute_capability": COMPUTE_CAP,
    "benchmark": "Flash Attention",
    "implementation": "EdgeLLM custom CUDA",
    "features": [
        "Tiled computation",
        "Online softmax (O(N) memory)",
        "KV cache support",
        "Causal masking",
        "Optimized for SmolLM-135M (head_dim=64)"
    ],
    "target_throughput": 630,
    "ollama_baseline": 423
}

print(json.dumps(summary, indent=2))

with open('flash_attention_benchmark_report.json', 'w') as f:
    json.dump(summary, f, indent=2)
print("\nReport saved!")

## 7. Cleanup

In [None]:
!rm -f custom_fa_bench compare_attention custom_fa_bench.cu compare_attention.cu
print("Cleanup complete!")