# ML Compilers: Interactive Learning Notebook

This notebook provides hands-on experiments with ML compilers.
Run each cell to understand what compilers do and why they matter.

## Topics Covered
1. Why Compilers Matter - Eager vs Compiled
2. PyTorch torch.compile
3. Kernel Fusion Demonstration
4. Profiling Compiled Code
5. Comparison Across Methods

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. Profiling Utility

First, let's create a proper profiling function that handles CUDA synchronization.

In [None]:
def profile_fn(func, warmup=10, iterations=100, name=""):
    """Profile a function with proper CUDA timing."""
    # Warmup
    for _ in range(warmup):
        func()
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        
        start.record()
        for _ in range(iterations):
            func()
        end.record()
        torch.cuda.synchronize()
        
        return start.elapsed_time(end) / iterations
    else:
        start = time.perf_counter()
        for _ in range(iterations):
            func()
        return (time.perf_counter() - start) * 1000 / iterations

print("Profiling utility ready!")

## 2. Why Compilers Matter: Eager vs Compiled

Let's see the difference between eager execution and compiled execution.

In [None]:
# Simple operations that should be fused
size = 10_000_000
device = 'cuda' if torch.cuda.is_available() else 'cpu'

a = torch.randn(size, device=device)
b = torch.randn(size, device=device)
c = torch.randn(size, device=device)

# Eager mode: Each operation is a separate kernel
def eager_compute():
    x = a + b        # Kernel 1: Read a,b → Write x
    y = x * c        # Kernel 2: Read x,c → Write y  
    z = torch.relu(y) # Kernel 3: Read y → Write z
    return z

time_eager = profile_fn(eager_compute)
print(f"Eager mode: {time_eager:.3f} ms")
print(f"Memory traffic: 6 passes (read a,b,write x, read x,c,write y, read y,write z)")

In [None]:
# Compiled mode: Operations are fused into one kernel
try:
    compiled_compute = torch.compile(eager_compute)
    
    # First call triggers compilation
    _ = compiled_compute()
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    time_compiled = profile_fn(compiled_compute)
    
    print(f"Compiled mode: {time_compiled:.3f} ms")
    print(f"Memory traffic: 2 passes (read a,b,c → write z)")
    print(f"\nSpeedup: {time_eager/time_compiled:.2f}x")
    print(f"\nWhy faster? Compiler fused 3 kernels into 1!")
except Exception as e:
    print(f"torch.compile not available: {e}")

## 3. Visualizing the Compilation

Let's understand what torch.compile does under the hood.

In [None]:
# Explain what Dynamo captures
def simple_fn(x):
    y = x.sin()
    z = y.cos()
    return z + 1

try:
    x = torch.randn(100, device=device)
    explanation = torch._dynamo.explain(simple_fn)(x)
    
    print("TorchDynamo Analysis:")
    print(f"  Graph breaks: {explanation.graph_break_count}")
    print(f"  Operations captured: {len(explanation.graphs)} graph(s)")
    print(f"\nNo graph breaks = everything can be compiled together!")
except Exception as e:
    print(f"Dynamo explain not available: {e}")

## 4. GELU Fusion Example

GELU is a great example because it involves multiple operations.

In [None]:
size = 16_000_000
x = torch.randn(size, device=device)

# Manual GELU: Multiple operations
def manual_gelu(x):
    # GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608 * (x + 0.044715 * x**3)))

# PyTorch's fused GELU
def pytorch_gelu(x):
    return F.gelu(x)

# Profile
time_manual = profile_fn(lambda: manual_gelu(x))
time_pytorch = profile_fn(lambda: pytorch_gelu(x))

print(f"GELU Implementations ({size/1e6:.0f}M elements):")
print(f"  Manual (unfused):  {time_manual:.3f} ms")
print(f"  PyTorch (fused):   {time_pytorch:.3f} ms")
print(f"  Speedup: {time_manual/time_pytorch:.2f}x")

In [None]:
# Let torch.compile fuse our manual implementation
try:
    compiled_gelu = torch.compile(manual_gelu)
    _ = compiled_gelu(x)  # Warmup/compile
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    time_compiled = profile_fn(lambda: compiled_gelu(x))
    
    print(f"torch.compile on manual GELU: {time_compiled:.3f} ms")
    print(f"Speedup vs manual: {time_manual/time_compiled:.2f}x")
    print(f"\nCompiler automatically fused all the operations!")
except Exception as e:
    print(f"Compilation error: {e}")

## 5. Model Compilation

Let's compile an entire model and see the benefits.

In [None]:
# Simple transformer-like block
class SimpleBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.ln = nn.LayerNorm(dim)
        self.fc1 = nn.Linear(dim, dim * 4)
        self.fc2 = nn.Linear(dim * 4, dim)
    
    def forward(self, x):
        x = self.ln(x)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x

# Setup
batch, seq, dim = 32, 512, 768
model = SimpleBlock(dim).to(device)
x = torch.randn(batch, seq, dim, device=device)

# Eager
time_eager = profile_fn(lambda: model(x), iterations=50)
print(f"Model input: ({batch}, {seq}, {dim})")
print(f"Eager mode: {time_eager:.3f} ms")

In [None]:
# Compile with different modes
results = [("Eager", time_eager)]

modes = [
    ("default", {}),
    ("reduce-overhead", {"mode": "reduce-overhead"}),
]

for mode_name, kwargs in modes:
    try:
        compiled_model = torch.compile(model, **kwargs)
        
        # Warmup
        for _ in range(5):
            _ = compiled_model(x)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        time_compiled = profile_fn(lambda: compiled_model(x), iterations=50)
        results.append((f"compile ({mode_name})", time_compiled))
    except Exception as e:
        print(f"Mode {mode_name} failed: {e}")

# Print results
print(f"\n{'Mode':<30} {'Time (ms)':<15} {'Speedup'}")
print("-" * 55)
base_time = results[0][1]
for name, t in results:
    print(f"{name:<30} {t:<15.3f} {base_time/t:.2f}x")

## 6. Mixed Precision Impact

Compilers also optimize for different precisions.

In [None]:
# Matrix multiply at different precisions
N = 4096

# FP32
A_fp32 = torch.randn(N, N, device=device)
B_fp32 = torch.randn(N, N, device=device)

# FP16
A_fp16 = A_fp32.half()
B_fp16 = B_fp32.half()

# Profile
time_fp32 = profile_fn(lambda: A_fp32 @ B_fp32, iterations=50)
time_fp16 = profile_fn(lambda: A_fp16 @ B_fp16, iterations=50)

# Calculate TFLOPS
flops = 2 * N * N * N
tflops_fp32 = flops / (time_fp32 / 1000) / 1e12
tflops_fp16 = flops / (time_fp16 / 1000) / 1e12

print(f"Matrix multiply {N}x{N}:")
print(f"  FP32: {time_fp32:.3f} ms ({tflops_fp32:.1f} TFLOPS)")
print(f"  FP16: {time_fp16:.3f} ms ({tflops_fp16:.1f} TFLOPS)")
print(f"  Speedup: {time_fp32/time_fp16:.2f}x")
print(f"\nFP16 uses Tensor Cores for massive speedup!")

## 7. Memory-Bound vs Compute-Bound

Understanding this distinction is key to optimization.

In [None]:
N = 4096

# Element-wise operation (memory-bound)
x = torch.randn(N * N, device=device)
time_add = profile_fn(lambda: x + 1.0)

# Arithmetic intensity = 1 FLOP / 8 bytes = 0.125
print("Element-wise add (memory-bound):")
print(f"  Time: {time_add:.3f} ms")
print(f"  Arithmetic intensity: ~0.125 FLOP/byte")

# Matrix multiply (compute-bound)
A = torch.randn(N, N, device=device)
B = torch.randn(N, N, device=device)
time_mm = profile_fn(lambda: A @ B, iterations=50)

# Arithmetic intensity = 2N FLOPs / 8 bytes ≈ 1024 for N=4096
print(f"\nMatrix multiply (compute-bound):")
print(f"  Time: {time_mm:.3f} ms")
print(f"  Arithmetic intensity: ~{2*N/8:.0f} FLOP/byte")

print(f"\nKey insight:")
print(f"  Memory-bound ops benefit from FUSION (less memory traffic)")
print(f"  Compute-bound ops benefit from Tensor Cores (more compute)")

## 8. Profiling with PyTorch Profiler

Let's see what kernels are actually being run.

In [None]:
from torch.profiler import profile, ProfilerActivity

# Profile eager execution
model = SimpleBlock(512).to(device)
x = torch.randn(16, 256, 512, device=device)

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True
) as prof:
    for _ in range(5):
        _ = model(x)

print("Eager mode kernel breakdown:")
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

In [None]:
# Profile compiled execution
try:
    compiled_model = torch.compile(model)
    # Warmup
    for _ in range(3):
        _ = compiled_model(x)
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        record_shapes=True
    ) as prof:
        for _ in range(5):
            _ = compiled_model(x)
    
    print("Compiled mode kernel breakdown:")
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    print("\nNotice: Triton kernels appear, operations are fused!")
except Exception as e:
    print(f"Compilation error: {e}")

## 9. Summary

Key takeaways from this notebook:

In [None]:
print("""
╔══════════════════════════════════════════════════════════════════╗
║                    ML COMPILERS SUMMARY                          ║
╠══════════════════════════════════════════════════════════════════╣
║                                                                  ║
║  1. KERNEL FUSION is the #1 optimization                        ║
║     • Reduces memory traffic                                    ║
║     • Most ML ops are memory-bound                              ║
║                                                                  ║
║  2. torch.compile is easy to use                                ║
║     • Just wrap your model: torch.compile(model)                ║
║     • Works with most PyTorch code                              ║
║                                                                  ║
║  3. PROFILING is essential                                      ║
║     • Always sync CUDA before timing                            ║
║     • Use PyTorch profiler for details                          ║
║                                                                  ║
║  4. Different compilers for different use cases                 ║
║     • Training: torch.compile                                   ║
║     • NVIDIA inference: TensorRT                                ║
║     • TPU: JAX/XLA                                              ║
║     • Edge: TVM, ONNX Runtime                                   ║
║                                                                  ║
╚══════════════════════════════════════════════════════════════════╝
""")