# FlashAttention-2 Benchmark

Comparing attention implementations on Tesla T4:
- **FP32 Flash Attention** - Basic CUDA implementation
- **INT8 Tensor Core** - WMMA-based INT8 attention
- **FlashAttention-2** - Tiled with online softmax

**Target:** 630+ tok/s (vs Ollama 423 tok/s baseline)

## 1. Environment Setup

In [None]:
import subprocess
import os

# Check GPU
result = subprocess.run(['nvidia-smi', '--query-gpu=name,compute_cap,memory.total',
                        '--format=csv,noheader'], capture_output=True, text=True)
gpu_info = result.stdout.strip()
print(f"GPU: {gpu_info}")

parts = gpu_info.split(', ')
GPU_NAME = parts[0] if len(parts) > 0 else 'Unknown'
COMPUTE_CAP = parts[1] if len(parts) > 1 else '0.0'
print(f"Compute Capability: {COMPUTE_CAP}")

In [None]:
# Clone repository
!rm -rf ollama-api-gateway
!git clone --depth 1 https://github.com/umerkhan95/ollama-api-gateway.git
%cd ollama-api-gateway/mojo-gateway

## 2. Build All Attention Kernels

In [None]:
%cd src/kernels/cuda
!make clean

# Build all attention variants for T4
!make CUDA_ARCH="-gencode arch=compute_75,code=sm_75" \
      NVCC_FLAGS_COMMON="-O3 -Xcompiler -fPIC -Xcompiler -Wall" \
      flash

!make CUDA_ARCH="-gencode arch=compute_75,code=sm_75" \
      NVCC_FLAGS_COMMON="-O3 -Xcompiler -fPIC -Xcompiler -Wall --expt-relaxed-constexpr" \
      int8

!make CUDA_ARCH="-gencode arch=compute_75,code=sm_75" \
      NVCC_FLAGS_COMMON="-O3 -Xcompiler -fPIC -Xcompiler -Wall" \
      fa2

print("\nBuild complete!")
!ls -la ../../../lib/*.so

## 3. Comprehensive Benchmark

In [None]:
%%writefile fa2_benchmark.cu
#include <stdio.h>
#include <stdlib.h>
#include <chrono>
#include <cuda_runtime.h>
#include "flash_attention.h"
#include "flash_attention_int8.h"
#include "flash_attention_v2.h"

#define WARMUP 50
#define RUNS 500

void fill_random(float* data, int size) {
    for (int i = 0; i < size; i++) {
        data[i] = ((float)rand() / RAND_MAX - 0.5f) * 2.0f;
    }
}

double benchmark_fp32(float* Q, float* K, float* V, float* O,
                      int batch_heads, int cache_len, int head_dim) {
    // Warmup
    for (int i = 0; i < WARMUP; i++) {
        flash_attention_decode(Q, K, V, O, batch_heads, cache_len - 1, head_dim);
    }
    cudaDeviceSynchronize();

    auto start = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < RUNS; i++) {
        flash_attention_decode(Q, K, V, O, batch_heads, cache_len - 1, head_dim);
    }
    cudaDeviceSynchronize();
    auto end = std::chrono::high_resolution_clock::now();

    return std::chrono::duration<double, std::milli>(end - start).count() / RUNS;
}

double benchmark_int8(float* Q, float* K, float* V, float* O,
                      int batch_heads, int cache_len, int head_dim) {
    for (int i = 0; i < WARMUP; i++) {
        flash_attention_int8_decode_fp32(Q, K, V, O, batch_heads, cache_len - 1, head_dim);
    }
    cudaDeviceSynchronize();

    auto start = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < RUNS; i++) {
        flash_attention_int8_decode_fp32(Q, K, V, O, batch_heads, cache_len - 1, head_dim);
    }
    cudaDeviceSynchronize();
    auto end = std::chrono::high_resolution_clock::now();

    return std::chrono::duration<double, std::milli>(end - start).count() / RUNS;
}

double benchmark_fa2(float* Q, float* K, float* V, float* O,
                     int batch_heads, int cache_len, int head_dim) {
    for (int i = 0; i < WARMUP; i++) {
        flash_attention_v2_decode(Q, K, V, O, batch_heads, cache_len - 1, head_dim);
    }
    cudaDeviceSynchronize();

    auto start = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < RUNS; i++) {
        flash_attention_v2_decode(Q, K, V, O, batch_heads, cache_len - 1, head_dim);
    }
    cudaDeviceSynchronize();
    auto end = std::chrono::high_resolution_clock::now();

    return std::chrono::duration<double, std::milli>(end - start).count() / RUNS;
}

int main() {
    printf("\n" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "\n");
    printf("  FlashAttention-2 Benchmark - SmolLM-135M on T4\n");
    printf("=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "=" "\n\n");

    // SmolLM-135M configuration
    int batch_heads = 9, head_dim = 64, num_layers = 9;
    int cache_lengths[] = {64, 128, 256, 512};
    int num_cache_lengths = 4;

    printf("Configuration: heads=%d, head_dim=%d, layers=%d\n\n", batch_heads, head_dim, num_layers);

    // Initialize all attention implementations
    flash_attention_init(1, 9, 2048, head_dim);
    flash_attention_init_kv_cache(1, 9, 2048, head_dim);
    flash_attention_int8_init(batch_heads, 2048, head_dim);
    flash_attention_v2_init(batch_heads, 2048, head_dim);

    int single_size = batch_heads * head_dim;
    float* Q = (float*)malloc(single_size * sizeof(float));
    float* K = (float*)malloc(single_size * sizeof(float));
    float* V = (float*)malloc(single_size * sizeof(float));
    float* O = (float*)malloc(single_size * sizeof(float));

    srand(42);

    printf("| Cache Len |  FP32 (ms) | INT8 TC (ms) |   FA2 (ms) | Best Speedup |\n");
    printf("|-----------|------------|--------------|------------|--------------|\n");

    double best_tok = 0;
    const char* best_impl = "";

    for (int c = 0; c < num_cache_lengths; c++) {
        int cache_len = cache_lengths[c];

        // Fill cache
        for (int pos = 0; pos < cache_len; pos++) {
            fill_random(K, single_size);
            fill_random(V, single_size);
            flash_attention_update_kv_cache(K, V, batch_heads, pos, 1, head_dim);
            flash_attention_int8_decode_fp32(Q, K, V, O, batch_heads, pos, head_dim);
            flash_attention_v2_decode(Q, K, V, O, batch_heads, pos, head_dim);
        }
        fill_random(Q, single_size);

        double fp32_ms = benchmark_fp32(Q, K, V, O, batch_heads, cache_len, head_dim);
        double int8_ms = benchmark_int8(Q, K, V, O, batch_heads, cache_len, head_dim);
        double fa2_ms = benchmark_fa2(Q, K, V, O, batch_heads, cache_len, head_dim);

        double best_ms = fmin(fmin(fp32_ms, int8_ms), fa2_ms);
        double speedup = fp32_ms / best_ms;

        printf("| %9d | %10.4f | %12.4f | %10.4f | %10.2fx  |\n",
               cache_len, fp32_ms, int8_ms, fa2_ms, speedup);

        // Calculate throughput for cache_len=256 (typical)
        if (cache_len == 256) {
            double fp32_tok = 1000.0 / (fp32_ms * num_layers);
            double int8_tok = 1000.0 / (int8_ms * num_layers);
            double fa2_tok = 1000.0 / (fa2_ms * num_layers);

            if (fp32_tok > best_tok) { best_tok = fp32_tok; best_impl = "FP32"; }
            if (int8_tok > best_tok) { best_tok = int8_tok; best_impl = "INT8 TC"; }
            if (fa2_tok > best_tok) { best_tok = fa2_tok; best_impl = "FA2"; }
        }
    }

    printf("\n");
    printf("Throughput Estimate (cache_len=256, attention only):\n");
    printf("  Best: %.1f tok/s (%s)\n", best_tok, best_impl);
    printf("  Target: 630 tok/s\n");
    printf("  Ollama: 423 tok/s\n");
    printf("\n");

    // JSON output
    printf("JSON: {\"best_tok\":%.1f,\"best_impl\":\"%s\",\"target\":630,\"ollama\":423}\n",
           best_tok, best_impl);

    free(Q); free(K); free(V); free(O);
    flash_attention_cleanup();
    flash_attention_int8_cleanup();
    flash_attention_v2_cleanup();

    return 0;
}

In [None]:
# Compile and run benchmark
!nvcc -O3 -gencode arch=compute_75,code=sm_75 --expt-relaxed-constexpr \
    -o fa2_benchmark fa2_benchmark.cu \
    flash_attention.o flash_attention_int8.o flash_attention_v2.o -lcudart
!./fa2_benchmark

## 4. Summary

In [None]:
import json
from datetime import datetime

summary = {
    "timestamp": datetime.now().isoformat(),
    "gpu": GPU_NAME,
    "compute_capability": COMPUTE_CAP,
    "benchmark": "FlashAttention-2 Comparison",
    "implementations": ["FP32", "INT8 Tensor Core", "FlashAttention-2"],
    "target_throughput": 630,
    "ollama_baseline": 423
}
print(json.dumps(summary, indent=2))

In [None]:
!rm -f fa2_benchmark fa2_benchmark.cu
print("Cleanup complete!")