# INT8 __dp4a Tensor Core Flash Attention Benchmark

Tests the optimized INT8 attention kernel using `__dp4a` intrinsics for 4x throughput.

**Optimizations:**
- `__dp4a`: 4-element INT8 dot product in single instruction
- Vectorized INT8 loads (int8x4 packed as int32)
- Async CUDA streams for overlapped copy/compute
- Proper warp-level reductions

**Expected Results:**
- 4x throughput improvement for Q@K^T matmul
- Low jitter due to deterministic execution

In [None]:
# Check GPU
!nvidia-smi --query-gpu=name,memory.total,compute_cap --format=csv

In [None]:
# Create kernel directory
!mkdir -p cuda_kernels

In [None]:
%%writefile cuda_kernels/flash_attention_int8.h
/**
 * INT8 Tensor Core Flash Attention - Header
 */
#ifndef FLASH_ATTENTION_INT8_H
#define FLASH_ATTENTION_INT8_H

#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

int flash_attention_int8_init(int max_batch_heads, int max_cache_len, int head_dim);
void flash_attention_int8_cleanup(void);
void flash_attention_int8_reset(void);

int quantize_tensor_int8(const float* input, int8_t* output, float* scale, int size);

int flash_attention_int8_update_cache(
    const int8_t* K_int8, const int8_t* V_int8,
    float scale_k, float scale_v,
    int batch_heads, int cache_pos);

int flash_attention_int8_decode(
    const int8_t* Q_int8, const int8_t* K_int8, const int8_t* V_int8, float* O,
    float scale_q, float scale_k, float scale_v,
    int batch_heads, int cache_pos, int head_dim);

int flash_attention_int8_decode_fp32(
    const float* Q, const float* K_new, const float* V_new, float* O,
    int batch_heads, int cache_pos, int head_dim);

int flash_attention_int8_decode_gpu(
    int batch_heads, int cache_len,
    float scale_q, float scale_k, float scale_v);

void flash_attention_int8_sync(void);
void flash_attention_int8_info(int* initialized, int* max_cache, int* max_bh, int* h_dim);

#ifdef __cplusplus
}
#endif

#endif

In [None]:
%%writefile cuda_kernels/flash_attention_int8.cu
/**
 * INT8 Tensor Core Flash Attention with __dp4a optimization
 */
#include <cuda_runtime.h>
#include <mma.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <math.h>
#include <float.h>

using namespace nvcuda;

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16
#define TC_THREADS 128
#define QUANT_SCALE 127.0f

#define CUDA_CHECK(call) do { \
    cudaError_t err = call; \
    if (err != cudaSuccess) { \
        fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
        return -1; \
    } \
} while(0)

// Quantization kernel
__global__ void quantize_fp32_to_int8_kernel(
    const float* __restrict__ input, int8_t* __restrict__ output,
    float* __restrict__ scale, int size
) {
    __shared__ float s_max;
    int tid = threadIdx.x;
    int idx = blockIdx.x * blockDim.x + tid;

    float local_max = 0.0f;
    for (int i = idx; i < size; i += gridDim.x * blockDim.x) {
        local_max = fmaxf(local_max, fabsf(input[i]));
    }

    __shared__ float shared_max[256];
    shared_max[tid] = local_max;
    __syncthreads();

    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (tid < s) shared_max[tid] = fmaxf(shared_max[tid], shared_max[tid + s]);
        __syncthreads();
    }

    if (tid == 0) atomicMax((int*)&s_max, __float_as_int(shared_max[0]));
    __syncthreads();

    float max_val = s_max;
    float quant_scale = (max_val > 0.0f) ? (QUANT_SCALE / max_val) : 1.0f;

    if (tid == 0 && blockIdx.x == 0) *scale = max_val / QUANT_SCALE;

    for (int i = idx; i < size; i += gridDim.x * blockDim.x) {
        float val = input[i] * quant_scale;
        val = fminf(fmaxf(val, -127.0f), 127.0f);
        output[i] = (int8_t)rintf(val);
    }
}

// Optimized __dp4a INT8 attention kernel
__global__ void flash_attention_int8_decode_dp4a_kernel(
    const int8_t* __restrict__ Q_int8,
    const int8_t* __restrict__ K_cache_int8,
    const int8_t* __restrict__ V_cache_int8,
    float* __restrict__ O,
    const float scale_q, const float scale_k, const float scale_v,
    const int cache_len, const int head_dim, const float attn_scale
) {
    const int batch_head_idx = blockIdx.x;
    const int tid = threadIdx.x;
    const int warp_id = tid / 32;
    const int lane_id = tid % 32;
    const int num_warps = TC_THREADS / 32;

    extern __shared__ char shared_mem[];
    float* s_scores = (float*)shared_mem;
    int* s_Q_packed = (int*)(s_scores + ((cache_len + 15) / 16) * 16);

    const int8_t* Q_ptr = Q_int8 + batch_head_idx * head_dim;
    const int8_t* K_ptr = K_cache_int8 + batch_head_idx * cache_len * head_dim;
    const int8_t* V_ptr = V_cache_int8 + batch_head_idx * cache_len * head_dim;
    float* O_ptr = O + batch_head_idx * head_dim;

    const int packed_dim = head_dim / 4;
    const int* Q_packed = (const int*)Q_ptr;
    for (int d = tid; d < packed_dim; d += TC_THREADS) {
        s_Q_packed[d] = Q_packed[d];
    }
    __syncthreads();

    // Phase 1: Q @ K^T using __dp4a
    float local_max = -FLT_MAX;

    for (int k_idx = tid; k_idx < cache_len; k_idx += TC_THREADS) {
        const int* K_row_packed = (const int*)(K_ptr + k_idx * head_dim);
        int32_t dot = 0;

        #pragma unroll 4
        for (int d = 0; d < packed_dim; d++) {
            dot = __dp4a(s_Q_packed[d], K_row_packed[d], dot);
        }

        float score = (float)dot * scale_q * scale_k * attn_scale;
        s_scores[k_idx] = score;
        local_max = fmaxf(local_max, score);
    }
    __syncthreads();

    // Warp reduction for max
    #pragma unroll
    for (int offset = 16; offset > 0; offset /= 2) {
        local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, offset));
    }

    __shared__ float s_block_max[4];
    if (lane_id == 0) s_block_max[warp_id] = local_max;
    __syncthreads();

    if (tid == 0) {
        float block_max = s_block_max[0];
        for (int i = 1; i < num_warps; i++) block_max = fmaxf(block_max, s_block_max[i]);
        s_block_max[0] = block_max;
    }
    __syncthreads();
    float global_max = s_block_max[0];

    // Softmax
    float local_sum = 0.0f;
    for (int k_idx = tid; k_idx < cache_len; k_idx += TC_THREADS) {
        float exp_score = expf(s_scores[k_idx] - global_max);
        s_scores[k_idx] = exp_score;
        local_sum += exp_score;
    }
    __syncthreads();

    #pragma unroll
    for (int offset = 16; offset > 0; offset /= 2) {
        local_sum += __shfl_down_sync(0xffffffff, local_sum, offset);
    }

    __shared__ float s_block_sum[4];
    if (lane_id == 0) s_block_sum[warp_id] = local_sum;
    __syncthreads();

    if (tid == 0) {
        float total_sum = 0.0f;
        for (int i = 0; i < num_warps; i++) total_sum += s_block_sum[i];
        s_block_sum[0] = total_sum;
    }
    __syncthreads();

    float inv_sum = 1.0f / s_block_sum[0];
    for (int k_idx = tid; k_idx < cache_len; k_idx += TC_THREADS) {
        s_scores[k_idx] *= inv_sum;
    }
    __syncthreads();

    // Phase 2: softmax @ V
    for (int d = tid; d < head_dim; d += TC_THREADS) {
        float acc = 0.0f;
        int k_idx = 0;

        for (; k_idx + 3 < cache_len; k_idx += 4) {
            float s0 = s_scores[k_idx];
            float s1 = s_scores[k_idx + 1];
            float s2 = s_scores[k_idx + 2];
            float s3 = s_scores[k_idx + 3];

            acc += s0 * (float)V_ptr[(k_idx + 0) * head_dim + d];
            acc += s1 * (float)V_ptr[(k_idx + 1) * head_dim + d];
            acc += s2 * (float)V_ptr[(k_idx + 2) * head_dim + d];
            acc += s3 * (float)V_ptr[(k_idx + 3) * head_dim + d];
        }

        for (; k_idx < cache_len; k_idx++) {
            acc += s_scores[k_idx] * (float)V_ptr[k_idx * head_dim + d];
        }

        O_ptr[d] = acc * scale_v;
    }
}

// C Interface
extern "C" {

static int8_t* d_Q_int8 = nullptr;
static cudaStream_t int8_compute_stream = nullptr;
static cudaStream_t int8_copy_stream = nullptr;
static int8_t* d_K_cache_int8 = nullptr;
static int8_t* d_V_cache_int8 = nullptr;
static float* d_O_float = nullptr;
static float* d_scales = nullptr;
static int int8_initialized = 0;
static int int8_max_cache_len = 0;
static int int8_max_batch_heads = 0;
static int int8_head_dim = 0;

void flash_attention_int8_cleanup(void);

int flash_attention_int8_init(int max_batch_heads, int max_cache_len, int head_dim) {
    if (int8_initialized) flash_attention_int8_cleanup();

    if (head_dim % 4 != 0) {
        fprintf(stderr, "head_dim must be multiple of 4 for __dp4a\n");
        return -1;
    }

    int8_max_batch_heads = max_batch_heads;
    int8_max_cache_len = max_cache_len;
    int8_head_dim = head_dim;

    size_t q_size = max_batch_heads * head_dim * sizeof(int8_t);
    size_t cache_size = max_batch_heads * max_cache_len * head_dim * sizeof(int8_t);
    size_t output_size = max_batch_heads * head_dim * sizeof(float);

    CUDA_CHECK(cudaMalloc(&d_Q_int8, q_size));
    CUDA_CHECK(cudaMalloc(&d_K_cache_int8, cache_size));
    CUDA_CHECK(cudaMalloc(&d_V_cache_int8, cache_size));
    CUDA_CHECK(cudaMalloc(&d_O_float, output_size));
    CUDA_CHECK(cudaMalloc(&d_scales, 3 * sizeof(float)));

    CUDA_CHECK(cudaStreamCreate(&int8_compute_stream));
    CUDA_CHECK(cudaStreamCreate(&int8_copy_stream));

    float default_scales[3] = {1.0f/127.0f, 1.0f/127.0f, 1.0f/127.0f};
    CUDA_CHECK(cudaMemcpy(d_scales, default_scales, 3*sizeof(float), cudaMemcpyHostToDevice));

    int8_initialized = 1;

    cudaDeviceProp prop;
    cudaGetDeviceProperties(&prop, 0);
    printf("INT8 __dp4a Flash Attention initialized\n");
    printf("  GPU: %s (SM %d.%d)\n", prop.name, prop.major, prop.minor);
    printf("  Config: batch_heads=%d, max_cache=%d, head_dim=%d\n",
           max_batch_heads, max_cache_len, head_dim);

    return 0;
}

void flash_attention_int8_cleanup(void) {
    if (d_Q_int8) { cudaFree(d_Q_int8); d_Q_int8 = nullptr; }
    if (d_K_cache_int8) { cudaFree(d_K_cache_int8); d_K_cache_int8 = nullptr; }
    if (d_V_cache_int8) { cudaFree(d_V_cache_int8); d_V_cache_int8 = nullptr; }
    if (d_O_float) { cudaFree(d_O_float); d_O_float = nullptr; }
    if (d_scales) { cudaFree(d_scales); d_scales = nullptr; }
    if (int8_compute_stream) { cudaStreamDestroy(int8_compute_stream); int8_compute_stream = nullptr; }
    if (int8_copy_stream) { cudaStreamDestroy(int8_copy_stream); int8_copy_stream = nullptr; }
    int8_initialized = 0;
}

void flash_attention_int8_reset(void) {}

int quantize_tensor_int8(const float* input, int8_t* output, float* scale, int size) {
    float* d_input;
    float* d_scale;
    int8_t* d_output;

    CUDA_CHECK(cudaMalloc(&d_input, size * sizeof(float)));
    CUDA_CHECK(cudaMalloc(&d_output, size * sizeof(int8_t)));
    CUDA_CHECK(cudaMalloc(&d_scale, sizeof(float)));

    CUDA_CHECK(cudaMemcpy(d_input, input, size * sizeof(float), cudaMemcpyHostToDevice));

    dim3 block(256);
    dim3 grid((size + 255) / 256);
    quantize_fp32_to_int8_kernel<<<grid, block>>>(d_input, d_output, d_scale, size);
    CUDA_CHECK(cudaGetLastError());

    CUDA_CHECK(cudaMemcpy(output, d_output, size * sizeof(int8_t), cudaMemcpyDeviceToHost));
    CUDA_CHECK(cudaMemcpy(scale, d_scale, sizeof(float), cudaMemcpyDeviceToHost));

    cudaFree(d_input);
    cudaFree(d_output);
    cudaFree(d_scale);

    return 0;
}

int flash_attention_int8_update_cache(
    const int8_t* K_int8, const int8_t* V_int8,
    float scale_k, float scale_v,
    int batch_heads, int cache_pos
) {
    if (!int8_initialized) return -1;

    for (int bh = 0; bh < batch_heads; bh++) {
        size_t cache_offset = (bh * int8_max_cache_len + cache_pos) * int8_head_dim;
        size_t src_offset = bh * int8_head_dim;

        CUDA_CHECK(cudaMemcpyAsync(d_K_cache_int8 + cache_offset, K_int8 + src_offset,
                                   int8_head_dim * sizeof(int8_t), cudaMemcpyHostToDevice, int8_copy_stream));
        CUDA_CHECK(cudaMemcpyAsync(d_V_cache_int8 + cache_offset, V_int8 + src_offset,
                                   int8_head_dim * sizeof(int8_t), cudaMemcpyHostToDevice, int8_copy_stream));
    }

    return 0;
}

int flash_attention_int8_decode(
    const int8_t* Q_int8, const int8_t* K_int8, const int8_t* V_int8, float* O,
    float scale_q, float scale_k, float scale_v,
    int batch_heads, int cache_pos, int head_dim
) {
    if (!int8_initialized) return -1;

    CUDA_CHECK(cudaMemcpyAsync(d_Q_int8, Q_int8, batch_heads * head_dim * sizeof(int8_t),
                               cudaMemcpyHostToDevice, int8_copy_stream));

    flash_attention_int8_update_cache(K_int8, V_int8, scale_k, scale_v, batch_heads, cache_pos);
    CUDA_CHECK(cudaStreamSynchronize(int8_copy_stream));

    int cache_len = cache_pos + 1;
    float attn_scale = 1.0f / sqrtf((float)head_dim);

    size_t scores_size = ((cache_len + 15) / 16 * 16) * sizeof(float);
    size_t q_packed_size = ((head_dim / 4) + 3) / 4 * 4 * sizeof(int);
    size_t smem_size = scores_size + q_packed_size;

    dim3 grid(batch_heads);
    dim3 block(TC_THREADS);

    flash_attention_int8_decode_dp4a_kernel<<<grid, block, smem_size, int8_compute_stream>>>(
        d_Q_int8, d_K_cache_int8, d_V_cache_int8, d_O_float,
        scale_q, scale_k, scale_v, cache_len, head_dim, attn_scale);

    CUDA_CHECK(cudaGetLastError());
    CUDA_CHECK(cudaStreamSynchronize(int8_compute_stream));

    CUDA_CHECK(cudaMemcpy(O, d_O_float, batch_heads * head_dim * sizeof(float), cudaMemcpyDeviceToHost));

    return 0;
}

int flash_attention_int8_decode_fp32(
    const float* Q, const float* K_new, const float* V_new, float* O,
    int batch_heads, int cache_pos, int head_dim
) {
    if (!int8_initialized) return -1;

    int single_size = batch_heads * head_dim;

    int8_t* h_Q_int8 = (int8_t*)malloc(single_size);
    int8_t* h_K_int8 = (int8_t*)malloc(single_size);
    int8_t* h_V_int8 = (int8_t*)malloc(single_size);
    float scale_q, scale_k, scale_v;

    quantize_tensor_int8(Q, h_Q_int8, &scale_q, single_size);
    quantize_tensor_int8(K_new, h_K_int8, &scale_k, single_size);
    quantize_tensor_int8(V_new, h_V_int8, &scale_v, single_size);

    int ret = flash_attention_int8_decode(
        h_Q_int8, h_K_int8, h_V_int8, O,
        scale_q, scale_k, scale_v,
        batch_heads, cache_pos, head_dim);

    free(h_Q_int8);
    free(h_K_int8);
    free(h_V_int8);

    return ret;
}

int flash_attention_int8_decode_gpu(int batch_heads, int cache_len, float scale_q, float scale_k, float scale_v) {
    if (!int8_initialized) return -1;

    float attn_scale = 1.0f / sqrtf((float)int8_head_dim);
    size_t scores_size = ((cache_len + 15) / 16 * 16) * sizeof(float);
    size_t q_packed_size = ((int8_head_dim / 4) + 3) / 4 * 4 * sizeof(int);
    size_t smem_size = scores_size + q_packed_size;

    dim3 grid(batch_heads);
    dim3 block(TC_THREADS);

    flash_attention_int8_decode_dp4a_kernel<<<grid, block, smem_size, int8_compute_stream>>>(
        d_Q_int8, d_K_cache_int8, d_V_cache_int8, d_O_float,
        scale_q, scale_k, scale_v, cache_len, int8_head_dim, attn_scale);

    CUDA_CHECK(cudaGetLastError());
    return 0;
}

void flash_attention_int8_sync(void) {
    if (int8_compute_stream) cudaStreamSynchronize(int8_compute_stream);
}

void flash_attention_int8_info(int* initialized, int* max_cache, int* max_bh, int* h_dim) {
    if (initialized) *initialized = int8_initialized;
    if (max_cache) *max_cache = int8_max_cache_len;
    if (max_bh) *max_bh = int8_max_batch_heads;
    if (h_dim) *h_dim = int8_head_dim;
}

} // extern "C"

In [None]:
%%writefile cuda_kernels/benchmark.cu
/**
 * INT8 __dp4a Flash Attention Benchmark
 */
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <chrono>
#include <algorithm>
#include <cuda_runtime.h>
#include "flash_attention_int8.h"

#define WARMUP_RUNS 50
#define BENCHMARK_RUNS 200

void fill_random(float* data, int size) {
    for (int i = 0; i < size; i++)
        data[i] = ((float)rand() / RAND_MAX - 0.5f) * 2.0f;
}

int main() {
    printf("\n=== INT8 __dp4a Flash Attention Benchmark ===\n\n");

    int device_count;
    cudaGetDeviceCount(&device_count);
    if (device_count == 0) {
        printf("ERROR: No CUDA GPUs found\n");
        return 1;
    }

    cudaDeviceProp prop;
    cudaGetDeviceProperties(&prop, 0);
    printf("GPU: %s (SM %d.%d)\n", prop.name, prop.major, prop.minor);
    printf("Memory: %.2f GB\n\n", prop.totalGlobalMem / 1e9);

    // Test configurations
    struct Config {
        const char* name;
        int num_heads;
        int head_dim;
        int num_layers;
    };

    Config configs[] = {
        {"SmolLM-135M", 9, 64, 9},
        {"Qwen-0.5B", 14, 64, 24},
        {"Qwen-1.5B", 12, 128, 28}
    };

    for (auto& cfg : configs) {
        printf("\n=== %s ===\n", cfg.name);
        printf("  Heads: %d, Head dim: %d, Layers: %d\n", cfg.num_heads, cfg.head_dim, cfg.num_layers);

        int batch_heads = cfg.num_heads;
        int cache_len = 256;

        flash_attention_int8_init(batch_heads, 2048, cfg.head_dim);

        int single_size = batch_heads * cfg.head_dim;
        float* Q = (float*)malloc(single_size * sizeof(float));
        float* K = (float*)malloc(single_size * sizeof(float));
        float* V = (float*)malloc(single_size * sizeof(float));
        float* O = (float*)malloc(single_size * sizeof(float));

        srand(42);

        // Fill cache
        for (int pos = 0; pos < cache_len; pos++) {
            fill_random(K, single_size);
            fill_random(V, single_size);
            flash_attention_int8_decode_fp32(Q, K, V, O, batch_heads, pos, cfg.head_dim);
        }

        fill_random(Q, single_size);
        fill_random(K, single_size);
        fill_random(V, single_size);

        // Warmup
        printf("  Warmup: %d runs...\n", WARMUP_RUNS);
        for (int i = 0; i < WARMUP_RUNS; i++) {
            flash_attention_int8_decode_fp32(Q, K, V, O, batch_heads, cache_len - 1, cfg.head_dim);
        }
        cudaDeviceSynchronize();

        // Benchmark
        printf("  Benchmark: %d runs...\n", BENCHMARK_RUNS);
        double* latencies = (double*)malloc(BENCHMARK_RUNS * sizeof(double));

        for (int i = 0; i < BENCHMARK_RUNS; i++) {
            auto start = std::chrono::high_resolution_clock::now();
            flash_attention_int8_decode_fp32(Q, K, V, O, batch_heads, cache_len - 1, cfg.head_dim);
            cudaDeviceSynchronize();
            auto end = std::chrono::high_resolution_clock::now();
            latencies[i] = std::chrono::duration<double, std::micro>(end - start).count();
        }

        // Sort for percentiles
        std::sort(latencies, latencies + BENCHMARK_RUNS);

        double min_us = latencies[0];
        double max_us = latencies[BENCHMARK_RUNS - 1];
        double median_us = latencies[BENCHMARK_RUNS / 2];
        double p95_us = latencies[(int)(BENCHMARK_RUNS * 0.95)];
        double p99_us = latencies[(int)(BENCHMARK_RUNS * 0.99)];

        double mean_us = 0;
        for (int i = 0; i < BENCHMARK_RUNS; i++) mean_us += latencies[i];
        mean_us /= BENCHMARK_RUNS;

        double variance = 0;
        for (int i = 0; i < BENCHMARK_RUNS; i++) {
            double diff = latencies[i] - mean_us;
            variance += diff * diff;
        }
        double std_us = sqrt(variance / BENCHMARK_RUNS);

        double median_ms = median_us / 1000.0;
        double per_token_attn_ms = median_ms * cfg.num_layers;
        double attn_throughput = 1000.0 / per_token_attn_ms;
        double estimated_throughput = 1000.0 / (per_token_attn_ms / 0.35);

        printf("\n  Results:\n");
        printf("    Layer latency (median): %.2f us (%.4f ms)\n", median_us, median_ms);
        printf("    Layer latency (mean):   %.2f +/- %.2f us\n", mean_us, std_us);
        printf("    Layer latency (P95):    %.2f us\n", p95_us);
        printf("    Layer latency (P99):    %.2f us\n", p99_us);
        printf("    Jitter (std/median):    %.2f%%\n", (std_us / median_us) * 100);
        printf("\n");
        printf("    Per-token attention:    %.3f ms\n", per_token_attn_ms);
        printf("    Attn-only throughput:   %.1f tok/s\n", attn_throughput);
        printf("    Est. total throughput:  %.1f tok/s\n", estimated_throughput);

        printf("\n  JSON:\n");
        printf("  {\n");
        printf("    \"model\": \"%s\",\n", cfg.name);
        printf("    \"kernel\": \"INT8_dp4a\",\n");
        printf("    \"layer_latency_us\": %.2f,\n", median_us);
        printf("    \"std_us\": %.2f,\n", std_us);
        printf("    \"p95_us\": %.2f,\n", p95_us);
        printf("    \"p99_us\": %.2f,\n", p99_us);
        printf("    \"attn_throughput\": %.1f,\n", attn_throughput);
        printf("    \"estimated_throughput\": %.1f\n", estimated_throughput);
        printf("  }\n");

        free(Q);
        free(K);
        free(V);
        free(O);
        free(latencies);
        flash_attention_int8_cleanup();
    }

    printf("\n=== Benchmark Complete ===\n");
    printf("\nTargets: 630 tok/s (EdgeLLM), Baseline: 423 tok/s (Ollama)\n");

    return 0;
}

In [None]:
# Compile
!nvcc -O3 -gencode arch=compute_75,code=sm_75 \
     --expt-relaxed-constexpr \
     -I cuda_kernels \
     cuda_kernels/flash_attention_int8.cu cuda_kernels/benchmark.cu \
     -o benchmark_int8_dp4a \
     -lcudart

In [None]:
# Run benchmark
!./benchmark_int8_dp4a

## Analysis

The `__dp4a` intrinsic provides 4x throughput for INT8 dot products by computing:
```
d = c + a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3]
```
where `a` and `b` are packed int8x4 values.

**Expected improvements:**
- 4x faster Q@K^T matmul
- Lower memory bandwidth (INT8 vs FP32/FP16)
- Better cache utilization due to smaller data types