# Optimized Transformer Benchmark
This notebook compares your `ActualOptimizedTransformer` from `transformer.py` against a vanilla PyTorch transformer.
- Measures **inference time** and **peak GPU memory** usage.
- Tests multiple sequence lengths.
- Shows model size breakdown.
- Verifies chunked attention matches FlashAttention (SDPA) numerically.

In [None]:
import torch
import time
import importlib
import json
import matplotlib.pyplot as plt

# Import your optimized transformer
import transformer
importlib.reload(transformer)  # reload in case we edit

from transformer import ActualOptimizedTransformer, FlashAttentionFallback, ActualChunkedAttention

In [None]:
def benchmark_model(model, seq_len, device):
    model.eval()
    input_ids = torch.randint(0, 5000, (1, seq_len), device=device)

    torch.cuda.reset_peak_memory_stats(device) if device.type == 'cuda' else None
    
    start = time.time()
    with torch.no_grad():
        _ = model(input_ids)
    elapsed = time.time() - start
    
    peak_mem = torch.cuda.max_memory_allocated(device) / (1024 ** 2) if device.type == 'cuda' else None
    return elapsed, peak_mem

In [None]:
# Devices to test
devices = ['cpu']
if torch.cuda.is_available():
    devices.append('cuda')

seq_lengths = [256, 512, 1024, 2048]

results = []

for device_name in devices:
    device = torch.device(device_name)
    print(f"\n=== DEVICE: {device} ===")
    
    for seq_len in seq_lengths:
        print(f"Seq Len {seq_len}")
        # Optimized model
        opt_model = ActualOptimizedTransformer(
            vocab_size=5000,
            d_model=256,
            n_heads=8,
            n_layers=4,
            chunk_size=128,
            use_4bit=True,
            use_flash_attention=True
        ).to(device)
        t_opt, m_opt = benchmark_model(opt_model, seq_len, device)
        
        # Vanilla baseline
        baseline = torch.nn.Transformer(
            d_model=256,
            nhead=8,
            num_encoder_layers=4,
            num_decoder_layers=4
        ).to(device)
        src = torch.randn(seq_len, 1, 256, device=device)
        tgt = torch.randn(seq_len, 1, 256, device=device)
        
        torch.cuda.reset_peak_memory_stats(device) if device.type == 'cuda' else None
        start = time.time()
        with torch.no_grad():
            _ = baseline(src, tgt)
        t_base = time.time() - start
        m_base = torch.cuda.max_memory_allocated(device) / (1024 ** 2) if device.type == 'cuda' else None
        
        results.append({
            'device': device_name,
            'seq_len': seq_len,
            'time_opt': t_opt,
            'mem_opt': m_opt,
            'time_base': t_base,
            'mem_base': m_base
        })

In [None]:
# Save results to file
with open('bench_results.json', 'w') as f:
    json.dump(results, f, indent=2)

results

In [None]:
# Plotting
for device_name in devices:
    times_opt = [r['time_opt'] for r in results if r['device'] == device_name]
    times_base = [r['time_base'] for r in results if r['device'] == device_name]
    
    plt.figure(figsize=(8,4))
    plt.plot(seq_lengths, times_opt, label='Optimized')
    plt.plot(seq_lengths, times_base, label='Baseline')
    plt.title(f'Inference Time ({device_name})')
    plt.xlabel('Sequence Length')
    plt.ylabel('Time (s)')
    plt.legend()
    plt.show()
    
    if torch.cuda.is_available():
        mem_opt = [r['mem_opt'] for r in results if r['device'] == device_name]
        mem_base = [r['mem_base'] for r in results if r['device'] == device_name]
        plt.figure(figsize=(8,4))
        plt.plot(seq_lengths, mem_opt, label='Optimized')
        plt.plot(seq_lengths, mem_base, label='Baseline')
        plt.title(f'Peak GPU Memory ({device_name})')
        plt.xlabel('Sequence Length')
        plt.ylabel('Memory (MB)')
        plt.legend()
        plt.show()

In [None]:
# Check model size breakdown
opt_model = ActualOptimizedTransformer(
    vocab_size=5000,
    d_model=256,
    n_heads=8,
    n_layers=4,
    chunk_size=128,
    use_4bit=True,
    use_flash_attention=True
).to('cpu')
size_info = opt_model.get_real_model_size()
size_info