# CUDA Programming Fundamentals - Interactive Notebook

This notebook provides hands-on experiments with CUDA programming concepts.
Each cell demonstrates a key concept with profiled measurements.

## Topics
1. Thread Hierarchy
2. Memory Coalescing
3. Kernel Fusion
4. Memory-Bound vs Compute-Bound
5. Synchronization
6. Profiling Techniques

In [None]:
# Setup
import torch
import torch.nn.functional as F
import time
import math

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Profiling utility with proper CUDA timing
def profile_cuda(func, warmup=10, iterations=100, name=""):
    """Profile CUDA function with proper synchronization."""
    if not torch.cuda.is_available():
        return 0.0
    
    # Warmup
    for _ in range(warmup):
        func()
    torch.cuda.synchronize()
    
    # Time with CUDA events
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    for _ in range(iterations):
        func()
    end.record()
    torch.cuda.synchronize()
    
    return start.elapsed_time(end) / iterations

print("Profiling utility ready!")

## 1. CUDA Thread Hierarchy

```
Grid (all blocks)
└── Block (group of threads, share memory)
    └── Warp (32 threads, execute together)
        └── Thread (individual execution unit)
```

In [None]:
# Demonstrate how work scales with size
print("How PyTorch launches kernels (conceptually):")
print(f"{'Size':<15} {'Est. Blocks':<15} {'Threads/Block':<15} {'Time (ms)'}")
print("-" * 60)

for size in [1024, 65536, 1048576, 16777216]:
    x = torch.randn(size, device='cuda')
    
    time_ms = profile_cuda(lambda: x + 1.0, iterations=1000)
    
    # Estimate launch config (256 threads/block is common)
    threads_per_block = 256
    num_blocks = (size + threads_per_block - 1) // threads_per_block
    
    print(f"{size:<15} {num_blocks:<15} {threads_per_block:<15} {time_ms:.4f}")

print("\nKey insight: More elements = more blocks = more parallelism")

## 2. Memory Coalescing

When threads in a warp access consecutive memory addresses, GPU combines into fewer transactions.

**Good (coalesced):** Thread 0→addr 0, Thread 1→addr 1, Thread 2→addr 2...

**Bad (strided):** Thread 0→addr 0, Thread 1→addr N, Thread 2→addr 2N...

In [None]:
# Coalesced vs strided access
rows, cols = 4096, 4096
x = torch.randn(rows, cols, device='cuda')

print(f"Matrix shape: {rows} x {cols}")
print(f"\n{'Access Pattern':<30} {'Time (ms)':<15} {'Bandwidth (GB/s)'}")
print("-" * 60)

bytes_total = rows * cols * 4

# Row-wise (coalesced - last dim is contiguous)
time_row = profile_cuda(lambda: x.sum(dim=1))
bw_row = bytes_total / (time_row / 1000) / 1e9
print(f"{'Row sum (coalesced)':<30} {time_row:<15.3f} {bw_row:.1f}")

# Column-wise (strided)
time_col = profile_cuda(lambda: x.sum(dim=0))
bw_col = bytes_total / (time_col / 1000) / 1e9
print(f"{'Column sum (strided)':<30} {time_col:<15.3f} {bw_col:.1f}")

print(f"\nStrided is {time_col/time_row:.1f}x slower!")
print("Always access memory along the contiguous (last) dimension!")

## 3. Kernel Fusion

Fusing operations reduces memory traffic - intermediate results stay in registers.

In [None]:
size = 16_000_000
x = torch.randn(size, device='cuda')

# Unfused: Each operation writes to memory
def unfused():
    a = x + 1
    b = a * 2
    c = b - 0.5
    d = torch.relu(c)
    return d

# Partially fused (compiler may help)
def partially_fused():
    return torch.relu(x * 2 + 1.5)

time_unfused = profile_cuda(unfused)
time_fused = profile_cuda(partially_fused)

print(f"Element-wise operations ({size/1e6:.0f}M elements):")
print(f"  Unfused (4 kernels): {time_unfused:.3f} ms")
print(f"  Fused (1 kernel):    {time_fused:.3f} ms")
print(f"  Speedup: {time_unfused/time_fused:.2f}x")
print(f"\nMemory traffic: Unfused ~8 passes, Fused ~2 passes")

In [None]:
# torch.compile fusion
try:
    compiled_unfused = torch.compile(unfused)
    _ = compiled_unfused()  # Warmup
    torch.cuda.synchronize()
    
    time_compiled = profile_cuda(compiled_unfused)
    print(f"torch.compile (auto-fused): {time_compiled:.3f} ms")
    print(f"Speedup vs unfused: {time_unfused/time_compiled:.2f}x")
except Exception as e:
    print(f"torch.compile not available: {e}")

## 4. Memory-Bound vs Compute-Bound

**Arithmetic Intensity** = FLOPs / Bytes moved

- Low AI → Memory-bound (waiting for data)
- High AI → Compute-bound (waiting for math)

In [None]:
N = 4096

print(f"{'Operation':<25} {'AI (FLOP/Byte)':<18} {'Type':<15} {'Time (ms)'}")
print("-" * 75)

# Vector add: 1 FLOP per 12 bytes
a = torch.randn(N * N, device='cuda')
b = torch.randn(N * N, device='cuda')
time_add = profile_cuda(lambda: a + b)
print(f"{'a + b':<25} {1/12:<18.3f} {'MEMORY':<15} {time_add:.3f}")

# Softmax: ~10 FLOPs per 8 bytes
x = torch.randn(N, N, device='cuda')
time_soft = profile_cuda(lambda: F.softmax(x, dim=-1))
print(f"{'softmax':<25} {10/8:<18.3f} {'MEMORY':<15} {time_soft:.3f}")

# Matrix multiply: 2N FLOPs per element
A = torch.randn(N, N, device='cuda')
B = torch.randn(N, N, device='cuda')
time_mm = profile_cuda(lambda: A @ B, iterations=50)
print(f"{'A @ B':<25} {2*N/12:<18.0f} {'COMPUTE':<15} {time_mm:.3f}")

print(f"\nMost element-wise ops are memory-bound!")
print(f"Matrix multiply is compute-bound (uses Tensor Cores).")

## 5. Synchronization Importance

CUDA operations are **asynchronous** - they return immediately!
Must synchronize for correct timing.

In [None]:
x = torch.randn(10000, 10000, device='cuda')

# WRONG: Without sync (measures only launch time)
start = time.perf_counter()
for _ in range(10):
    y = x @ x
no_sync_time = (time.perf_counter() - start) * 1000 / 10
torch.cuda.synchronize()  # Complete the work

# CORRECT: With sync (measures actual execution)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(10):
    y = x @ x
torch.cuda.synchronize()
sync_time = (time.perf_counter() - start) * 1000 / 10

print(f"Without sync (WRONG): {no_sync_time:.2f} ms")
print(f"With sync (CORRECT):  {sync_time:.2f} ms")
print(f"\nWithout sync is {sync_time/no_sync_time:.0f}x faster - but it's a lie!")
print("Always use CUDA events or synchronize for timing.")

## 6. Mixed Precision

FP16/BF16 uses Tensor Cores for massive speedup.

In [None]:
N = 4096
A = torch.randn(N, N, device='cuda')
B = torch.randn(N, N, device='cuda')

# FP32
time_fp32 = profile_cuda(lambda: A @ B, iterations=50)

# FP16
A_fp16 = A.half()
B_fp16 = B.half()
time_fp16 = profile_cuda(lambda: A_fp16 @ B_fp16, iterations=50)

# Calculate TFLOPS
flops = 2 * N * N * N
tflops_fp32 = flops / (time_fp32 / 1000) / 1e12
tflops_fp16 = flops / (time_fp16 / 1000) / 1e12

print(f"Matrix multiply {N}x{N}:")
print(f"  FP32: {time_fp32:.3f} ms ({tflops_fp32:.1f} TFLOPS)")
print(f"  FP16: {time_fp16:.3f} ms ({tflops_fp16:.1f} TFLOPS)")
print(f"  Speedup: {time_fp32/time_fp16:.2f}x")
print(f"\nFP16 uses Tensor Cores!")

## 7. Flash Attention Benefit

Flash Attention uses O(N) memory instead of O(N²).

In [None]:
batch, heads, head_dim = 4, 8, 64

print(f"Attention comparison (batch={batch}, heads={heads}, head_dim={head_dim}):")
print(f"\n{'Seq Len':<12} {'Standard (ms)':<18} {'SDPA (ms)':<18} {'Speedup'}")
print("-" * 60)

for seq_len in [256, 512, 1024, 2048]:
    Q = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
    K = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
    V = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
    
    # Standard attention
    def standard():
        scale = 1.0 / math.sqrt(head_dim)
        scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
        attn = F.softmax(scores, dim=-1)
        return torch.matmul(attn, V)
    
    # SDPA (uses Flash Attention)
    def sdpa():
        return F.scaled_dot_product_attention(Q, K, V)
    
    time_standard = profile_cuda(standard, iterations=50)
    time_sdpa = profile_cuda(sdpa, iterations=50)
    
    print(f"{seq_len:<12} {time_standard:<18.3f} {time_sdpa:<18.3f} {time_standard/time_sdpa:.2f}x")

print(f"\nFlash Attention: O(N) memory instead of O(N²)!")

## Summary

Key CUDA concepts demonstrated:

In [None]:
print("""
╔══════════════════════════════════════════════════════════════════╗
║                 CUDA FUNDAMENTALS SUMMARY                        ║
╠══════════════════════════════════════════════════════════════════╣
║                                                                  ║
║  1. THREAD HIERARCHY                                             ║
║     Grid → Blocks → Warps → Threads                              ║
║     More elements = more parallelism                             ║
║                                                                  ║
║  2. MEMORY COALESCING                                            ║
║     Access consecutive addresses within warps                    ║
║     Strided access is 2-10x slower!                              ║
║                                                                  ║
║  3. KERNEL FUSION                                                ║
║     Fuse operations to reduce memory traffic                     ║
║     torch.compile does this automatically                        ║
║                                                                  ║
║  4. MEMORY vs COMPUTE BOUND                                      ║
║     Most ops are memory-bound                                    ║
║     Matmul is compute-bound (Tensor Cores)                       ║
║                                                                  ║
║  5. SYNCHRONIZATION                                              ║
║     Always sync for correct timing!                              ║
║     Use CUDA events for best precision                           ║
║                                                                  ║
║  6. MIXED PRECISION                                              ║
║     FP16/BF16 = Tensor Cores = massive speedup                   ║
║                                                                  ║
╚══════════════════════════════════════════════════════════════════╝
""")