You must restart kernel after you install the cell below. Go to Kernel -> Restart Kernel

In [None]:
!pip install --upgrade datasets

Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-21.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets)
  Downloading pandas-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting requests>=2.32.2 (from datasets)
  Downloading requests-2.32.4-py3-none-any.whl.metadata (4.9 kB)
Collecting tqdm>=4.66.3 (from datasets)
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.7/57.7 kB[0m [31m33.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2

Now that you installed the cell above, restart the kernel. Go to Kernel -> Restart Kernel

In [None]:
!pip install matplotlib transformers seaborn

Collecting seaborn
  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Downloading seaborn-0.13.2-py3-none-any.whl (294 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.9/294.9 kB[0m [31m839.9 kB/s[0m eta [36m0:00:00[0m [36m0:00:01[0mm
[?25hInstalling collected packages: seaborn
Successfully installed seaborn-0.13.2
[0m

## Learning rate search for Muon and AdamW

In [None]:
!pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Looking in indexes: https://download.pytorch.org/whl/cu128
Collecting torch==2.7.0
  Downloading https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.8.61 (from torch==2.7.0)
  Downloading https://download.pytorch.org/whl/cu128/nvidia_cuda_nvrtc_cu12-12.8.61-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-runtime-cu12==12.8.57 (from torch==2.7.0)
  Downloading https://download.pytorch.org/whl/cu128/nvidia_cuda_runtime_cu12-12.8.57-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-cupti-cu12==12.8.57 (from torch==2.7.0)
  Downloading https://download.pytorch.org/whl/cu128/nvidia_cuda_cupti_cu12-12.8.57-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cudnn-cu12==9.7.1.26 (from torch==2.7.0)
  Downloading https://download.pytorch.org/whl/cu128/nvidia_cud

You will possibly need to reset the kernel again after installing torch, make sure the cell above says 2.7.0+cu128, otherwise reset the kernel again

In [None]:
import torch
print(torch.__version__)
# restart kernel

2.7.0+cu128


Cell below was run on T4, which is the free Google Colab GPU, so it's small enough for that, if you are running it on a stronger GPU, you can ask Claude Sonnet to increase size of the ablations or models

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
import math
import random
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import json
import time
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
import subprocess
import sys
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import warnings
import os
from datetime import datetime
import seaborn as sns
warnings.filterwarnings('ignore')

# Install required packages
def install_packages():
    """Install required packages for Colab"""
    packages = ['datasets', 'transformers', 'accelerate', 'scipy', 'seaborn']
    for package in packages:
        try:
            __import__(package)
        except ImportError:
            print(f"Installing {package}...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])

install_packages()

from scipy import stats

def set_seed(seed: int = 42):
    """Set all random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"🌱 Set all seeds to {seed}")

@dataclass
class ExperimentConfig:
    # Model architecture (FIXED)
    d_model: int = 256
    n_heads: int = 8
    n_layers: int = 4
    d_ff: int = 1024
    max_seq_len: int = 256

    # Training (FIXED)
    batch_size: int = 32  # Increased for faster training
    max_steps: int = 600  # Short but enough to see trends
    eval_every: int = 100

    # Data (FIXED)
    num_documents: int = 200
    max_tokens: int = 30000
    vocab_size: Optional[int] = None

    # Optimization (VARIABLE - this is what we're testing)
    learning_rate: float = 1e-3  # This will be varied
    weight_decay: float = 0.01   # Fixed low value
    grad_clip: float = 1.0

    # System (FIXED)
    use_amp: bool = True

    def __post_init__(self):
        self.d_k = self.d_model // self.n_heads
        assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"

@torch.compile
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5) -> torch.Tensor:
    """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G."""
    assert G.ndim >= 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()

    if G.size(-2) > G.size(-1):
        X = X.mT

    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)

    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A
        X = a * X + B @ X

    if G.size(-2) > G.size(-1):
        X = X.mT

    return X

class Muon(torch.optim.Optimizer):
    """Muon - MomentUm Orthogonalized by Newton-schulz"""
    def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                g = p.grad
                state = self.state[p]

                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = torch.zeros_like(g)

                buf = state["momentum_buffer"]
                buf.lerp_(g, 1 - group["momentum"])
                g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
                g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
                p.add_(g.view_as(p), alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)

def load_data(config: ExperimentConfig):
    """Load and tokenize data"""
    print(f"Loading dataset...")
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M", token=False)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    dataset = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", split="train", streaming=True, token=False)

    texts = []
    for i, item in enumerate(dataset):
        if i >= config.num_documents:
            break
        texts.append(item["text"][:1500])  # Shorter texts for faster processing

    print(f"Loaded {len(texts)} documents")
    config.vocab_size = tokenizer.vocab_size

    return texts, tokenizer

class TextTokenDataset(Dataset):
    def __init__(self, texts: List[str], tokenizer, seq_len: int = 256, max_tokens: int = 30000):
        self.tokenizer = tokenizer
        self.seq_len = seq_len

        print("Tokenizing texts...")
        all_tokens = []
        for text in tqdm(texts, desc="Tokenizing"):
            tokens = tokenizer.encode(text, add_special_tokens=False)
            all_tokens.extend(tokens)

        self.tokens = all_tokens[:max_tokens]
        print(f"Using {len(self.tokens):,} tokens")

    def __len__(self):
        return max(0, len(self.tokens) - self.seq_len)

    def __getitem__(self, idx):
        x = torch.tensor(self.tokens[idx:idx + self.seq_len], dtype=torch.long)
        y = torch.tensor(self.tokens[idx + 1:idx + self.seq_len + 1], dtype=torch.long)
        return x, y

class Rotary(nn.Module):
    def __init__(self, dim: int, max_seq_len: int):
        super().__init__()
        angular_freq = (1 / 10000) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
        angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
        t = torch.arange(max_seq_len, dtype=torch.float32)
        theta = torch.einsum("i,j -> ij", t, angular_freq)
        self.register_buffer('cos', theta.cos(), persistent=False)
        self.register_buffer('sin', theta.sin(), persistent=False)

    def forward(self, x_BTHD: torch.Tensor):
        assert self.cos.size(0) >= x_BTHD.size(-3)
        cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
        x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
        y1 = x1 * cos + x2 * sin
        y2 = x1 * (-sin) + x2 * cos
        return torch.cat((y1, y2), 3).type_as(x_BTHD)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, max_seq_len: int):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
        self.w_o = nn.Linear(d_model, d_model, bias=False)
        self.rotary = Rotary(self.d_k, max_seq_len)

    def forward(self, x):
        batch_size, seq_len = x.size(0), x.size(1)

        qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.n_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        Q, K, V = qkv[0], qkv[1], qkv[2]

        Q = self.rotary(Q)
        K = self.rotary(K)

        attn_output = F.scaled_dot_product_attention(Q, K, V, is_causal=True, dropout_p=0.0)
        attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
        return self.w_o(attn_output)

class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff, bias=False)
        self.linear2 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x):
        return self.linear2(F.silu(self.linear1(x)))

class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, max_seq_len: int):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads, max_seq_len)
        self.feed_forward = FeedForward(d_model, d_ff)
        self.norm1 = nn.RMSNorm(d_model)
        self.norm2 = nn.RMSNorm(d_model)

    def forward(self, x):
        attn_out = self.attention(self.norm1(x))
        x = x + attn_out
        ff_out = self.feed_forward(self.norm2(x))
        x = x + ff_out
        return x

class MinimalLLM(nn.Module):
    def __init__(self, config: ExperimentConfig):
        super().__init__()
        self.config = config

        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(config.d_model, config.n_heads, config.d_ff, config.max_seq_len)
            for _ in range(config.n_layers)
        ])
        self.norm = nn.RMSNorm(config.d_model)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x):
        x = self.token_embedding(x) * math.sqrt(self.config.d_model)

        for block in self.transformer_blocks:
            x = block(x)

        x = self.norm(x)
        logits = F.linear(x, self.token_embedding.weight)
        return logits

class MetricsTracker:
    def __init__(self):
        self.metrics = {}

    def log_step(self, step: int, **kwargs):
        for key, value in kwargs.items():
            if key not in self.metrics:
                self.metrics[key] = []
            self.metrics[key].append((step, value))

def setup_optimizer(model: nn.Module, optimizer_type: str, learning_rate: float, config: ExperimentConfig):
    """Setup optimizer with specific learning rate"""

    if optimizer_type == 'muon':
        muon_params = []
        adamw_params = []

        for name, param in model.named_parameters():
            if (param.ndim == 2 and
                'token_embedding' not in name and
                'norm' not in name and
                param.requires_grad):
                muon_params.append(param)
            else:
                adamw_params.append(param)

        muon_optimizer = Muon(muon_params, lr=learning_rate, momentum=0.95)
        adamw_optimizer = torch.optim.AdamW(adamw_params, lr=learning_rate*0.1, weight_decay=config.weight_decay)

        return [muon_optimizer, adamw_optimizer]

    else:  # adamw
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=config.weight_decay)
        return [optimizer]

def evaluate_model(model: nn.Module, val_loader: DataLoader, config: ExperimentConfig) -> Dict:
    """Quick model evaluation"""
    model.eval()
    total_loss = 0
    total_tokens = 0
    total_correct = 0

    device = next(model.parameters()).device

    with torch.no_grad():
        for i, (x, y) in enumerate(val_loader):
            if i >= 3:  # Very quick eval
                break
            x, y = x.to(device), y.to(device)

            with autocast('cuda', enabled=config.use_amp):
                logits = model(x)
                loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))

            total_loss += loss.item() * y.numel()
            total_tokens += y.numel()

            predictions = logits.argmax(dim=-1)
            total_correct += (predictions == y).sum().item()

    avg_loss = total_loss / total_tokens
    accuracy = total_correct / total_tokens
    perplexity = math.exp(min(avg_loss, 20))

    model.train()

    return {
        'val_loss': avg_loss,
        'val_accuracy': accuracy,
        'val_perplexity': perplexity
    }

def train_with_learning_rate(optimizer_type: str, learning_rate: float, config: ExperimentConfig,
                           train_loader: DataLoader, val_loader: DataLoader) -> Dict:
    """Train model with specific learning rate"""

    print(f"🚀 {optimizer_type.upper()} LR={learning_rate}")

    # Initialize model
    set_seed(42)  # Same initialization for all runs
    model = MinimalLLM(config)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Setup optimizer with this learning rate
    optimizers = setup_optimizer(model, optimizer_type, learning_rate, config)

    # Setup schedulers (no decay for clean LR comparison)
    # schedulers = []
    # for optimizer in optimizers:
    #     scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
    #     schedulers.append(scheduler)

    scaler = GradScaler('cuda') if config.use_amp else None
    tracker = MetricsTracker()

    # Training loop
    model.train()
    step = 0
    start_time = time.time()

    # Track if training becomes unstable
    min_loss = float('inf')
    unstable_count = 0

    while step < config.max_steps:
        for batch_idx, (x, y) in enumerate(train_loader):
            if step >= config.max_steps:
                break

            x, y = x.to(device), y.to(device)

            # Zero gradients
            for optimizer in optimizers:
                optimizer.zero_grad()

            # Forward pass
            try:
                if config.use_amp:
                    with autocast('cuda'):
                        logits = model(x)
                        loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))

                    if torch.isnan(loss) or torch.isinf(loss):
                        print(f"  💥 NaN/Inf loss at step {step}")
                        return {'failed': True, 'reason': 'nan_loss', 'step': step}

                    scaler.scale(loss).backward()

                    # Unscale and clip gradients
                    for optimizer in optimizers:
                        scaler.unscale_(optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)

                    # Step optimizers
                    for optimizer in optimizers:
                        scaler.step(optimizer)

                    scaler.update()
                else:
                    logits = model(x)
                    loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))

                    if torch.isnan(loss) or torch.isinf(loss):
                        print(f"  💥 NaN/Inf loss at step {step}")
                        return {'failed': True, 'reason': 'nan_loss', 'step': step}

                    loss.backward()

                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)

                    for optimizer in optimizers:
                        optimizer.step()

            except RuntimeError as e:
                print(f"  💥 Runtime error at step {step}: {e}")
                return {'failed': True, 'reason': 'runtime_error', 'step': step}

            # Check for training instability
            if loss.item() > min_loss * 2 and step > 100:
                unstable_count += 1
                if unstable_count > 10:
                    print(f"  ⚠️ Training became unstable at step {step}")
                    return {'failed': True, 'reason': 'unstable', 'step': step}
            else:
                min_loss = min(min_loss, loss.item())
                unstable_count = 0

            # Log metrics
            if step % 50 == 0:
                with torch.no_grad():
                    predictions = logits.argmax(dim=-1)
                    accuracy = (predictions == y).float().mean().item()
                    perplexity = math.exp(min(loss.item(), 20))

                tracker.log_step(
                    step,
                    train_loss=loss.item(),
                    train_accuracy=accuracy,
                    train_perplexity=perplexity,
                    grad_norm=grad_norm.item(),
                    learning_rate=optimizers[0].param_groups[0]['lr']
                )

            # Evaluation
            if step % config.eval_every == 0 and step > 0:
                eval_metrics = evaluate_model(model, val_loader, config)
                for key, value in eval_metrics.items():
                    tracker.log_step(step, **{key: value})

            step += 1

    training_time = time.time() - start_time

    # Final evaluation
    final_eval = evaluate_model(model, val_loader, config)

    # Get final training loss
    final_train_loss = None
    if 'train_loss' in tracker.metrics and len(tracker.metrics['train_loss']) > 0:
        final_train_loss = tracker.metrics['train_loss'][-1][1]

    # Clean up
    del model
    torch.cuda.empty_cache()

    return {
        'failed': False,
        'tracker': tracker,
        'training_time': training_time,
        'final_metrics': final_eval,
        'final_train_loss': final_train_loss,
        'learning_rate': learning_rate
    }

def run_learning_rate_search():
    """Run comprehensive learning rate search"""

    print("🔍 COMPREHENSIVE LEARNING RATE SEARCH: MUON vs ADAMW")
    print("="*80)

    # Fixed configuration
    config = ExperimentConfig()

    # Learning rate ranges to test
    adamw_lrs = [1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3, 1e-2]
    muon_lrs = [0.001, 0.003, 0.005, 0.01, 0.015, 0.02, 0.03, 0.05]

    print(f"🏗️ Model Architecture: {config.d_model}d, {config.n_layers}L, {config.n_heads}H")
    print(f"📊 AdamW LRs: {adamw_lrs}")
    print(f"📊 Muon LRs: {muon_lrs}")
    print(f"⏱️ Training steps: {config.max_steps}")
    print(f"📦 Batch size: {config.batch_size}")

    # Load data once
    texts, tokenizer = load_data(config)
    dataset = TextTokenDataset(texts, tokenizer, config.max_seq_len, config.max_tokens)

    # Split data
    val_size = len(dataset) // 10
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

    # Run experiments
    all_results = {
        'adamw': {},
        'muon': {},
        'config': config
    }

    # Test AdamW
    print(f"\n{'='*50}")
    print("🔵 TESTING ADAMW")
    print(f"{'='*50}")

    for lr in adamw_lrs:
        result = train_with_learning_rate('adamw', lr, config, train_loader, val_loader)
        all_results['adamw'][lr] = result

        if result['failed']:
            print(f"  ❌ LR {lr}: FAILED ({result['reason']})")
        else:
            final_loss = result['final_train_loss'] if result['final_train_loss'] else "N/A"
            val_acc = result['final_metrics']['val_accuracy']
            print(f"  ✅ LR {lr}: Final Loss={final_loss}, Val Acc={val_acc:.3f}")

    # Test Muon
    print(f"\n{'='*50}")
    print("🔴 TESTING MUON")
    print(f"{'='*50}")

    for lr in muon_lrs:
        result = train_with_learning_rate('muon', lr, config, train_loader, val_loader)
        all_results['muon'][lr] = result

        if result['failed']:
            print(f"  ❌ LR {lr}: FAILED ({result['reason']})")
        else:
            final_loss = result['final_train_loss'] if result['final_train_loss'] else "N/A"
            val_acc = result['final_metrics']['val_accuracy']
            print(f"  ✅ LR {lr}: Final Loss={final_loss}, Val Acc={val_acc:.3f}")

    # Save and analyze results
    save_lr_search_results(all_results)

    return all_results

def save_lr_search_results(all_results: Dict):
    """Save learning rate search results"""

    # Create results directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = f"results/lr_search_{timestamp}"
    os.makedirs(results_dir, exist_ok=True)

    print(f"\n💾 Saving results to {results_dir}")

    # Generate plots
    generate_lr_plots(all_results, results_dir)

    # Generate report
    generate_lr_report(all_results, results_dir)

    # Save raw data
    save_lr_raw_data(all_results, results_dir)

    print(f"✅ All results saved to {results_dir}")

def generate_lr_plots(all_results: Dict, results_dir: str):
    """Generate learning rate analysis plots"""

    # Set style
    plt.style.use('default')
    sns.set_palette("husl")

    # 1. Learning Rate vs Final Metrics
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Learning Rate Sensitivity Analysis', fontsize=16, fontweight='bold')

    # Collect data for successful runs
    adamw_data = {'lrs': [], 'final_loss': [], 'val_loss': [], 'val_acc': [], 'val_ppl': []}
    muon_data = {'lrs': [], 'final_loss': [], 'val_loss': [], 'val_acc': [], 'val_ppl': []}

    for lr, result in all_results['adamw'].items():
        if not result['failed'] and result['final_train_loss'] is not None:
            adamw_data['lrs'].append(lr)
            adamw_data['final_loss'].append(result['final_train_loss'])
            adamw_data['val_loss'].append(result['final_metrics']['val_loss'])
            adamw_data['val_acc'].append(result['final_metrics']['val_accuracy'])
            adamw_data['val_ppl'].append(result['final_metrics']['val_perplexity'])

    for lr, result in all_results['muon'].items():
        if not result['failed'] and result['final_train_loss'] is not None:
            muon_data['lrs'].append(lr)
            muon_data['final_loss'].append(result['final_train_loss'])
            muon_data['val_loss'].append(result['final_metrics']['val_loss'])
            muon_data['val_acc'].append(result['final_metrics']['val_accuracy'])
            muon_data['val_ppl'].append(result['final_metrics']['val_perplexity'])

    # Final Training Loss
    axes[0,0].semilogx(adamw_data['lrs'], adamw_data['final_loss'], 'o-', label='AdamW', color='blue', linewidth=2, markersize=8)
    axes[0,0].semilogx(muon_data['lrs'], muon_data['final_loss'], 's-', label='Muon', color='red', linewidth=2, markersize=8)
    axes[0,0].set_title('Final Training Loss vs Learning Rate')
    axes[0,0].set_xlabel('Learning Rate')
    axes[0,0].set_ylabel('Final Training Loss')
    axes[0,0].set_yscale('log')
    axes[0,0].legend()
    axes[0,0].grid(True, alpha=0.3)

    # Validation Loss
    axes[0,1].semilogx(adamw_data['lrs'], adamw_data['val_loss'], 'o-', label='AdamW', color='blue', linewidth=2, markersize=8)
    axes[0,1].semilogx(muon_data['lrs'], muon_data['val_loss'], 's-', label='Muon', color='red', linewidth=2, markersize=8)
    axes[0,1].set_title('Validation Loss vs Learning Rate')
    axes[0,1].set_xlabel('Learning Rate')
    axes[0,1].set_ylabel('Validation Loss')
    axes[0,1].set_yscale('log')
    axes[0,1].legend()
    axes[0,1].grid(True, alpha=0.3)

    # Validation Accuracy
    axes[1,0].semilogx(adamw_data['lrs'], adamw_data['val_acc'], 'o-', label='AdamW', color='blue', linewidth=2, markersize=8)
    axes[1,0].semilogx(muon_data['lrs'], muon_data['val_acc'], 's-', label='Muon', color='red', linewidth=2, markersize=8)
    axes[1,0].set_title('Validation Accuracy vs Learning Rate')
    axes[1,0].set_xlabel('Learning Rate')
    axes[1,0].set_ylabel('Validation Accuracy')
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)

    # Validation Perplexity
    axes[1,1].semilogx(adamw_data['lrs'], adamw_data['val_ppl'], 'o-', label='AdamW', color='blue', linewidth=2, markersize=8)
    axes[1,1].semilogx(muon_data['lrs'], muon_data['val_ppl'], 's-', label='Muon', color='red', linewidth=2, markersize=8)
    axes[1,1].set_title('Validation Perplexity vs Learning Rate')
    axes[1,1].set_xlabel('Learning Rate')
    axes[1,1].set_ylabel('Validation Perplexity')
    axes[1,1].set_yscale('log')
    axes[1,1].legend()
    axes[1,1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{results_dir}/lr_sensitivity_analysis.png', dpi=300, bbox_inches='tight')
    plt.close()

    # 2. Training curves for best learning rates
    best_adamw_lr = adamw_data['lrs'][np.argmax(adamw_data['val_acc'])] if adamw_data['val_acc'] else None
    best_muon_lr = muon_data['lrs'][np.argmax(muon_data['val_acc'])] if muon_data['val_acc'] else None

    if best_adamw_lr and best_muon_lr:
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))
        fig.suptitle(f'Training Curves: Best LRs (AdamW: {best_adamw_lr}, Muon: {best_muon_lr})', fontsize=16, fontweight='bold')

        # Training loss
        adamw_result = all_results['adamw'][best_adamw_lr]
        muon_result = all_results['muon'][best_muon_lr]

        if 'train_loss' in adamw_result['tracker'].metrics:
            steps, losses = zip(*adamw_result['tracker'].metrics['train_loss'])
            axes[0].plot(steps, losses, label='AdamW', color='blue', linewidth=2)

        if 'train_loss' in muon_result['tracker'].metrics:
            steps, losses = zip(*muon_result['tracker'].metrics['train_loss'])
            axes[0].plot(steps, losses, label='Muon', color='red', linewidth=2)

        axes[0].set_title('Training Loss')
        axes[0].set_xlabel('Steps')
        axes[0].set_ylabel('Loss')
        axes[0].set_yscale('log')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)

        # Validation loss
        if 'val_loss' in adamw_result['tracker'].metrics:
            steps, losses = zip(*adamw_result['tracker'].metrics['val_loss'])
            axes[1].plot(steps, losses, label='AdamW', color='blue', linewidth=2)

        if 'val_loss' in muon_result['tracker'].metrics:
            steps, losses = zip(*muon_result['tracker'].metrics['val_loss'])
            axes[1].plot(steps, losses, label='Muon', color='red', linewidth=2)

        axes[1].set_title('Validation Loss')
        axes[1].set_xlabel('Steps')
        axes[1].set_ylabel('Loss')
        axes[1].set_yscale('log')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)

        # Validation accuracy
        if 'val_accuracy' in adamw_result['tracker'].metrics:
            steps, accs = zip(*adamw_result['tracker'].metrics['val_accuracy'])
            axes[2].plot(steps, accs, label='AdamW', color='blue', linewidth=2)

        if 'val_accuracy' in muon_result['tracker'].metrics:
            steps, accs = zip(*muon_result['tracker'].metrics['val_accuracy'])
            axes[2].plot(steps, accs, label='Muon', color='red', linewidth=2)

        axes[2].set_title('Validation Accuracy')
        axes[2].set_xlabel('Steps')
        axes[2].set_ylabel('Accuracy')
        axes[2].legend()
        axes[2].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(f'{results_dir}/best_lr_training_curves.png', dpi=300, bbox_inches='tight')
        plt.close()

    # 3. Success/Failure visualization
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    fig.suptitle('Learning Rate Stability Analysis', fontsize=16, fontweight='bold')

    # AdamW stability
    adamw_lrs_all = list(all_results['adamw'].keys())
    adamw_success = [1 if not all_results['adamw'][lr]['failed'] else 0 for lr in adamw_lrs_all]

    axes[0].semilogx(adamw_lrs_all, adamw_success, 'o-', color='blue', linewidth=2, markersize=8)
    axes[0].set_title('AdamW Training Stability')
    axes[0].set_xlabel('Learning Rate')
    axes[0].set_ylabel('Success (1) / Failure (0)')
    axes[0].set_ylim(-0.1, 1.1)
    axes[0].grid(True, alpha=0.3)

    # Muon stability
    muon_lrs_all = list(all_results['muon'].keys())
    muon_success = [1 if not all_results['muon'][lr]['failed'] else 0 for lr in muon_lrs_all]

    axes[1].semilogx(muon_lrs_all, muon_success, 's-', color='red', linewidth=2, markersize=8)
    axes[1].set_title('Muon Training Stability')
    axes[1].set_xlabel('Learning Rate')
    axes[1].set_ylabel('Success (1) / Failure (0)')
    axes[1].set_ylim(-0.1, 1.1)
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{results_dir}/lr_stability_analysis.png', dpi=300, bbox_inches='tight')
    plt.close()

def generate_lr_report(all_results: Dict, results_dir: str):
    """Generate comprehensive learning rate report"""

    report = []
    report.append("=" * 80)
    report.append("COMPREHENSIVE LEARNING RATE SEARCH: MUON vs ADAMW")
    report.append("=" * 80)
    report.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    report.append("")

    config = all_results['config']
    report.append("EXPERIMENTAL SETUP")
    report.append("-" * 40)
    report.append(f"Model: {config.d_model}d, {config.n_layers}L, {config.n_heads}H, {config.d_ff}ff")
    report.append(f"Training steps: {config.max_steps}")
    report.append(f"Batch size: {config.batch_size}")
    report.append(f"Weight decay: {config.weight_decay}")
    report.append("")

    # Analyze AdamW results
    report.append("ADAMW RESULTS")
    report.append("-" * 40)

    adamw_successful = []
    adamw_failed = []

    for lr, result in all_results['adamw'].items():
        if result['failed']:
            adamw_failed.append((lr, result['reason']))
        else:
            adamw_successful.append((lr, result))

    report.append(f"Successful runs: {len(adamw_successful)}/{len(all_results['adamw'])}")
    report.append(f"Failed runs: {len(adamw_failed)}")

    if adamw_failed:
        report.append("Failed learning rates:")
        for lr, reason in adamw_failed:
            report.append(f"  LR {lr}: {reason}")
    report.append("")

    if adamw_successful:
        # Find best learning rate
        best_lr, best_result = max(adamw_successful, key=lambda x: x[1]['final_metrics']['val_accuracy'])
        report.append(f"BEST ADAMW LR: {best_lr}")
        report.append(f"  Final Training Loss: {best_result['final_train_loss']:.4f}")
        report.append(f"  Validation Loss: {best_result['final_metrics']['val_loss']:.4f}")
        report.append(f"  Validation Accuracy: {best_result['final_metrics']['val_accuracy']:.4f}")
        report.append(f"  Validation Perplexity: {best_result['final_metrics']['val_perplexity']:.2f}")
        report.append(f"  Training Time: {best_result['training_time']:.1f}s")
        report.append("")

        # Learning rate range analysis
        lrs = [x[0] for x in adamw_successful]
        val_accs = [x[1]['final_metrics']['val_accuracy'] for x in adamw_successful]
        report.append(f"LR Range Analysis:")
        report.append(f"  Working LR range: {min(lrs)} to {max(lrs)}")
        report.append(f"  Best val accuracy: {max(val_accs):.4f}")
        report.append(f"  Worst val accuracy: {min(val_accs):.4f}")
        report.append("")

    # Analyze Muon results
    report.append("MUON RESULTS")
    report.append("-" * 40)

    muon_successful = []
    muon_failed = []

    for lr, result in all_results['muon'].items():
        if result['failed']:
            muon_failed.append((lr, result['reason']))
        else:
            muon_successful.append((lr, result))

    report.append(f"Successful runs: {len(muon_successful)}/{len(all_results['muon'])}")
    report.append(f"Failed runs: {len(muon_failed)}")

    if muon_failed:
        report.append("Failed learning rates:")
        for lr, reason in muon_failed:
            report.append(f"  LR {lr}: {reason}")
    report.append("")

    if muon_successful:
        # Find best learning rate
        best_lr, best_result = max(muon_successful, key=lambda x: x[1]['final_metrics']['val_accuracy'])
        report.append(f"BEST MUON LR: {best_lr}")
        report.append(f"  Final Training Loss: {best_result['final_train_loss']:.4f}")
        report.append(f"  Validation Loss: {best_result['final_metrics']['val_loss']:.4f}")
        report.append(f"  Validation Accuracy: {best_result['final_metrics']['val_accuracy']:.4f}")
        report.append(f"  Validation Perplexity: {best_result['final_metrics']['val_perplexity']:.2f}")
        report.append(f"  Training Time: {best_result['training_time']:.1f}s")
        report.append("")

        # Learning rate range analysis
        lrs = [x[0] for x in muon_successful]
        val_accs = [x[1]['final_metrics']['val_accuracy'] for x in muon_successful]
        report.append(f"LR Range Analysis:")
        report.append(f"  Working LR range: {min(lrs)} to {max(lrs)}")
        report.append(f"  Best val accuracy: {max(val_accs):.4f}")
        report.append(f"  Worst val accuracy: {min(val_accs):.4f}")
        report.append("")

    # Comparison
    if adamw_successful and muon_successful:
        report.append("DIRECT COMPARISON")
        report.append("-" * 40)

        best_adamw = max(adamw_successful, key=lambda x: x[1]['final_metrics']['val_accuracy'])
        best_muon = max(muon_successful, key=lambda x: x[1]['final_metrics']['val_accuracy'])

        adamw_acc = best_adamw[1]['final_metrics']['val_accuracy']
        muon_acc = best_muon[1]['final_metrics']['val_accuracy']

        adamw_loss = best_adamw[1]['final_metrics']['val_loss']
        muon_loss = best_muon[1]['final_metrics']['val_loss']

        report.append(f"Best validation accuracy:")
        report.append(f"  AdamW (LR={best_adamw[0]}): {adamw_acc:.4f}")
        report.append(f"  Muon (LR={best_muon[0]}):  {muon_acc:.4f}")
        report.append(f"  Winner: {'Muon' if muon_acc > adamw_acc else 'AdamW'} by {abs(muon_acc - adamw_acc):.4f}")
        report.append("")

        report.append(f"Best validation loss:")
        report.append(f"  AdamW (LR={best_adamw[0]}): {adamw_loss:.4f}")
        report.append(f"  Muon (LR={best_muon[0]}):  {muon_loss:.4f}")
        report.append(f"  Winner: {'Muon' if muon_loss < adamw_loss else 'AdamW'} by {abs(muon_loss - adamw_loss):.4f}")
        report.append("")

        # Stability comparison
        adamw_stable_range = max(lrs for lrs, _ in adamw_successful) / min(lrs for lrs, _ in adamw_successful)
        muon_stable_range = max(lrs for lrs, _ in muon_successful) / min(lrs for lrs, _ in muon_successful)

        report.append(f"Learning rate stability:")
        report.append(f"  AdamW stable range: {adamw_stable_range:.1f}x")
        report.append(f"  Muon stable range: {muon_stable_range:.1f}x")
        report.append(f"  More stable: {'AdamW' if adamw_stable_range > muon_stable_range else 'Muon'}")
        report.append("")

    # Key findings
    report.append("KEY FINDINGS")
    report.append("-" * 40)

    if adamw_successful and muon_successful:
        best_adamw_acc = max(x[1]['final_metrics']['val_accuracy'] for x in adamw_successful)
        best_muon_acc = max(x[1]['final_metrics']['val_accuracy'] for x in muon_successful)

        if best_muon_acc > best_adamw_acc:
            report.append("✓ Muon achieves higher peak performance than AdamW")
        else:
            report.append("✗ AdamW achieves higher peak performance than Muon")

        if len(adamw_successful) > len(muon_successful):
            report.append("✓ AdamW is more stable across learning rates")
        elif len(muon_successful) > len(adamw_successful):
            report.append("✓ Muon is more stable across learning rates")
        else:
            report.append("≈ Both optimizers show similar stability")

    # Save report
    with open(f'{results_dir}/lr_search_report.txt', 'w') as f:
        f.write('\n'.join(report))

def save_lr_raw_data(all_results: Dict, results_dir: str):
    """Save raw learning rate search data"""

    # Prepare serializable data
    raw_data = {
        'config': {
            'd_model': all_results['config'].d_model,
            'n_layers': all_results['config'].n_layers,
            'n_heads': all_results['config'].n_heads,
            'd_ff': all_results['config'].d_ff,
            'max_steps': all_results['config'].max_steps,
            'batch_size': all_results['config'].batch_size,
        },
        'results': {}
    }

    for optimizer_type in ['adamw', 'muon']:
        raw_data['results'][optimizer_type] = {}
        for lr, result in all_results[optimizer_type].items():
            raw_data['results'][optimizer_type][str(lr)] = {
                'failed': result['failed'],
                'learning_rate': result['learning_rate'],
                'training_time': result.get('training_time', 0),
                'final_train_loss': result.get('final_train_loss'),
                'final_metrics': result.get('final_metrics', {}),
                'failure_reason': result.get('reason') if result['failed'] else None,
                'failure_step': result.get('step') if result['failed'] else None
            }

            # Add metrics if available
            if not result['failed'] and 'tracker' in result:
                raw_data['results'][optimizer_type][str(lr)]['metrics'] = result['tracker'].metrics

    with open(f'{results_dir}/lr_search_raw_data.json', 'w') as f:
        json.dump(raw_data, f, indent=2, default=str)

if __name__ == "__main__":
    # Check system
    print(f"🔍 Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name()}")
        print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

    # Set global seed
    set_seed(42)

    # Run learning rate search
    start_time = time.time()
    results = run_learning_rate_search()
    total_time = time.time() - start_time

    print(f"\n🎉 LEARNING RATE SEARCH COMPLETED IN {total_time/60:.1f} MINUTES!")
    print("✅ Results saved to 'results/' folder")
    print("📊 Check the generated plots and report for optimal learning rates")

🔍 Device: CUDA
GPU: Tesla T4
Memory: 15.8 GB
🌱 Set all seeds to 42
🔍 COMPREHENSIVE LEARNING RATE SEARCH: MUON vs ADAMW
🏗️ Model Architecture: 256d, 4L, 8H
📊 AdamW LRs: [1e-05, 3e-05, 0.0001, 0.0003, 0.001, 0.003, 0.01]
📊 Muon LRs: [0.001, 0.003, 0.005, 0.01, 0.015, 0.02, 0.03, 0.05]
⏱️ Training steps: 600
📦 Batch size: 32
Loading dataset...


Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

Loaded 200 documents
Tokenizing texts...



Tokenizing:   0%|          | 0/200 [00:00<?, ?it/s][A
Tokenizing:  27%|██▋       | 54/200 [00:00<00:00, 537.14it/s][A
Tokenizing:  62%|██████▏   | 123/200 [00:00<00:00, 624.66it/s][A
Tokenizing: 100%|██████████| 200/200 [00:00<00:00, 633.41it/s]


Using 30,000 tokens

🔵 TESTING ADAMW
🚀 ADAMW LR=1e-05
🌱 Set all seeds to 42
  ✅ LR 1e-05: Final Loss=8.459477424621582, Val Acc=0.070
🚀 ADAMW LR=3e-05
🌱 Set all seeds to 42
  ✅ LR 3e-05: Final Loss=6.6022467613220215, Val Acc=0.183
🚀 ADAMW LR=0.0001
🌱 Set all seeds to 42
  ✅ LR 0.0001: Final Loss=4.173043251037598, Val Acc=0.371
🚀 ADAMW LR=0.0003
🌱 Set all seeds to 42
  ✅ LR 0.0003: Final Loss=1.478623628616333, Val Acc=0.777
🚀 ADAMW LR=0.001
🌱 Set all seeds to 42
  ✅ LR 0.001: Final Loss=0.16272294521331787, Val Acc=0.968
🚀 ADAMW LR=0.003
🌱 Set all seeds to 42
  ✅ LR 0.003: Final Loss=0.12793155014514923, Val Acc=0.972
🚀 ADAMW LR=0.01
🌱 Set all seeds to 42
  ✅ LR 0.01: Final Loss=3.3094286918640137, Val Acc=0.284

🔴 TESTING MUON
🚀 MUON LR=0.001
🌱 Set all seeds to 42
  ✅ LR 0.001: Final Loss=3.733657121658325, Val Acc=0.430
🚀 MUON LR=0.003
🌱 Set all seeds to 42
  ✅ LR 0.003: Final Loss=0.5717266798019409, Val Acc=0.953
🚀 MUON LR=0.005
🌱 Set all seeds to 42
  ✅ LR 0.005: Final Loss=0.13

📊 Using optimal learning rates from LR search:
   - AdamW: 0.003
   - Muon:  0.01

## Model size ablations with best learning rate for both

Cell below was run on Nvidia GTX 4090, if you are running it on free Google Colab GPU or weaker GPU:
- Copy code into AI and ask it to estimate if you will run out of CUDA memory
- Tell AI (Claude Sonnet is good) to saver results after each step so if later experiments cause out of memory, your results are saved
- Possibly tell AI to reduce size

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import math
import random
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import json
import time
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
import subprocess
import sys
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import warnings
import os
import pickle
from datetime import datetime
import seaborn as sns
from scipy import stats
warnings.filterwarnings('ignore')

def set_seed(seed: int = 42):
    """Set all random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"🌱 Set all seeds to {seed}")

@dataclass
class ImprovedModelConfig:
    # Model architecture (required fields first)
    d_model: int
    n_heads: int
    n_layers: int
    d_ff: int
    batch_size: int
    max_steps: int

    # Training parameters - MUCH MORE AGGRESSIVE
    gradient_accumulation_steps: int = 4  # Simulate larger batches

    # Data parameters - LARGER DATASET
    max_seq_len: int = 512  # Longer sequences
    num_documents: int = 2000  # 5x more documents
    max_tokens: int = 500000  # 8x more tokens

    # Evaluation
    eval_every: int = 500  # Less frequent but more comprehensive
    eval_steps: int = 100  # More validation batches

    # Learning rates (from your search)
    adamw_lr: float = 0.003
    muon_lr: float = 0.01

    # Regularization
    weight_decay: float = 0.1  # Stronger regularization
    dropout: float = 0.1  # Add dropout
    grad_clip: float = 1.0

    # Technical
    use_amp: bool = True
    compile_model: bool = False
    vocab_size: Optional[int] = None

    def __post_init__(self):
        self.d_k = self.d_model // self.n_heads
        assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"

def improved_model_configs():
    """More challenging model configurations"""
    return {
        'Tiny': ImprovedModelConfig(
            d_model=192, n_heads=6, n_layers=4, d_ff=768,
            batch_size=32, max_steps=6000
        ),
        'Small': ImprovedModelConfig(
            d_model=384, n_heads=8, n_layers=6, d_ff=1536,
            batch_size=24, max_steps=5000
        ),
        'Medium': ImprovedModelConfig(
            d_model=512, n_heads=8, n_layers=8, d_ff=2048,
            batch_size=16, max_steps=4000
        ),
        'Large': ImprovedModelConfig(
            d_model=768, n_heads=16, n_layers=10, d_ff=3072,
            batch_size=12, max_steps=3000
        )
    }

@torch.compile
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5) -> torch.Tensor:
    """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G."""
    assert G.ndim >= 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()

    if G.size(-2) > G.size(-1):
        X = X.mT

    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)

    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A
        X = a * X + B @ X

    if G.size(-2) > G.size(-1):
        X = X.mT

    return X

class Muon(torch.optim.Optimizer):
    """Muon - MomentUm Orthogonalized by Newton-schulz"""
    def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                g = p.grad
                state = self.state[p]

                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = torch.zeros_like(g)

                buf = state["momentum_buffer"]
                buf.lerp_(g, 1 - group["momentum"])
                g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
                g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
                p.add_(g.view_as(p), alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)

def load_and_cache_data(config: ImprovedModelConfig, cache_dir: str = "data_cache"):
    """Load and cache tokenized data to avoid reprocessing"""

    os.makedirs(cache_dir, exist_ok=True)
    cache_file = f"{cache_dir}/tokenized_data_{config.num_documents}_{config.max_tokens}.pkl"

    # Check if cached data exists
    if os.path.exists(cache_file):
        print(f"📦 Loading cached data from {cache_file}")
        with open(cache_file, 'rb') as f:
            cached_data = pickle.load(f)

        texts = cached_data['texts']
        tokenizer = cached_data['tokenizer']
        tokens = cached_data['tokens']

        # Update vocab size in config
        config.vocab_size = tokenizer.vocab_size

        print(f"✅ Loaded {len(texts)} documents, {len(tokens):,} tokens from cache")
        return texts, tokenizer, tokens

    print(f"🔄 Processing new data (will cache for future use)")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M", token=False)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load dataset
    dataset = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", split="train", streaming=True, token=False)

    texts = []
    for i, item in enumerate(dataset):
        if i >= config.num_documents:
            break
        texts.append(item["text"][:3000])  # Longer text chunks

    print(f"Loaded {len(texts)} documents")

    # Tokenize
    print("Tokenizing texts...")
    all_tokens = []
    for text in tqdm(texts, desc="Tokenizing"):
        tokens = tokenizer.encode(text, add_special_tokens=False)
        all_tokens.extend(tokens)

    # Limit tokens
    tokens = all_tokens[:config.max_tokens]
    print(f"Using {len(tokens):,} tokens")

    # Update config
    config.vocab_size = tokenizer.vocab_size

    # Cache the processed data
    cached_data = {
        'texts': texts,
        'tokenizer': tokenizer,
        'tokens': tokens
    }

    with open(cache_file, 'wb') as f:
        pickle.dump(cached_data, f)

    print(f"💾 Cached data to {cache_file}")

    return texts, tokenizer, tokens

class TextTokenDataset(Dataset):
    def __init__(self, tokens: List[int], seq_len: int = 512):
        self.tokens = tokens
        self.seq_len = seq_len

    def __len__(self):
        return max(0, len(self.tokens) - self.seq_len)

    def __getitem__(self, idx):
        x = torch.tensor(self.tokens[idx:idx + self.seq_len], dtype=torch.long)
        y = torch.tensor(self.tokens[idx + 1:idx + self.seq_len + 1], dtype=torch.long)
        return x, y

class Rotary(nn.Module):
    def __init__(self, dim: int, max_seq_len: int):
        super().__init__()
        angular_freq = (1 / 10000) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
        angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
        t = torch.arange(max_seq_len, dtype=torch.float32)
        theta = torch.einsum("i,j -> ij", t, angular_freq)
        self.register_buffer('cos', theta.cos(), persistent=False)
        self.register_buffer('sin', theta.sin(), persistent=False)

    def forward(self, x_BTHD: torch.Tensor):
        assert self.cos.size(0) >= x_BTHD.size(-3)
        cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
        x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
        y1 = x1 * cos + x2 * sin
        y2 = x1 * (-sin) + x2 * cos
        return torch.cat((y1, y2), 3).type_as(x_BTHD)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, max_seq_len: int, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
        self.w_o = nn.Linear(d_model, d_model, bias=False)
        self.rotary = Rotary(self.d_k, max_seq_len)
        self.dropout = dropout

    def forward(self, x):
        batch_size, seq_len = x.size(0), x.size(1)

        qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.n_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        Q, K, V = qkv[0], qkv[1], qkv[2]

        Q = self.rotary(Q)
        K = self.rotary(K)

        attn_output = F.scaled_dot_product_attention(
            Q, K, V, is_causal=True, dropout_p=self.dropout if self.training else 0.0
        )
        attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
        return self.w_o(attn_output)

class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff, bias=False)
        self.linear2 = nn.Linear(d_ff, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.silu(self.linear1(x))))

class ImprovedTransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int,
                 max_seq_len: int, dropout: float = 0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads, max_seq_len, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.RMSNorm(d_model)
        self.norm2 = nn.RMSNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Pre-norm with dropout
        attn_out = self.attention(self.norm1(x))
        x = x + self.dropout(attn_out)

        ff_out = self.feed_forward(self.norm2(x))
        x = x + self.dropout(ff_out)
        return x

class ImprovedMinimalLLM(nn.Module):
    def __init__(self, config: ImprovedModelConfig):
        super().__init__()
        self.config = config

        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.position_dropout = nn.Dropout(config.dropout)

        self.transformer_blocks = nn.ModuleList([
            ImprovedTransformerBlock(
                config.d_model, config.n_heads, config.d_ff,
                config.max_seq_len, config.dropout
            ) for _ in range(config.n_layers)
        ])

        self.norm = nn.RMSNorm(config.d_model)
        self.output_dropout = nn.Dropout(config.dropout)

        # Tie weights
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.lm_head.weight = self.token_embedding.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x):
        x = self.token_embedding(x) * math.sqrt(self.config.d_model)
        x = self.position_dropout(x)

        for block in self.transformer_blocks:
            x = block(x)

        x = self.norm(x)
        x = self.output_dropout(x)
        logits = self.lm_head(x)
        return logits

class MetricsTracker:
    def __init__(self):
        self.metrics = {}
        self.memory_usage = []

    def log_step(self, step: int, **kwargs):
        for key, value in kwargs.items():
            if key not in self.metrics:
                self.metrics[key] = []
            self.metrics[key].append((step, value))

    def log_memory(self, step: int):
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1e9
            reserved = torch.cuda.memory_reserved() / 1e9
            self.memory_usage.append((step, allocated, reserved))

def setup_optimizer(model: nn.Module, optimizer_type: str, config: ImprovedModelConfig):
    """Setup optimizer with optimal learning rates"""

    if optimizer_type == 'muon':
        muon_params = []
        adamw_params = []

        for name, param in model.named_parameters():
            if (param.ndim == 2 and
                'token_embedding' not in name and
                'norm' not in name and
                param.requires_grad):
                muon_params.append(param)
            else:
                adamw_params.append(param)

        print(f"  Muon parameters: {sum(p.numel() for p in muon_params):,}")
        print(f"  AdamW parameters: {sum(p.numel() for p in adamw_params):,}")

        muon_optimizer = Muon(muon_params, lr=config.muon_lr, momentum=0.95)
        adamw_optimizer = torch.optim.AdamW(adamw_params, lr=config.muon_lr*0.1, weight_decay=config.weight_decay)

        return [muon_optimizer, adamw_optimizer]

    else:  # adamw
        optimizer = torch.optim.AdamW(model.parameters(), lr=config.adamw_lr, weight_decay=config.weight_decay)
        return [optimizer]

def comprehensive_evaluate_model(model: nn.Module, val_loader: DataLoader,
                               config: ImprovedModelConfig) -> Dict:
    """More comprehensive evaluation with multiple metrics"""
    model.eval()
    total_loss = 0
    total_tokens = 0
    total_correct = 0
    total_correct_top5 = 0

    device = next(model.parameters()).device

    with torch.no_grad():
        for i, (x, y) in enumerate(val_loader):
            if i >= config.eval_steps:  # More evaluation steps
                break
            x, y = x.to(device), y.to(device)

            with autocast(enabled=config.use_amp):
                logits = model(x)
                loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))

            total_loss += loss.item() * y.numel()
            total_tokens += y.numel()

            # Top-1 accuracy
            predictions = logits.argmax(dim=-1)
            total_correct += (predictions == y).sum().item()

            # Top-5 accuracy
            top5_predictions = logits.topk(5, dim=-1)[1]
            total_correct_top5 += (top5_predictions == y.unsqueeze(-1)).any(dim=-1).sum().item()

    avg_loss = total_loss / total_tokens
    accuracy = total_correct / total_tokens
    top5_accuracy = total_correct_top5 / total_tokens
    perplexity = math.exp(min(avg_loss, 20))

    model.train()

    return {
        'val_loss': avg_loss,
        'val_accuracy': accuracy,
        'val_top5_accuracy': top5_accuracy,
        'val_perplexity': perplexity
    }

def improved_train_model(optimizer_type: str, config: ImprovedModelConfig,
                        train_loader: DataLoader, val_loader: DataLoader,
                        model_name: str, run_id: int = 0) -> Tuple[MetricsTracker, Dict]:
    """Improved training with gradient accumulation and better evaluation"""

    print(f"\n🚀 Training {optimizer_type.upper()} on {model_name} (Run {run_id+1})")

    # Initialize with different seeds for multiple runs
    set_seed(42 + run_id * 1000)
    model = ImprovedMinimalLLM(config)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    if config.compile_model:
        try:
            model = torch.compile(model, mode='max-autotune')
            print("  ✅ Model compiled with max-autotune")
        except Exception as e:
            print(f"  ⚠️ Compilation failed: {e}")

    total_params = sum(p.numel() for p in model.parameters())
    print(f"  📊 Total parameters: {total_params:,}")

    # Setup optimizers
    optimizers = setup_optimizer(model, optimizer_type, config)

    # Improved learning rate schedule
    schedulers = []
    for optimizer in optimizers:
        warmup_steps = config.max_steps // 20  # 5% warmup
        def lr_lambda(step):
            if step < warmup_steps:
                return step / warmup_steps
            else:
                # Cosine annealing to 10% of peak
                progress = (step - warmup_steps) / (config.max_steps - warmup_steps)
                return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        schedulers.append(scheduler)

    scaler = GradScaler() if config.use_amp else None
    tracker = MetricsTracker()

    # Training state
    model.train()
    step = 0
    accumulated_loss = 0
    start_time = time.time()
    best_val_loss = float('inf')
    patience_counter = 0
    patience_limit = 2000  # Early stopping patience

    pbar = tqdm(total=config.max_steps, desc=f"{optimizer_type.upper()}")

    while step < config.max_steps:
        for batch_idx, (x, y) in enumerate(train_loader):
            if step >= config.max_steps:
                break

            x, y = x.to(device), y.to(device)

            # Forward pass with gradient accumulation
            if config.use_amp:
                with autocast():
                    logits = model(x)
                    loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))
                    loss = loss / config.gradient_accumulation_steps

                scaler.scale(loss).backward()
                accumulated_loss += loss.item()
            else:
                logits = model(x)
                loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))
                loss = loss / config.gradient_accumulation_steps
                loss.backward()
                accumulated_loss += loss.item()

            # Optimizer step after accumulation
            if (step + 1) % config.gradient_accumulation_steps == 0:
                if config.use_amp:
                    # Unscale and clip gradients
                    for optimizer in optimizers:
                        scaler.unscale_(optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)

                    # Step optimizers
                    for optimizer in optimizers:
                        scaler.step(optimizer)
                        optimizer.zero_grad()
                    for scheduler in schedulers:
                        scheduler.step()
                    scaler.update()
                else:
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
                    for optimizer in optimizers:
                        optimizer.step()
                        optimizer.zero_grad()
                    for scheduler in schedulers:
                        scheduler.step()

                # Reset accumulated loss
                accumulated_loss = 0

            # Logging
            if step % 50 == 0:
                with torch.no_grad():
                    predictions = logits.argmax(dim=-1)
                    accuracy = (predictions == y).float().mean().item()
                    current_loss = loss.item() * config.gradient_accumulation_steps
                    perplexity = math.exp(min(current_loss, 20))

                tracker.log_step(
                    step,
                    train_loss=current_loss,
                    train_accuracy=accuracy,
                    train_perplexity=perplexity,
                    grad_norm=grad_norm.item() if 'grad_norm' in locals() else 0,
                    learning_rate=optimizers[0].param_groups[0]['lr']
                )

                if step % 500 == 0:  # Less frequent memory logging
                    tracker.log_memory(step)

                pbar.set_postfix({
                    'loss': f'{current_loss:.4f}',
                    'acc': f'{accuracy:.3f}',
                    'ppl': f'{perplexity:.1f}',
                    'lr': f'{optimizers[0].param_groups[0]["lr"]:.2e}'
                })

            # Comprehensive evaluation
            if step % config.eval_every == 0 and step > 0:
                eval_metrics = comprehensive_evaluate_model(model, val_loader, config)
                for key, value in eval_metrics.items():
                    tracker.log_step(step, **{key: value})

                # Early stopping check
                if eval_metrics['val_loss'] < best_val_loss:
                    best_val_loss = eval_metrics['val_loss']
                    patience_counter = 0
                else:
                    patience_counter += config.eval_every

                if patience_counter >= patience_limit:
                    print(f"\n  🛑 Early stopping at step {step} (patience exceeded)")
                    break

            step += 1
            if step % 50 == 0:  # Update progress bar every 10 steps
                pbar.update(50)

    pbar.close()

    training_time = time.time() - start_time
    print(f"  ⏱️ Training completed in {training_time:.1f} seconds")

    # Final comprehensive evaluation
    final_eval = comprehensive_evaluate_model(model, val_loader, config)
    print(f"  📊 Final - Loss: {final_eval['val_loss']:.4f}, "
          f"Acc: {final_eval['val_accuracy']:.4f}, PPL: {final_eval['val_perplexity']:.2f}")

    # Cleanup
    del model
    torch.cuda.empty_cache()

    return tracker, {
        'training_time': training_time,
        'final_metrics': final_eval,
        'best_val_loss': best_val_loss,
        'total_params': total_params,
        'steps_completed': step
    }

def run_comprehensive_ablation(num_runs: int = 3):
    """Run multiple experiments and average results"""

    print(f"🚀 COMPREHENSIVE ABLATION: {num_runs} runs per configuration")
    print("="*80)
    print("📊 Using optimal learning rates:")
    print("   AdamW: 0.003")
    print("   Muon:  0.01")
    print("🔬 Enhanced with: gradient accumulation, dropout, longer training")
    print("="*80)

    model_configs = improved_model_configs()
    all_results = {}

    for model_name, config in model_configs.items():
        print(f"\n{'='*80}")
        print(f"🔬 TESTING {model_name.upper()} MODEL")
        print(f"   Architecture: {config.d_model}d, {config.n_layers}L, {config.n_heads}H, {config.d_ff}ff")
        print(f"   Training: {config.max_steps} steps, batch size {config.batch_size}")
        print(f"   Data: {config.max_tokens:,} tokens, seq_len {config.max_seq_len}")
        print(f"{'='*80}")

        # Load data once per model
        texts, tokenizer, tokens = load_and_cache_data(config)
        dataset = TextTokenDataset(tokens, config.max_seq_len)

        # Fixed train/val split
        val_size = len(dataset) // 10
        train_size = len(dataset) - val_size
        train_dataset, val_dataset = torch.utils.data.random_split(
            dataset, [train_size, val_size],
            generator=torch.Generator().manual_seed(42)
        )

        train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2)

        model_results = {'adamw': [], 'muon': []}

        # Multiple runs per optimizer
        for optimizer_type in ['adamw', 'muon']:
            for run_id in range(num_runs):
                print(f"\n📊 {optimizer_type.upper()} Run {run_id+1}/{num_runs}")

                tracker, run_results = improved_train_model(
                    optimizer_type, config, train_loader, val_loader,
                    model_name, run_id
                )

                model_results[optimizer_type].append({
                    'tracker': tracker,
                    **run_results
                })

        all_results[model_name] = {
            'config': config,
            'results': model_results
        }

    return all_results

def generate_comprehensive_plots(all_results: Dict, results_dir: str):
    """Generate comprehensive ablation plots with multiple runs"""

    plt.style.use('default')
    sns.set_palette("husl")

    model_names = list(all_results.keys())

    # 1. Training Loss Curves with Error Bars
    fig, axes = plt.subplots(2, len(model_names), figsize=(6*len(model_names), 10))
    if len(model_names) == 1:
        axes = axes.reshape(-1, 1)

    fig.suptitle('Training Metrics: Muon vs AdamW (Mean ± Std)', fontsize=16, fontweight='bold')

    for idx, (model_name, model_data) in enumerate(all_results.items()):
        # Training Loss
        ax_loss = axes[0, idx]
        ax_acc = axes[1, idx]

        for optimizer_type in ['adamw', 'muon']:
            runs = model_data['results'][optimizer_type]
            color = 'blue' if optimizer_type == 'adamw' else 'red'

            # Collect all runs data
            all_loss_curves = []
            all_acc_curves = []

            for run in runs:
                tracker = run['tracker']
                if 'train_loss' in tracker.metrics:
                    steps, losses = zip(*tracker.metrics['train_loss'])
                    all_loss_curves.append((steps, losses))
                if 'train_accuracy' in tracker.metrics:
                    steps, accs = zip(*tracker.metrics['train_accuracy'])
                    all_acc_curves.append((steps, accs))

            # Plot mean with std bands
            if all_loss_curves:
                # Find common steps
                min_len = min(len(curve[1]) for curve in all_loss_curves)
                common_steps = all_loss_curves[0][0][:min_len]
                loss_matrix = np.array([curve[1][:min_len] for curve in all_loss_curves])

                mean_loss = np.mean(loss_matrix, axis=0)
                std_loss = np.std(loss_matrix, axis=0)

                ax_loss.plot(common_steps, mean_loss, color=color, linewidth=2, label=f'{optimizer_type.upper()}')
                ax_loss.fill_between(common_steps, mean_loss - std_loss, mean_loss + std_loss,
                                   color=color, alpha=0.2)

            if all_acc_curves:
                min_len = min(len(curve[1]) for curve in all_acc_curves)
                common_steps = all_acc_curves[0][0][:min_len]
                acc_matrix = np.array([curve[1][:min_len] for curve in all_acc_curves])

                mean_acc = np.mean(acc_matrix, axis=0)
                std_acc = np.std(acc_matrix, axis=0)

                ax_acc.plot(common_steps, mean_acc, color=color, linewidth=2, label=f'{optimizer_type.upper()}')
                ax_acc.fill_between(common_steps, mean_acc - std_acc, mean_acc + std_acc,
                                  color=color, alpha=0.2)

        config = model_data['config']
        total_params = model_data['results']['adamw'][0]['total_params']

        ax_loss.set_title(f'{model_name} Training Loss\n({total_params:,} params)')
        ax_loss.set_xlabel('Steps')
        ax_loss.set_ylabel('Training Loss')
        ax_loss.set_yscale('log')
        ax_loss.legend()
        ax_loss.grid(True, alpha=0.3)

        ax_acc.set_title(f'{model_name} Training Accuracy')
        ax_acc.set_xlabel('Steps')
        ax_acc.set_ylabel('Training Accuracy')
        ax_acc.legend()
        ax_acc.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{results_dir}/training_curves_with_uncertainty.png', dpi=300, bbox_inches='tight')
    plt.close()

    # 2. Final Performance Comparison with Error Bars
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Final Performance: Muon vs AdamW (Mean ± Std)', fontsize=16, fontweight='bold')

    # Collect aggregated metrics
    model_params = []
    adamw_metrics = {'val_loss': [], 'val_accuracy': [], 'val_perplexity': [], 'training_time': []}
    muon_metrics = {'val_loss': [], 'val_accuracy': [], 'val_perplexity': [], 'training_time': []}
    adamw_stds = {'val_loss': [], 'val_accuracy': [], 'val_perplexity': [], 'training_time': []}
    muon_stds = {'val_loss': [], 'val_accuracy': [], 'val_perplexity': [], 'training_time': []}

    for model_name, model_data in all_results.items():
        model_params.append(model_data['results']['adamw'][0]['total_params'])

        for optimizer_type in ['adamw', 'muon']:
            runs = model_data['results'][optimizer_type]

            # Collect metrics from all runs
            val_losses = [run['final_metrics']['val_loss'] for run in runs]
            val_accs = [run['final_metrics']['val_accuracy'] for run in runs]
            val_ppls = [run['final_metrics']['val_perplexity'] for run in runs]
            times = [run['training_time'] for run in runs]

            metrics_dict = adamw_metrics if optimizer_type == 'adamw' else muon_metrics
            stds_dict = adamw_stds if optimizer_type == 'adamw' else muon_stds

            metrics_dict['val_loss'].append(np.mean(val_losses))
            metrics_dict['val_accuracy'].append(np.mean(val_accs))
            metrics_dict['val_perplexity'].append(np.mean(val_ppls))
            metrics_dict['training_time'].append(np.mean(times))

            stds_dict['val_loss'].append(np.std(val_losses))
            stds_dict['val_accuracy'].append(np.std(val_accs))
            stds_dict['val_perplexity'].append(np.std(val_ppls))
            stds_dict['training_time'].append(np.std(times))

    # Plot with error bars
    axes[0,0].errorbar(model_params, adamw_metrics['val_loss'], yerr=adamw_stds['val_loss'],
                      fmt='o-', label='AdamW', color='blue', linewidth=2, markersize=8, capsize=5)
    axes[0,0].errorbar(model_params, muon_metrics['val_loss'], yerr=muon_stds['val_loss'],
                      fmt='s-', label='Muon', color='red', linewidth=2, markersize=8, capsize=5)
    axes[0,0].set_title('Validation Loss vs Model Size')
    axes[0,0].set_xlabel('Parameters')
    axes[0,0].set_ylabel('Validation Loss')
    axes[0,0].set_xscale('log')
    axes[0,0].set_yscale('log')
    axes[0,0].legend()
    axes[0,0].grid(True, alpha=0.3)

    axes[0,1].errorbar(model_params, adamw_metrics['val_accuracy'], yerr=adamw_stds['val_accuracy'],
                      fmt='o-', label='AdamW', color='blue', linewidth=2, markersize=8, capsize=5)
    axes[0,1].errorbar(model_params, muon_metrics['val_accuracy'], yerr=muon_stds['val_accuracy'],
                      fmt='s-', label='Muon', color='red', linewidth=2, markersize=8, capsize=5)
    axes[0,1].set_title('Validation Accuracy vs Model Size')
    axes[0,1].set_xlabel('Parameters')
    axes[0,1].set_ylabel('Validation Accuracy')
    axes[0,1].set_xscale('log')
    axes[0,1].legend()
    axes[0,1].grid(True, alpha=0.3)

    axes[1,0].errorbar(model_params, adamw_metrics['val_perplexity'], yerr=adamw_stds['val_perplexity'],
                      fmt='o-', label='AdamW', color='blue', linewidth=2, markersize=8, capsize=5)
    axes[1,0].errorbar(model_params, muon_metrics['val_perplexity'], yerr=muon_stds['val_perplexity'],
                      fmt='s-', label='Muon', color='red', linewidth=2, markersize=8, capsize=5)
    axes[1,0].set_title('Validation Perplexity vs Model Size')
    axes[1,0].set_xlabel('Parameters')
    axes[1,0].set_ylabel('Validation Perplexity')
    axes[1,0].set_xscale('log')
    axes[1,0].set_yscale('log')
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)

    axes[1,1].errorbar(model_params, adamw_metrics['training_time'], yerr=adamw_stds['training_time'],
                      fmt='o-', label='AdamW', color='blue', linewidth=2, markersize=8, capsize=5)
    axes[1,1].errorbar(model_params, muon_metrics['training_time'], yerr=muon_stds['training_time'],
                      fmt='s-', label='Muon', color='red', linewidth=2, markersize=8, capsize=5)
    axes[1,1].set_title('Training Time vs Model Size')
    axes[1,1].set_xlabel('Parameters')
    axes[1,1].set_ylabel('Training Time (seconds)')
    axes[1,1].set_xscale('log')
    axes[1,1].legend()
    axes[1,1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{results_dir}/final_performance_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()

def generate_comprehensive_report(all_results: Dict, results_dir: str, num_runs: int):
    """Generate comprehensive ablation report with statistical analysis"""

    report = []
    report.append("=" * 80)
    report.append("COMPREHENSIVE MODEL SIZE ABLATION: MUON vs ADAMW")
    report.append("=" * 80)
    report.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    report.append(f"Number of runs per configuration: {num_runs}")
    report.append("")
    report.append("EXPERIMENTAL SETUP")
    report.append("-" * 40)
    report.append("Enhanced experimental setup with:")
    report.append("  • Optimal learning rates (AdamW: 0.003, Muon: 0.01)")
    report.append("  • Gradient accumulation (4 steps)")
    report.append("  • Dropout regularization (0.1)")
    report.append("  • Longer sequences (512 tokens)")
    report.append("  • Larger dataset (500k tokens, 2000 documents)")
    report.append("  • Extended training (6k-12k steps)")
    report.append("  • Multiple runs for statistical significance")
    report.append("")

    # Model configurations
    report.append("Model Configurations:")
    for model_name, model_data in all_results.items():
        config = model_data['config']
        total_params = model_data['results']['adamw'][0]['total_params']
        report.append(f"  {model_name}: {config.d_model}d, {config.n_layers}L, {config.n_heads}H, {config.d_ff}ff")
        report.append(f"    Parameters: {total_params:,}")
        report.append(f"    Training: {config.max_steps} steps, batch size {config.batch_size}")
    report.append("")

    # Detailed results for each model
    all_improvements = {'acc': [], 'loss': [], 'ppl': [], 'time': []}

    for model_name, model_data in all_results.items():
        report.append(f"{'='*60}")
        report.append(f"{model_name.upper()} MODEL RESULTS")
        report.append(f"{'='*60}")

        config = model_data['config']
        total_params = model_data['results']['adamw'][0]['total_params']

        report.append(f"Architecture: {config.d_model}d, {config.n_layers}L, {config.n_heads}H, {config.d_ff}ff")
        report.append(f"Parameters: {total_params:,}")
        report.append(f"Training steps: {config.max_steps}")
        report.append("")

        # Aggregate metrics across runs
        adamw_runs = model_data['results']['adamw']
        muon_runs = model_data['results']['muon']

        # Calculate means and stds
        adamw_metrics = {
            'val_loss': [run['final_metrics']['val_loss'] for run in adamw_runs],
            'val_accuracy': [run['final_metrics']['val_accuracy'] for run in adamw_runs],
            'val_perplexity': [run['final_metrics']['val_perplexity'] for run in adamw_runs],
            'training_time': [run['training_time'] for run in adamw_runs]
        }

        muon_metrics = {
            'val_loss': [run['final_metrics']['val_loss'] for run in muon_runs],
            'val_accuracy': [run['final_metrics']['val_accuracy'] for run in muon_runs],
            'val_perplexity': [run['final_metrics']['val_perplexity'] for run in muon_runs],
            'training_time': [run['training_time'] for run in muon_runs]
        }

        report.append("FINAL PERFORMANCE METRICS (Mean ± Std)")
        report.append("-" * 45)
        report.append(f"                    AdamW              Muon               Δ (p-value)")

        # Statistical tests and reporting
        for metric_key, metric_name in [('val_loss', 'Val Loss'), ('val_accuracy', 'Val Accuracy'),
                                       ('val_perplexity', 'Val Perplexity'), ('training_time', 'Training Time')]:
            adamw_values = adamw_metrics[metric_key]
            muon_values = muon_metrics[metric_key]

            adamw_mean, adamw_std = np.mean(adamw_values), np.std(adamw_values)
            muon_mean, muon_std = np.mean(muon_values), np.std(muon_values)

            # Statistical test
            if len(adamw_values) > 1 and len(muon_values) > 1:
                t_stat, p_value = stats.ttest_ind(muon_values, adamw_values)
                significance = "***" if p_value < 0.001 else "**" if p_value < 0.01 else "*" if p_value < 0.05 else ""
            else:
                p_value = 1.0
                significance = ""

            # Improvement calculation
            if metric_key in ['val_loss', 'val_perplexity', 'training_time']:
                improvement = (adamw_mean - muon_mean) / adamw_mean * 100  # Lower is better
            else:
                improvement = (muon_mean - adamw_mean) / adamw_mean * 100  # Higher is better

            report.append(f"{metric_name:15s}: {adamw_mean:6.4f}±{adamw_std:5.4f}   {muon_mean:6.4f}±{muon_std:5.4f}   {improvement:+6.1f}% (p={p_value:.3f}){significance}")

            # Store for overall analysis
            if metric_key == 'val_accuracy':
                all_improvements['acc'].append(improvement)
            elif metric_key == 'val_loss':
                all_improvements['loss'].append(improvement)
            elif metric_key == 'val_perplexity':
                all_improvements['ppl'].append(improvement)
            elif metric_key == 'training_time':
                all_improvements['time'].append(improvement)

        report.append("")

        # Winner determination
        muon_acc_mean = np.mean(muon_metrics['val_accuracy'])
        adamw_acc_mean = np.mean(adamw_metrics['val_accuracy'])

        if muon_acc_mean > adamw_acc_mean:
            report.append("🏆 WINNER: Muon (higher validation accuracy)")
        else:
            report.append("🏆 WINNER: AdamW (higher validation accuracy)")
        report.append("")

    # Overall analysis
    report.append("=" * 80)
    report.append("OVERALL ANALYSIS")
    report.append("=" * 80)

    # Count wins
    muon_wins = sum(1 for imp in all_improvements['acc'] if imp > 0)
    total_models = len(all_improvements['acc'])

    report.append("MUON vs ADAMW PERFORMANCE")
    report.append("-" * 35)
    report.append(f"Validation Accuracy Wins: Muon {muon_wins}/{total_models}, AdamW {total_models-muon_wins}/{total_models}")
    report.append("")

    report.append("AVERAGE IMPROVEMENTS (Muon vs AdamW)")
    report.append("-" * 40)

    for metric, values in all_improvements.items():
        metric_name = {'acc': 'Validation Accuracy', 'loss': 'Validation Loss',
                      'ppl': 'Validation Perplexity', 'time': 'Training Time'}[metric]

        mean_imp = np.mean(values)
        std_imp = np.std(values)

        # Statistical significance test
        if len(values) > 1:
            t_stat, p_value = stats.ttest_1samp(values, 0)
            significance = "***" if p_value < 0.001 else "**" if p_value < 0.01 else "*" if p_value < 0.05 else ""
        else:
            p_value = 1.0
            significance = ""

        report.append(f"{metric_name:20s}: {mean_imp:+6.2f}% ± {std_imp:5.2f}% (p={p_value:.3f}){significance}")

    report.append("")

    # Key findings
    report.append("KEY FINDINGS")
    report.append("-" * 20)

    acc_improvement = np.mean(all_improvements['acc'])
    loss_improvement = np.mean(all_improvements['loss'])
    time_improvement = np.mean(all_improvements['time'])

    if acc_improvement > 0:
        report.append("✓ Muon consistently outperforms AdamW in validation accuracy")
    else:
        report.append("✗ AdamW outperforms Muon in validation accuracy")

    if loss_improvement > 0:
        report.append("✓ Muon achieves lower validation loss than AdamW")
    else:
        report.append("✗ AdamW achieves lower validation loss than Muon")

    if abs(time_improvement) < 5:
        report.append("≈ Similar training times between optimizers")
    elif time_improvement > 0:
        report.append("⚡ Muon is faster than AdamW")
    else:
        report.append("⚠ Muon is slower than AdamW")

    # Statistical significance summary
    report.append("")
    report.append("STATISTICAL SIGNIFICANCE")
    report.append("-" * 30)

    significant_metrics = []
    for metric, values in all_improvements.items():
        if len(values) > 1:
            _, p_value = stats.ttest_1samp(values, 0)
            if p_value < 0.05:
                metric_name = {'acc': 'Accuracy', 'loss': 'Loss', 'ppl': 'Perplexity', 'time': 'Time'}[metric]
                significant_metrics.append(metric_name)

    if significant_metrics:
        report.append(f"Significant improvements (p < 0.05): {', '.join(significant_metrics)}")
    else:
        report.append("No statistically significant differences found")

    # Scaling behavior
    if len(all_improvements['acc']) >= 2:
        report.append("")
        report.append("SCALING BEHAVIOR")
        report.append("-" * 20)

        # Check correlation with model size
        model_params = [all_results[name]['results']['adamw'][0]['total_params'] for name in all_results.keys()]
        correlation, p_value = stats.pearsonr(model_params, all_improvements['acc'])

        if correlation > 0.5 and p_value < 0.05:
            report.append("📈 Muon's advantage significantly increases with model size")
        elif correlation < -0.5 and p_value < 0.05:
            report.append("📉 Muon's advantage significantly decreases with model size")
        else:
            report.append("📊 No clear scaling trend observed")

    # Save report
    with open(f'{results_dir}/comprehensive_ablation_report.txt', 'w') as f:
        f.write('\n'.join(report))

def save_comprehensive_results(all_results: Dict, results_dir: str, num_runs: int):
    """Save all results and generate comprehensive analysis"""

    # Create results directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    full_results_dir = f"{results_dir}/comprehensive_ablation_{timestamp}"
    os.makedirs(full_results_dir, exist_ok=True)

    print(f"\n💾 Saving comprehensive results to {full_results_dir}")

    # Generate plots
    generate_comprehensive_plots(all_results, full_results_dir)

    # Generate report
    generate_comprehensive_report(all_results, full_results_dir, num_runs)

    # Save raw data
    raw_data = {}
    for model_name, model_data in all_results.items():
        raw_data[model_name] = {
            'config': {
                'd_model': model_data['config'].d_model,
                'n_layers': model_data['config'].n_layers,
                'n_heads': model_data['config'].n_heads,
                'd_ff': model_data['config'].d_ff,
                'max_steps': model_data['config'].max_steps,
                'batch_size': model_data['config'].batch_size,
                'gradient_accumulation_steps': model_data['config'].gradient_accumulation_steps,
                'dropout': model_data['config'].dropout,
                'max_seq_len': model_data['config'].max_seq_len,
                'max_tokens': model_data['config'].max_tokens,
            },
            'results': {}
        }

        for optimizer_type in ['adamw', 'muon']:
            runs_data = []
            for run in model_data['results'][optimizer_type]:
                runs_data.append({
                    'training_time': run['training_time'],
                    'final_metrics': run['final_metrics'],
                    'best_val_loss': run['best_val_loss'],
                    'total_params': run['total_params'],
                    'steps_completed': run['steps_completed'],
                    'metrics': run['tracker'].metrics,
                    'memory_usage': run['tracker'].memory_usage
                })
            raw_data[model_name]['results'][optimizer_type] = runs_data

    with open(f'{full_results_dir}/comprehensive_raw_data.json', 'w') as f:
        json.dump(raw_data, f, indent=2, default=str)

    print(f"✅ Comprehensive results saved to {full_results_dir}")
    return full_results_dir

if __name__ == "__main__":
    # Check system
    print(f"🔍 Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name()}")
        print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

    # Set global seed
    set_seed(42)

    # Configuration
    NUM_RUNS = 2  # Number of runs per configuration

    print(f"\n🚀 Starting comprehensive model size ablation with {NUM_RUNS} runs per configuration")
    print("⏱️ Estimated time: 2-4 hours (depending on hardware)")
    print("💾 Results will be automatically saved with timestamps")

    # Run comprehensive ablation
    start_time = time.time()
    results = run_comprehensive_ablation(num_runs=NUM_RUNS)
    total_time = time.time() - start_time

    # Save results
    results_dir = save_comprehensive_results(results, "results", NUM_RUNS)

    print(f"\n🎉 COMPREHENSIVE ABLATION COMPLETED!")
    print(f"⏱️ Total time: {total_time/3600:.1f} hours")
    print(f"📊 Results saved to: {results_dir}")
    print("✅ Check the generated plots and comprehensive report for detailed analysis")
    print("🔬 All data cached for future analysis")

    # Quick summary
    print(f"\n📋 QUICK SUMMARY:")
    for model_name, model_data in results.items():
        adamw_acc = np.mean([run['final_metrics']['val_accuracy'] for run in model_data['results']['adamw']])
        muon_acc = np.mean([run['final_metrics']['val_accuracy'] for run in model_data['results']['muon']])
        improvement = (muon_acc - adamw_acc) / adamw_acc * 100
        winner = "Muon" if muon_acc > adamw_acc else "AdamW"
        print(f"  {model_name}: {winner} wins ({improvement:+.2f}% accuracy improvement)")

🔍 Device: CUDA
GPU: NVIDIA GeForce RTX 4090
Memory: 25.3 GB
🌱 Set all seeds to 42

🚀 Starting comprehensive model size ablation with 2 runs per configuration
⏱️ Estimated time: 2-4 hours (depending on hardware)
💾 Results will be automatically saved with timestamps
🚀 COMPREHENSIVE ABLATION: 2 runs per configuration
📊 Using optimal learning rates:
   AdamW: 0.003
   Muon:  0.01
🔬 Enhanced with: gradient accumulation, dropout, longer training

🔬 TESTING TINY MODEL
   Architecture: 192d, 4L, 6H, 768ff
   Training: 6000 steps, batch size 32
   Data: 500,000 tokens, seq_len 512
📦 Loading cached data from data_cache/tokenized_data_2000_500000.pkl
✅ Loaded 2000 documents, 500,000 tokens from cache

📊 ADAMW Run 1/2

🚀 Training ADAMW on Tiny (Run 1)
🌱 Set all seeds to 42
  📊 Total parameters: 11,208,384


ADAMW: 100%|██████████| 6000/6000 [04:34<00:00, 21.84it/s, loss=1.7671, acc=0.571, ppl=5.9, lr=2.72e-03]    

  ⏱️ Training completed in 274.8 seconds





  📊 Final - Loss: 0.8373, Acc: 0.7875, PPL: 2.31

📊 ADAMW Run 2/2

🚀 Training ADAMW on Tiny (Run 2)
🌱 Set all seeds to 1042
  📊 Total parameters: 11,208,384


ADAMW: 100%|██████████| 6000/6000 [04:33<00:00, 21.97it/s, loss=1.7058, acc=0.576, ppl=5.5, lr=2.72e-03]    

  ⏱️ Training completed in 273.1 seconds





  📊 Final - Loss: 0.7851, Acc: 0.8006, PPL: 2.19

📊 MUON Run 1/2

🚀 Training MUON on Tiny (Run 1)
🌱 Set all seeds to 42
  📊 Total parameters: 11,208,384
  Muon parameters: 1,769,472
  AdamW parameters: 9,438,912


MUON: 100%|██████████| 6000/6000 [04:45<00:00, 21.00it/s, loss=1.8745, acc=0.569, ppl=6.5, lr=9.07e-03]    

  ⏱️ Training completed in 285.7 seconds





  📊 Final - Loss: 0.9444, Acc: 0.7830, PPL: 2.57

📊 MUON Run 2/2

🚀 Training MUON on Tiny (Run 2)
🌱 Set all seeds to 1042
  📊 Total parameters: 11,208,384
  Muon parameters: 1,769,472
  AdamW parameters: 9,438,912


MUON: 100%|██████████| 6000/6000 [04:41<00:00, 21.31it/s, loss=1.8783, acc=0.568, ppl=6.5, lr=9.07e-03]    

  ⏱️ Training completed in 281.5 seconds





  📊 Final - Loss: 0.9423, Acc: 0.7841, PPL: 2.57

🔬 TESTING SMALL MODEL
   Architecture: 384d, 6L, 8H, 1536ff
   Training: 5000 steps, batch size 24
   Data: 500,000 tokens, seq_len 512
📦 Loading cached data from data_cache/tokenized_data_2000_500000.pkl
✅ Loaded 2000 documents, 500,000 tokens from cache

📊 ADAMW Run 1/2

🚀 Training ADAMW on Small (Run 1)
🌱 Set all seeds to 42
  📊 Total parameters: 29,496,192


ADAMW: 100%|██████████| 5000/5000 [04:28<00:00, 18.65it/s, loss=0.5567, acc=0.850, ppl=1.7, lr=2.72e-03]    

  ⏱️ Training completed in 268.2 seconds





  📊 Final - Loss: 0.1948, Acc: 0.9494, PPL: 1.22

📊 ADAMW Run 2/2

🚀 Training ADAMW on Small (Run 2)
🌱 Set all seeds to 1042
  📊 Total parameters: 29,496,192


ADAMW: 100%|██████████| 5000/5000 [04:28<00:00, 18.63it/s, loss=0.6559, acc=0.826, ppl=1.9, lr=2.72e-03]    

  ⏱️ Training completed in 268.3 seconds





  📊 Final - Loss: 0.1930, Acc: 0.9497, PPL: 1.21

📊 MUON Run 1/2

🚀 Training MUON on Small (Run 1)
🌱 Set all seeds to 42
  📊 Total parameters: 29,496,192
  Muon parameters: 10,616,832
  AdamW parameters: 18,879,360


MUON: 100%|██████████| 5000/5000 [04:37<00:00, 17.99it/s, loss=0.5100, acc=0.874, ppl=1.7, lr=9.07e-03]    

  ⏱️ Training completed in 277.9 seconds





  📊 Final - Loss: 0.1582, Acc: 0.9626, PPL: 1.17

📊 MUON Run 2/2

🚀 Training MUON on Small (Run 2)
🌱 Set all seeds to 1042
  📊 Total parameters: 29,496,192
  Muon parameters: 10,616,832
  AdamW parameters: 18,879,360


MUON: 100%|██████████| 5000/5000 [04:38<00:00, 17.98it/s, loss=0.5751, acc=0.858, ppl=1.8, lr=9.07e-03]    

  ⏱️ Training completed in 278.1 seconds





  📊 Final - Loss: 0.1593, Acc: 0.9623, PPL: 1.17

🔬 TESTING MEDIUM MODEL
   Architecture: 512d, 8L, 8H, 2048ff
   Training: 4000 steps, batch size 16
   Data: 500,000 tokens, seq_len 512
📦 Loading cached data from data_cache/tokenized_data_2000_500000.pkl
✅ Loaded 2000 documents, 500,000 tokens from cache

📊 ADAMW Run 1/2

🚀 Training ADAMW on Medium (Run 1)
🌱 Set all seeds to 42
  📊 Total parameters: 50,340,352


ADAMW: 100%|██████████| 4000/4000 [03:23<00:00, 19.68it/s, loss=0.7908, acc=0.790, ppl=2.2, lr=2.72e-03]   

  ⏱️ Training completed in 203.3 seconds





  📊 Final - Loss: 0.3210, Acc: 0.9143, PPL: 1.38

📊 ADAMW Run 2/2

🚀 Training ADAMW on Medium (Run 2)
🌱 Set all seeds to 1042
  📊 Total parameters: 50,340,352


ADAMW: 100%|██████████| 4000/4000 [03:22<00:00, 19.73it/s, loss=0.7070, acc=0.812, ppl=2.0, lr=2.72e-03]   

  ⏱️ Training completed in 202.7 seconds





  📊 Final - Loss: 0.2986, Acc: 0.9204, PPL: 1.35

📊 MUON Run 1/2

🚀 Training MUON on Medium (Run 1)
🌱 Set all seeds to 42
  📊 Total parameters: 50,340,352
  Muon parameters: 25,165,824
  AdamW parameters: 25,174,528


MUON: 100%|██████████| 4000/4000 [03:31<00:00, 18.91it/s, loss=0.4998, acc=0.882, ppl=1.6, lr=9.08e-03]    

  ⏱️ Training completed in 211.5 seconds





  📊 Final - Loss: 0.1644, Acc: 0.9604, PPL: 1.18

📊 MUON Run 2/2

🚀 Training MUON on Medium (Run 2)
🌱 Set all seeds to 1042
  📊 Total parameters: 50,340,352
  Muon parameters: 25,165,824
  AdamW parameters: 25,174,528


MUON: 100%|██████████| 4000/4000 [03:31<00:00, 18.91it/s, loss=0.4573, acc=0.890, ppl=1.6, lr=9.08e-03]    

  ⏱️ Training completed in 211.5 seconds





  📊 Final - Loss: 0.1575, Acc: 0.9624, PPL: 1.17

🔬 TESTING LARGE MODEL
   Architecture: 768d, 10L, 16H, 3072ff
   Training: 3000 steps, batch size 12
   Data: 500,000 tokens, seq_len 512
📦 Loading cached data from data_cache/tokenized_data_2000_500000.pkl
✅ Loaded 2000 documents, 500,000 tokens from cache

📊 ADAMW Run 1/2

🚀 Training ADAMW on Large (Run 1)
🌱 Set all seeds to 42
  📊 Total parameters: 108,543,744


ADAMW: 100%|██████████| 3000/3000 [03:25<00:00, 14.61it/s, loss=4.1460, acc=0.249, ppl=63.2, lr=2.73e-03]  

  ⏱️ Training completed in 205.3 seconds





  📊 Final - Loss: 3.8141, Acc: 0.2791, PPL: 45.33

📊 ADAMW Run 2/2

🚀 Training ADAMW on Large (Run 2)
🌱 Set all seeds to 1042
  📊 Total parameters: 108,543,744


ADAMW: 100%|██████████| 3000/3000 [03:25<00:00, 14.62it/s, loss=4.1773, acc=0.247, ppl=65.2, lr=2.73e-03]  

  ⏱️ Training completed in 205.2 seconds





  📊 Final - Loss: 3.7074, Acc: 0.2878, PPL: 40.75

📊 MUON Run 1/2

🚀 Training MUON on Large (Run 1)
🌱 Set all seeds to 42
  📊 Total parameters: 108,543,744
  Muon parameters: 70,778,880
  AdamW parameters: 37,764,864


MUON: 100%|██████████| 3000/3000 [03:35<00:00, 13.91it/s, loss=0.5369, acc=0.858, ppl=1.7, lr=9.09e-03]    

  ⏱️ Training completed in 215.7 seconds





  📊 Final - Loss: 0.2260, Acc: 0.9453, PPL: 1.25

📊 MUON Run 2/2

🚀 Training MUON on Large (Run 2)
🌱 Set all seeds to 1042
  📊 Total parameters: 108,543,744
  Muon parameters: 70,778,880
  AdamW parameters: 37,764,864


MUON: 100%|██████████| 3000/3000 [03:35<00:00, 13.90it/s, loss=0.4284, acc=0.901, ppl=1.5, lr=9.09e-03]    

  ⏱️ Training completed in 215.8 seconds





  📊 Final - Loss: 0.2161, Acc: 0.9475, PPL: 1.24

💾 Saving comprehensive results to results/comprehensive_ablation_20250721_205356
✅ Comprehensive results saved to results/comprehensive_ablation_20250721_205356

🎉 COMPREHENSIVE ABLATION COMPLETED!
⏱️ Total time: 1.1 hours
📊 Results saved to: results/comprehensive_ablation_20250721_205356
✅ Check the generated plots and comprehensive report for detailed analysis
🔬 All data cached for future analysis

📋 QUICK SUMMARY:
  Tiny: AdamW wins (-1.33% accuracy improvement)
  Small: Muon wins (+1.36% accuracy improvement)
  Medium: Muon wins (+4.80% accuracy improvement)
  Large: Muon wins (+233.88% accuracy improvement)


### This code just fixes a bug when uploading this notebook to github, you can ignore it otherwise.

In [1]:
from IPython.display import display
from ipywidgets import Widget

# Close all widgets and clear their state
Widget.close_all()
