# EdgeLLM INT8 Tensor Core Flash Attention Benchmark

This notebook benchmarks our **INT8 Tensor Core** Flash Attention implementation on Tesla T4 GPU.

**Key Features:**
- INT8 quantized Q, K, V matrices
- WMMA (Warp Matrix Multiply Accumulate) for Tensor Core acceleration
- ~8x theoretical speedup over FP32

**Target:** 630+ tok/s (vs Ollama 423 tok/s baseline)

## 1. Environment Setup

In [None]:
import subprocess
import os

# Check GPU
result = subprocess.run(['nvidia-smi', '--query-gpu=name,compute_cap,memory.total',
                        '--format=csv,noheader'], capture_output=True, text=True)
gpu_info = result.stdout.strip()
print(f"GPU: {gpu_info}")

# Parse compute capability
parts = gpu_info.split(', ')
GPU_NAME = parts[0] if len(parts) > 0 else 'Unknown'
COMPUTE_CAP = parts[1] if len(parts) > 1 else '0.0'
print(f"Compute Capability: {COMPUTE_CAP}")

# Verify Tensor Core support (sm_75+)
major, minor = COMPUTE_CAP.split('.')
if int(major) >= 7 and int(minor) >= 5:
    print("INT8 Tensor Cores: SUPPORTED")
else:
    print("WARNING: INT8 Tensor Cores require sm_75+ (Turing or newer)")

In [None]:
# Clone repository
!rm -rf ollama-api-gateway
!git clone --depth 1 https://github.com/umerkhan95/ollama-api-gateway.git
%cd ollama-api-gateway/mojo-gateway

## 2. Build INT8 Tensor Core Kernels

In [None]:
# Check CUDA version
!nvcc --version

In [None]:
# Build INT8 Tensor Core Flash Attention for T4
%cd src/kernels/cuda
!make clean

# Build with T4-specific optimizations (sm_75)
!make CUDA_ARCH="-gencode arch=compute_75,code=sm_75" \
      NVCC_FLAGS_COMMON="-O3 -Xcompiler -fPIC -Xcompiler -Wall --expt-relaxed-constexpr" \
      int8

# Also build FP32 for comparison
!make CUDA_ARCH="-gencode arch=compute_75,code=sm_75" \
      NVCC_FLAGS_COMMON="-O3 -Xcompiler -fPIC -Xcompiler -Wall --expt-relaxed-constexpr" \
      flash

print("\nBuild complete!")
!ls -la ../../../lib/*.so 2>/dev/null || echo "Checking for libraries..."

## 3. Run INT8 Tests

In [None]:
# Build and run INT8 Flash Attention tests
!make CUDA_ARCH="-gencode arch=compute_75,code=sm_75" \
      NVCC_FLAGS_COMMON="-O3 -Xcompiler -fPIC -Xcompiler -Wall --expt-relaxed-constexpr" \
      test-int8

## 4. Custom INT8 vs FP32 Benchmark

In [None]:
%%writefile custom_int8_bench.cu
#include <stdio.h>
#include <stdlib.h>
#include <chrono>
#include <cuda_runtime.h>
#include "flash_attention_int8.h"
#include "flash_attention.h"

#define WARMUP 20
#define RUNS 200

void fill_random(float* data, int size) {
    for (int i = 0; i < size; i++) {
        data[i] = ((float)rand() / RAND_MAX - 0.5f) * 2.0f;
    }
}

int main() {
    printf("\n=== INT8 Tensor Core vs FP32 Flash Attention ===");
    printf("\n\nSmolLM-135M Configuration:\n");
    printf("  Heads: 9, Head dim: 64, Layers: 9\n\n");

    int batch_heads = 9, head_dim = 64, num_layers = 9, cache_len = 256;

    flash_attention_init(1, 9, 2048, head_dim);
    flash_attention_init_kv_cache(1, 9, 2048, head_dim);
    flash_attention_int8_init(batch_heads, 2048, head_dim);

    int single_size = batch_heads * head_dim;
    float* Q = (float*)malloc(single_size * sizeof(float));
    float* K = (float*)malloc(single_size * sizeof(float));
    float* V = (float*)malloc(single_size * sizeof(float));
    float* O = (float*)malloc(single_size * sizeof(float));

    srand(42);
    for (int pos = 0; pos < cache_len; pos++) {
        fill_random(K, single_size);
        fill_random(V, single_size);
        flash_attention_update_kv_cache(K, V, batch_heads, pos, 1, head_dim);
        flash_attention_int8_decode_fp32(Q, K, V, O, batch_heads, pos, head_dim);
    }
    fill_random(Q, single_size);

    // FP32 benchmark
    for (int i = 0; i < WARMUP; i++)
        flash_attention_decode(Q, K, V, O, batch_heads, cache_len - 1, head_dim);
    cudaDeviceSynchronize();

    auto start = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < RUNS; i++)
        flash_attention_decode(Q, K, V, O, batch_heads, cache_len - 1, head_dim);
    cudaDeviceSynchronize();
    auto end = std::chrono::high_resolution_clock::now();
    double fp32_ms = std::chrono::duration<double, std::milli>(end - start).count() / RUNS;

    // INT8 benchmark
    for (int i = 0; i < WARMUP; i++)
        flash_attention_int8_decode_fp32(Q, K, V, O, batch_heads, cache_len - 1, head_dim);
    cudaDeviceSynchronize();

    start = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < RUNS; i++)
        flash_attention_int8_decode_fp32(Q, K, V, O, batch_heads, cache_len - 1, head_dim);
    cudaDeviceSynchronize();
    end = std::chrono::high_resolution_clock::now();
    double int8_ms = std::chrono::duration<double, std::milli>(end - start).count() / RUNS;

    double speedup = fp32_ms / int8_ms;
    printf("Results (cache_len=%d):\n", cache_len);
    printf("  FP32 per-layer: %.4f ms\n", fp32_ms);
    printf("  INT8 per-layer: %.4f ms\n", int8_ms);
    printf("  Speedup: %.2fx\n\n", speedup);

    double fp32_tok = 1000.0 / (fp32_ms * num_layers / 0.35);
    double int8_tok = 1000.0 / (int8_ms * num_layers / 0.35);
    printf("Estimated throughput:\n");
    printf("  FP32: %.1f tok/s\n", fp32_tok);
    printf("  INT8: %.1f tok/s\n", int8_tok);
    printf("  Target: 630 tok/s, Ollama: 423 tok/s\n\n");

    printf("JSON: {\"fp32_ms\":%.4f,\"int8_ms\":%.4f,\"speedup\":%.2f,\"fp32_tok\":%.1f,\"int8_tok\":%.1f}\n",
           fp32_ms, int8_ms, speedup, fp32_tok, int8_tok);

    free(Q); free(K); free(V); free(O);
    flash_attention_cleanup();
    flash_attention_int8_cleanup();
    return 0;
}

In [None]:
!nvcc -O3 -gencode arch=compute_75,code=sm_75 --expt-relaxed-constexpr \
    -o custom_int8_bench custom_int8_bench.cu \
    flash_attention_int8.o flash_attention.o -lcudart
!./custom_int8_bench

## 5. Summary

In [None]:
import json
from datetime import datetime

summary = {
    "timestamp": datetime.now().isoformat(),
    "gpu": GPU_NAME,
    "compute_capability": COMPUTE_CAP,
    "benchmark": "INT8 Tensor Core Flash Attention",
    "target_throughput": 630,
    "ollama_baseline": 423
}
print(json.dumps(summary, indent=2))

In [None]:
!rm -f custom_int8_bench custom_int8_bench.cu
print("Cleanup complete!")