In [None]:
# ‚öôÔ∏è Setup
import subprocess, sys
try:
    import google.colab
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "numba"])
except ImportError:
    pass

import numpy as np
from numba import cuda
import time

print("‚ö†Ô∏è  CUDA C++ is the PRIMARY learning material!")

---

## Part 1: Attention Overview

### Scaled Dot-Product Attention

```
Attention(Q, K, V) = softmax(Q @ K.T / sqrt(d_k)) @ V

Shapes:
  Q: [batch, heads, seq_len, head_dim]  = [B, H, N, D]
  K: [batch, heads, seq_len, head_dim]  = [B, H, N, D]
  V: [batch, heads, seq_len, head_dim]  = [B, H, N, D]
  
  Scores = Q @ K.T:    [B, H, N, N]  ‚Üê O(N¬≤) memory!
  Attention @ V:       [B, H, N, D]
```

### Memory Challenge

```
For GPT-style model:
  Sequence length N = 2048
  Batch size B = 32
  Heads H = 32
  
  Attention matrix: B √ó H √ó N √ó N √ó 4 bytes
                  = 32 √ó 32 √ó 2048 √ó 2048 √ó 4
                  = 17 GB just for attention scores!
```

### üî∑ CUDA C++ Implementation (Primary)

In [None]:
%%writefile naive_attention.cu
// naive_attention.cu - Basic attention implementation
#include <stdio.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <float.h>

#define CHECK_CUDA(call) { \
    cudaError_t err = call; \
    if (err != cudaSuccess) { \
        printf("CUDA error: %s\n", cudaGetErrorString(err)); exit(1); \
    } \
}

// Softmax kernel (in-place, row-wise)
__global__ void softmax_kernel(
    float* scores, int rows, int cols
) {
    int row = blockIdx.x;
    if (row >= rows) return;
    
    float* row_data = scores + row * cols;
    
    // Find max
    float max_val = -FLT_MAX;
    for (int i = 0; i < cols; i++) {
        max_val = fmaxf(max_val, row_data[i]);
    }
    
    // Exp and sum
    float sum = 0.0f;
    for (int i = 0; i < cols; i++) {
        row_data[i] = expf(row_data[i] - max_val);
        sum += row_data[i];
    }
    
    // Normalize
    for (int i = 0; i < cols; i++) {
        row_data[i] /= sum;
    }
}

int main() {
    printf("=== Naive Attention Implementation ===\n\n");
    
    // Small example for correctness
    const int batch = 1;
    const int heads = 8;
    const int seq_len = 64;
    const int head_dim = 64;
    const float scale = 1.0f / sqrtf(head_dim);
    
    int total_heads = batch * heads;
    
    printf("Configuration:\n");
    printf("  Batch: %d, Heads: %d\n", batch, heads);
    printf("  Sequence length: %d\n", seq_len);
    printf("  Head dimension: %d\n\n", head_dim);
    
    // Allocate Q, K, V, Scores, Output
    float *d_Q, *d_K, *d_V, *d_scores, *d_output;
    CHECK_CUDA(cudaMalloc(&d_Q, total_heads * seq_len * head_dim * sizeof(float)));
    CHECK_CUDA(cudaMalloc(&d_K, total_heads * seq_len * head_dim * sizeof(float)));
    CHECK_CUDA(cudaMalloc(&d_V, total_heads * seq_len * head_dim * sizeof(float)));
    CHECK_CUDA(cudaMalloc(&d_scores, total_heads * seq_len * seq_len * sizeof(float)));
    CHECK_CUDA(cudaMalloc(&d_output, total_heads * seq_len * head_dim * sizeof(float)));
    
    // Initialize with random data
    float* h_data = new float[total_heads * seq_len * head_dim];
    for (int i = 0; i < total_heads * seq_len * head_dim; i++) {
        h_data[i] = (rand() % 1000) / 1000.0f - 0.5f;
    }
    CHECK_CUDA(cudaMemcpy(d_Q, h_data, total_heads * seq_len * head_dim * sizeof(float), cudaMemcpyHostToDevice));
    CHECK_CUDA(cudaMemcpy(d_K, h_data, total_heads * seq_len * head_dim * sizeof(float), cudaMemcpyHostToDevice));
    CHECK_CUDA(cudaMemcpy(d_V, h_data, total_heads * seq_len * head_dim * sizeof(float), cudaMemcpyHostToDevice));
    
    cublasHandle_t handle;
    cublasCreate(&handle);
    
    cudaEvent_t start, stop;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);
    
    cudaEventRecord(start);
    
    for (int h = 0; h < total_heads; h++) {
        float* Q_h = d_Q + h * seq_len * head_dim;
        float* K_h = d_K + h * seq_len * head_dim;
        float* V_h = d_V + h * seq_len * head_dim;
        float* scores_h = d_scores + h * seq_len * seq_len;
        float* out_h = d_output + h * seq_len * head_dim;
        
        // Step 1: Scores = Q @ K.T * scale
        float alpha = scale, beta = 0.0f;
        cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N,
                   seq_len, seq_len, head_dim,
                   &alpha,
                   K_h, head_dim,
                   Q_h, head_dim,
                   &beta,
                   scores_h, seq_len);
        
        // Step 2: Softmax
        softmax_kernel<<<seq_len, 1>>>(scores_h, seq_len, seq_len);
        
        // Step 3: Output = Attention @ V
        alpha = 1.0f;
        cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N,
                   head_dim, seq_len, seq_len,
                   &alpha,
                   V_h, head_dim,
                   scores_h, seq_len,
                   &beta,
                   out_h, head_dim);
    }
    
    cudaEventRecord(stop);
    cudaEventSynchronize(stop);
    
    float ms;
    cudaEventElapsedTime(&ms, start, stop);
    
    printf("Attention computed in %.3f ms\n", ms);
    
    // Memory usage
    float scores_mem = total_heads * seq_len * seq_len * 4.0f / 1e6;
    printf("\nMemory Analysis:\n");
    printf("  Attention scores: %.2f MB\n", scores_mem);
    printf("  For seq_len=2048: %.2f MB\n", 
           total_heads * 2048.0f * 2048.0f * 4.0f / 1e6);
    
    // Cleanup
    delete[] h_data;
    cudaFree(d_Q); cudaFree(d_K); cudaFree(d_V);
    cudaFree(d_scores); cudaFree(d_output);
    cublasDestroy(handle);
    
    return 0;
}

In [None]:
!nvcc -arch=sm_75 -lcublas -o naive_attention naive_attention.cu
!./naive_attention

---

## Part 2: Tiled Attention

### Reducing Memory with Tiling

```
Instead of materializing full N√óN attention matrix:

For each query block Q_i:
    For each key block K_j:
        Compute partial scores S_ij = Q_i @ K_j.T
        Update running softmax statistics
        Accumulate output contribution
        
Memory: O(block_size¬≤) instead of O(N¬≤)
```

### üî∑ CUDA C++ Implementation (Primary)

In [None]:
%%writefile tiled_attention.cu
// tiled_attention.cu - Memory-efficient tiled attention
#include <stdio.h>
#include <cuda_runtime.h>
#include <float.h>

#define BLOCK_SIZE 32

// Online softmax update
__device__ void update_softmax(
    float* m_prev,      // Previous max
    float* l_prev,      // Previous sum
    float* o_prev,      // Previous output (head_dim values)
    float* scores,      // Current block scores (BLOCK_SIZE values)
    float* v_block,     // Current V block
    int head_dim
) {
    // Find max in current block
    float m_new = *m_prev;
    for (int i = 0; i < BLOCK_SIZE; i++) {
        m_new = fmaxf(m_new, scores[i]);
    }
    
    // Compute exp and sum for current block
    float l_new = 0.0f;
    float exp_scores[BLOCK_SIZE];
    for (int i = 0; i < BLOCK_SIZE; i++) {
        exp_scores[i] = expf(scores[i] - m_new);
        l_new += exp_scores[i];
    }
    
    // Update running sum with rescaling
    float scale_prev = expf(*m_prev - m_new);
    float l_combined = scale_prev * (*l_prev) + l_new;
    
    // Update output with rescaling
    for (int d = 0; d < head_dim; d++) {
        float o_scaled = scale_prev * (*l_prev) * o_prev[d];
        float v_contribution = 0.0f;
        for (int i = 0; i < BLOCK_SIZE; i++) {
            v_contribution += exp_scores[i] * v_block[i * head_dim + d];
        }
        o_prev[d] = (o_scaled + v_contribution) / l_combined;
    }
    
    *m_prev = m_new;
    *l_prev = l_combined;
}

// Simplified tiled attention (single query row)
__global__ void tiled_attention_kernel(
    float* Q, float* K, float* V, float* output,
    int seq_len, int head_dim, float scale
) {
    int query_idx = blockIdx.x;
    
    if (query_idx >= seq_len) return;
    
    // Load query row
    extern __shared__ float shared[];
    float* q_row = shared;  // head_dim
    float* k_block = shared + head_dim;  // BLOCK_SIZE * head_dim
    float* v_block = shared + head_dim + BLOCK_SIZE * head_dim;
    float* scores = shared + head_dim + 2 * BLOCK_SIZE * head_dim;
    
    // Each thread helps load Q
    int tid = threadIdx.x;
    if (tid < head_dim) {
        q_row[tid] = Q[query_idx * head_dim + tid];
    }
    __syncthreads();
    
    // Running softmax statistics
    float m = -FLT_MAX;  // Running max
    float l = 0.0f;      // Running sum
    float o[64];         // Output accumulator (assuming head_dim <= 64)
    for (int d = 0; d < head_dim; d++) o[d] = 0.0f;
    
    // Process K,V in blocks
    int num_blocks = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
    
    for (int block = 0; block < num_blocks; block++) {
        int block_start = block * BLOCK_SIZE;
        int block_len = min(BLOCK_SIZE, seq_len - block_start);
        
        // Load K block
        for (int i = tid; i < block_len * head_dim; i += blockDim.x) {
            int row = i / head_dim;
            int col = i % head_dim;
            k_block[row * head_dim + col] = K[(block_start + row) * head_dim + col];
            v_block[row * head_dim + col] = V[(block_start + row) * head_dim + col];
        }
        __syncthreads();
        
        // Compute scores for this block
        if (tid < block_len) {
            float score = 0.0f;
            for (int d = 0; d < head_dim; d++) {
                score += q_row[d] * k_block[tid * head_dim + d];
            }
            scores[tid] = score * scale;
        }
        __syncthreads();
        
        // Thread 0 updates softmax (simplified)
        if (tid == 0) {
            // Find new max
            float m_new = m;
            for (int i = 0; i < block_len; i++) {
                m_new = fmaxf(m_new, scores[i]);
            }
            
            // Compute exp scores and sum
            float l_new = 0.0f;
            for (int i = 0; i < block_len; i++) {
                scores[i] = expf(scores[i] - m_new);
                l_new += scores[i];
            }
            
            // Rescale previous output
            float scale_prev = expf(m - m_new);
            for (int d = 0; d < head_dim; d++) {
                o[d] *= scale_prev;
            }
            
            // Add contribution from current block
            for (int i = 0; i < block_len; i++) {
                for (int d = 0; d < head_dim; d++) {
                    o[d] += scores[i] * v_block[i * head_dim + d];
                }
            }
            
            // Update statistics
            l = scale_prev * l + l_new;
            m = m_new;
        }
        __syncthreads();
    }
    
    // Write output
    if (tid == 0) {
        for (int d = 0; d < head_dim; d++) {
            output[query_idx * head_dim + d] = o[d] / l;
        }
    }
}

int main() {
    printf("=== Tiled Attention (Memory Efficient) ===\n\n");
    
    const int seq_len = 256;
    const int head_dim = 64;
    const float scale = 1.0f / sqrtf(head_dim);
    
    printf("Sequence length: %d\n", seq_len);
    printf("Block size: %d\n", BLOCK_SIZE);
    printf("Head dimension: %d\n\n", head_dim);
    
    float *d_Q, *d_K, *d_V, *d_output;
    cudaMalloc(&d_Q, seq_len * head_dim * sizeof(float));
    cudaMalloc(&d_K, seq_len * head_dim * sizeof(float));
    cudaMalloc(&d_V, seq_len * head_dim * sizeof(float));
    cudaMalloc(&d_output, seq_len * head_dim * sizeof(float));
    
    // Initialize
    float* h_data = new float[seq_len * head_dim];
    for (int i = 0; i < seq_len * head_dim; i++) {
        h_data[i] = (rand() % 1000) / 1000.0f - 0.5f;
    }
    cudaMemcpy(d_Q, h_data, seq_len * head_dim * sizeof(float), cudaMemcpyHostToDevice);
    cudaMemcpy(d_K, h_data, seq_len * head_dim * sizeof(float), cudaMemcpyHostToDevice);
    cudaMemcpy(d_V, h_data, seq_len * head_dim * sizeof(float), cudaMemcpyHostToDevice);
    
    // Shared memory size
    int shared_size = (head_dim + 2 * BLOCK_SIZE * head_dim + BLOCK_SIZE) * sizeof(float);
    
    cudaEvent_t start, stop;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);
    
    cudaEventRecord(start);
    for (int i = 0; i < 100; i++) {
        tiled_attention_kernel<<<seq_len, BLOCK_SIZE, shared_size>>>(
            d_Q, d_K, d_V, d_output, seq_len, head_dim, scale
        );
    }
    cudaEventRecord(stop);
    cudaEventSynchronize(stop);
    
    float ms;
    cudaEventElapsedTime(&ms, start, stop);
    
    printf("Tiled attention: %.3f ms (100 iterations)\n", ms);
    printf("\nMemory comparison:\n");
    printf("  Naive (full matrix): %.2f KB\n", seq_len * seq_len * 4.0f / 1024);
    printf("  Tiled (per block):   %.2f KB\n", BLOCK_SIZE * BLOCK_SIZE * 4.0f / 1024);
    printf("  Reduction: %.0fx\n", (float)(seq_len * seq_len) / (BLOCK_SIZE * BLOCK_SIZE));
    
    delete[] h_data;
    cudaFree(d_Q); cudaFree(d_K); cudaFree(d_V); cudaFree(d_output);
    
    return 0;
}

In [None]:
!nvcc -arch=sm_75 -o tiled_attention tiled_attention.cu
!./tiled_attention

---

## Part 3: Flash Attention Concepts

### Key Innovations

```
Flash Attention improvements over naive:

1. Online Softmax
   - Compute softmax incrementally
   - Track running max and sum
   - Rescale accumulated outputs

2. IO Awareness
   - Minimize HBM ‚Üî SRAM transfers
   - Keep Q, K, V tiles in shared memory
   - Never materialize full attention matrix

3. Tiling Strategy
   - Tile Q in outer loop
   - Tile K, V in inner loop
   - Fuse all operations
   
Result: 2-4x faster, O(N) memory
```

### üî∂ Python/Numba (Optional - Conceptual Demo)

In [None]:
# Flash Attention algorithm (conceptual Python)
def flash_attention_algorithm():
    print("Flash Attention Algorithm")
    print("=" * 50)
    print()
    print("for each query block Q_i:")
    print("    initialize: m_i = -inf, l_i = 0, O_i = 0")
    print("    ")
    print("    for each key/value block (K_j, V_j):")
    print("        # Compute local attention scores")
        print("        S_ij = Q_i @ K_j.T * scale")
    print("        ")
    print("        # Online softmax update")
    print("        m_new = max(m_i, rowmax(S_ij))")
    print("        P_ij = exp(S_ij - m_new)")
    print("        l_new = exp(m_i - m_new) * l_i + rowsum(P_ij)")
    print("        ")
    print("        # Rescale and accumulate output")
    print("        O_i = exp(m_i - m_new) * O_i + P_ij @ V_j")
    print("        m_i, l_i = m_new, l_new")
    print("    ")
    print("    # Final normalization")
    print("    O_i = O_i / l_i")

flash_attention_algorithm()

---

## Exercises

### Exercise 1: Causal Masking

Modify the attention to support causal (autoregressive) masking:
- Only attend to positions j ‚â§ i
- Set masked positions to -infinity before softmax

---

## Summary

### Attention Optimization Techniques

| Technique | Memory | Speed | Complexity |
|-----------|--------|-------|------------|
| Naive | O(N¬≤) | Baseline | Simple |
| Tiled | O(B¬≤) | 2x | Moderate |
| Flash Attention | O(N) | 2-4x | Complex |

### Key Takeaways

1. **Memory is the bottleneck** for long sequences
2. **Online softmax** enables streaming computation
3. **Tiling** trades compute for memory
4. Use **cuBLAS batched GEMM** for multi-head attention