# Triton Programming - Interactive Notebook

Learn GPU programming with Triton - high-level, Pythonic, and fast!

## Topics
1. Your First Triton Kernel
2. Memory Access Patterns
3. Auto-tuning
4. Fused Operations
5. Softmax Implementation
6. Performance Comparison

In [None]:
# Setup
import torch
import triton
import triton.language as tl
import time

print(f"PyTorch: {torch.__version__}")
print(f"Triton: {triton.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Profiling utility
def profile_triton(func, warmup=25, iterations=100):
    """Profile a Triton/PyTorch function."""
    for _ in range(warmup):
        func()
    torch.cuda.synchronize()
    
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    for _ in range(iterations):
        func()
    end.record()
    torch.cuda.synchronize()
    
    return start.elapsed_time(end) / iterations

print("Profiling ready!")

## 1. Your First Triton Kernel: Vector Addition

The "Hello World" of GPU programming.

In [None]:
@triton.jit
def add_kernel(
    x_ptr,      # Pointer to input x
    y_ptr,      # Pointer to input y  
    out_ptr,    # Pointer to output
    n_elements, # Total elements
    BLOCK_SIZE: tl.constexpr,  # Compile-time constant
):
    """
    Vector addition: out = x + y
    
    Key Triton concepts:
    - tl.program_id(0): Which block am I? (like blockIdx.x in CUDA)
    - tl.arange(0, BLOCK_SIZE): Range of offsets within block
    - tl.load/tl.store: Memory operations with automatic coalescing
    - mask: Handle boundary conditions
    """
    # Which block (program) is this?
    pid = tl.program_id(axis=0)
    
    # Calculate offsets for this block
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    
    # Mask for out-of-bounds elements
    mask = offsets < n_elements
    
    # Load data
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    
    # Compute
    out = x + y
    
    # Store result
    tl.store(out_ptr + offsets, out, mask=mask)

print("Kernel defined!")

In [None]:
def triton_add(x, y):
    """Wrapper function for the Triton kernel."""
    output = torch.empty_like(x)
    n_elements = output.numel()
    
    # Calculate grid size (number of blocks)
    BLOCK_SIZE = 1024
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    
    # Launch kernel
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=BLOCK_SIZE)
    
    return output

# Test correctness
size = 1_000_000
x = torch.randn(size, device='cuda')
y = torch.randn(size, device='cuda')

out_triton = triton_add(x, y)
out_torch = x + y

print(f"Correctness: {torch.allclose(out_triton, out_torch)}")
print(f"Max difference: {torch.max(torch.abs(out_triton - out_torch)):.2e}")

In [None]:
# Performance comparison
print(f"\nPerformance ({size/1e6:.0f}M elements):")
print(f"{'Method':<20} {'Time (ms)':<15}")
print("-" * 35)

time_triton = profile_triton(lambda: triton_add(x, y))
time_torch = profile_triton(lambda: x + y)

print(f"{'Triton':<20} {time_triton:.4f}")
print(f"{'PyTorch':<20} {time_torch:.4f}")
print(f"\nFor simple ops, PyTorch is already optimized.")
print(f"Triton shines for CUSTOM fused operations!")

## 2. Fused Operations - Where Triton Shines

Triton excels at fusing multiple operations into one kernel.

In [None]:
@triton.jit
def fused_gelu_kernel(
    x_ptr, out_ptr, n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    """
    Fused GELU activation.
    GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
    
    All operations fused = minimal memory traffic!
    """
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    
    # Load
    x = tl.load(x_ptr + offsets, mask=mask)
    
    # Fused GELU computation (all in registers!)
    x3 = x * x * x
    inner = 0.7978845608 * (x + 0.044715 * x3)  # sqrt(2/pi)
    tanh_inner = tl.libdevice.tanh(inner)
    result = 0.5 * x * (1.0 + tanh_inner)
    
    # Store
    tl.store(out_ptr + offsets, result, mask=mask)

def triton_gelu(x):
    out = torch.empty_like(x)
    n = out.numel()
    grid = lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),)
    fused_gelu_kernel[grid](x, out, n, BLOCK_SIZE=1024)
    return out

print("Fused GELU kernel defined!")

In [None]:
import torch.nn.functional as F

size = 16_000_000
x = torch.randn(size, device='cuda')

# Verify correctness
out_triton = triton_gelu(x)
out_torch = F.gelu(x)
print(f"Correctness: {torch.allclose(out_triton, out_torch, atol=1e-5)}")

# Manual unfused GELU
def manual_gelu(x):
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608 * (x + 0.044715 * x**3)))

# Benchmark
time_manual = profile_triton(lambda: manual_gelu(x))
time_pytorch = profile_triton(lambda: F.gelu(x))
time_triton = profile_triton(lambda: triton_gelu(x))

print(f"\nGELU Performance ({size/1e6:.0f}M elements):")
print(f"{'Method':<25} {'Time (ms)':<15} {'Speedup'}")
print("-" * 55)
print(f"{'Manual (unfused)':<25} {time_manual:<15.3f} 1.0x")
print(f"{'PyTorch F.gelu':<25} {time_pytorch:<15.3f} {time_manual/time_pytorch:.2f}x")
print(f"{'Triton fused':<25} {time_triton:<15.3f} {time_manual/time_triton:.2f}x")

## 3. Auto-tuning

Triton can automatically find optimal configurations.

In [None]:
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 64}),
        triton.Config({'BLOCK_SIZE': 128}),
        triton.Config({'BLOCK_SIZE': 256}),
        triton.Config({'BLOCK_SIZE': 512}),
        triton.Config({'BLOCK_SIZE': 1024}),
    ],
    key=['n_elements'],  # Re-tune when size changes
)
@triton.jit
def autotuned_add_kernel(
    x_ptr, y_ptr, out_ptr, n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    tl.store(out_ptr + offsets, x + y, mask=mask)

def autotuned_add(x, y):
    out = torch.empty_like(x)
    n = out.numel()
    grid = lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),)
    autotuned_add_kernel[grid](x, y, out, n)
    return out

# Test auto-tuning at different sizes
print("Auto-tuning in action:")
for size in [1024, 65536, 1048576, 16777216]:
    x = torch.randn(size, device='cuda')
    y = torch.randn(size, device='cuda')
    
    # First call triggers auto-tuning
    _ = autotuned_add(x, y)
    torch.cuda.synchronize()
    
    time_ms = profile_triton(lambda: autotuned_add(x, y))
    print(f"  Size {size:>10}: {time_ms:.4f} ms")

## 4. Softmax - A Complete Example

Softmax shows reduction + element-wise operations.

In [None]:
@triton.jit
def softmax_kernel(
    input_ptr, output_ptr,
    input_stride, output_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    """
    Fused softmax kernel - one row per program.
    
    softmax(x_i) = exp(x_i - max) / sum(exp(x_j - max))
    
    All in one kernel = minimal memory traffic!
    """
    row_idx = tl.program_id(0)
    
    # Pointers to this row
    row_start = input_ptr + row_idx * input_stride
    col_offsets = tl.arange(0, BLOCK_SIZE)
    
    # Load row (mask for valid columns)
    mask = col_offsets < n_cols
    row = tl.load(row_start + col_offsets, mask=mask, other=-float('inf'))
    
    # Compute softmax
    row_max = tl.max(row, axis=0)
    numerator = tl.exp(row - row_max)
    numerator = tl.where(mask, numerator, 0.0)
    denominator = tl.sum(numerator, axis=0)
    softmax_out = numerator / denominator
    
    # Store
    out_start = output_ptr + row_idx * output_stride
    tl.store(out_start + col_offsets, softmax_out, mask=mask)

def triton_softmax(x):
    n_rows, n_cols = x.shape
    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    BLOCK_SIZE = min(BLOCK_SIZE, 8192)
    
    out = torch.empty_like(x)
    softmax_kernel[(n_rows,)](
        x, out,
        x.stride(0), out.stride(0),
        n_cols,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return out

print("Softmax kernel defined!")

In [None]:
# Test softmax
rows, cols = 1024, 1024
x = torch.randn(rows, cols, device='cuda')

out_triton = triton_softmax(x)
out_torch = F.softmax(x, dim=-1)

print(f"Shape: {rows}x{cols}")
print(f"Correctness: {torch.allclose(out_triton, out_torch, atol=1e-5)}")

# Benchmark at different sizes
print(f"\nSoftmax Performance:")
print(f"{'Shape':<18} {'PyTorch (ms)':<15} {'Triton (ms)':<15} {'Speedup'}")
print("-" * 60)

for rows, cols in [(128, 128), (512, 512), (1024, 1024), (2048, 2048)]:
    x = torch.randn(rows, cols, device='cuda')
    
    time_torch = profile_triton(lambda: F.softmax(x, dim=-1))
    time_triton = profile_triton(lambda: triton_softmax(x))
    
    print(f"{str((rows,cols)):<18} {time_torch:<15.4f} {time_triton:<15.4f} {time_torch/time_triton:.2f}x")

## Summary

In [None]:
print("""
╔══════════════════════════════════════════════════════════════════╗
║                   TRITON PROGRAMMING SUMMARY                     ║
╠══════════════════════════════════════════════════════════════════╣
║                                                                  ║
║  TRITON KERNEL TEMPLATE:                                         ║
║  ──────────────────────────────────────────────────────────────  ║
║  @triton.jit                                                     ║
║  def kernel(ptr, n, BLOCK: tl.constexpr):                        ║
║      pid = tl.program_id(0)           # Block index              ║
║      offs = pid * BLOCK + tl.arange(0, BLOCK)  # Offsets         ║
║      mask = offs < n                  # Bounds check             ║
║      data = tl.load(ptr + offs, mask=mask)  # Load               ║
║      # ... compute ...                                           ║
║      tl.store(ptr + offs, data, mask=mask)  # Store              ║
║                                                                  ║
║  KEY FUNCTIONS:                                                  ║
║  ──────────────────────────────────────────────────────────────  ║
║  • tl.program_id(axis)  - Block index                            ║
║  • tl.arange(start,end) - Range of offsets                       ║
║  • tl.load(ptr, mask)   - Load with masking                      ║
║  • tl.store(ptr, val)   - Store with masking                     ║
║  • tl.dot(a, b)         - Matrix multiply (Tensor Cores!)        ║
║  • tl.sum/max/min       - Reductions                             ║
║                                                                  ║
║  WHEN TO USE TRITON:                                             ║
║  ──────────────────────────────────────────────────────────────  ║
║  ✓ Custom fused operations                                       ║
║  ✓ Operations not in PyTorch                                     ║
║  ✓ Memory-bound ops needing fusion                               ║
║  ✓ Flash Attention variants                                      ║
║  ✗ Simple ops (PyTorch is already optimized)                     ║
║                                                                  ║
╚══════════════════════════════════════════════════════════════════╝
""")