# Code & Train GPT-5 From Scratch - Step by Step

> **To train the model just run the cells below!**

---

GPT-5 is not open sourced, however, we can make an educated guess on how it's built:
- **GPT architecture + latest advancements in LLM pretraining**

If you need a reminder on GPT or base LLM architecture check:

---

🎓 **[🦙 LLaMA 4 From Scratch (first 2h 30min)](https://youtu.be/wcDV3l4CD14)**

> In the first **2h 30min**, I give a **clear and intuitive explanation** of both:

* 🧠 Attention Mechanism
* 🧩 Tokens & Tokenizers

Highly recommended if you're just starting out or want a solid refresher.

---

🎥 **[📘 GPT From Scratch by Andrej Karpathy](https://youtu.be/kCc8FmEb1nY)**

> A legendary tutorial from Karpathy that walks through building a GPT model from scratch. Great for understanding the fundamentals!
- If Andrej's course is too complex or difficult, then return to it after watching few more of the videos / courses below that you find more digestible.

---

## 🧠 Modern GPT architecture with latest advancements

⚙️ For attention mechanism there are 3 options:
1. Option 1: OpenAI invented its own attention mechanism, in this case it is almost certainly not too different from others, unless they compretely surpassed transformer architecture, which is unlikely to happen until 2026 or 2027. So their architecture is likely very similar to option 2 - GQA.

2. Option 2: Grouped-Query Attention (GQA) - memory and compute efficienct attention
- Explained in my course [Code & Train Qwen 3 From Scratch - Full Course](https://youtu.be/wM-KP_wNAeY) starting at 16:35

3. Option 3 (less likely): Multihead Latent Attention (MLA)
- Given that some of the latest LLMs are using Multihead Latent Attention (MLA) by DeepSeek, there is some chance GPT-5 also uses it instead of the GQA - you can learn how to code it in my [Code DeepSeek From Scratch Course](https://youtu.be/TfEG0TwueTs)
- Advantage of MLA over GQA is lower memory usage, however it's a bit more complex to build.
- I recommend learning all of these as it will help you understand neural networks, transformers and how to surpass transformers and invent the next architecture.

💡 Rotary Positional Embeddings (RoPE) for better performance and context window extrapolation
- 📌 **[Rotary Positional Embeddings & Rotation Matrix + Python LLM Code](https://youtu.be/wiJ-OU-URYg)**
- 🧠 **[Get SMARTER Than 99% of AI Researchers](https://youtu.be/X0JryI85hL0)** - Beginning part
- 🛠️ **[RoPE In DeepSeek V3 – Code Step by Step](https://youtu.be/Rs9tLDSMUkM)**
- 🏋️ **[Excercises with ChatGPT Chat](https://chatgpt.com/share/68945a01-8d48-8002-8cf0-04b7f6db744b)**

🚀 Muon optimizer using Newton-Schulz orthogonalization for better weight updates, faster learning with less data
- This is the new best optimizer for 2D matrices, while AdamW is used for other parts of LLM. Highly likely Muon is used in GPT-5 as [OpenAI's researcher is tied to its invention](https://kellerjordan.github.io/posts/muon/).
- 🔁 [Backpropagation From Scratch](https://youtu.be/W8g1hvW4Wic) — Understand gradients deeply
- 🧠 [Orthonormal Matrix Intuition](https://youtu.be/FbYRZpBgFz4) — Key concept behind Muon’s update step
- Search for "Muon" on YouTube, you will find more tutorials.

> For all other things (below and above), watch my [Code & Train Qwen 3 From Scratch - Full Course](https://youtu.be/wM-KP_wNAeY) - I built and trained it on modern GPT architecture with latest advancements

📐 QK-Norm with RMSNorm for improved numerical / training stability

🔁 Hybrid optimization using Muon for matrices and AdamW for other parameters

🔄 SwiGLU activation and deep residual learning in the feedforward layers

🔢 Efficient dataset tokenization and caching with HuggingFace Datasets and Transformers

🧪 Validation metrics including loss, accuracy, and perplexity

🧵 Gradient accumulation + AMP (Automatic Mixed Precision) training for larger batch sizes

🎛️ Cosine learning rate scheduling with warmup

Find more tutorials / courses on AI research and engineering on [my YouTube](https://www.youtube.com/channel/UC7XJj9pv_11a11FUxCMz15g).



# Training

`max_steps: int = 2000` determines the duration of training, 2000 steps is about 12 minutes

If it asks you for `HF_TOKEN`, you can just click `cancel` as it doesn't need it.

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import math
import random
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import time
from transformers import AutoTokenizer
from dataclasses import dataclass
from typing import List, Optional
import warnings
import os
import pickle
warnings.filterwarnings('ignore')

def set_seed(seed: int = 42):
    """Set all random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"🌱 Set all seeds to {seed}")

@dataclass
class ModelConfig:
    # Model architecture
    d_model: int = 384
    n_heads: int = 8
    n_layers: int = 6
    d_ff: int = 1536
    batch_size: int = 24
    max_steps: int = 2000

    # Qwen3-like parameters
    n_kv_heads: int = 4  # For Grouped-Query Attention
    sliding_window: int = 4096  # Set a large default, effectively disabling it unless specified
    attention_bias: bool = False  # Qwen3 often sets this to False
    rms_norm_eps: float = 1e-6  # Epsilon for RMSNorm

    # Training parameters
    gradient_accumulation_steps: int = 4
    muon_lr: float = 0.01

    # Data parameters
    max_seq_len: int = 512
    num_documents: int = 2000
    max_tokens: int = 500000

    # Evaluation
    eval_every: int = 500
    eval_steps: int = 100

    # Regularization
    weight_decay: float = 0.1
    dropout: float = 0.1
    grad_clip: float = 1.0

    # Technical
    use_amp: bool = True
    vocab_size: Optional[int] = None

    def __post_init__(self):
        self.d_k = self.d_model // self.n_heads
        assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"
        assert self.n_heads % self.n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
        self.n_kv_groups = self.n_heads // self.n_kv_heads

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
    The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
    to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

@torch.compile
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5) -> torch.Tensor:
    """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G."""
    assert G.ndim >= 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()

    if G.size(-2) > G.size(-1):
        X = X.mT

    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)

    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A
        X = a * X + B @ X

    if G.size(-2) > G.size(-1):
        X = X.mT

    return X

class Muon(torch.optim.Optimizer):
    """Muon - MomentUm Orthogonalized by Newton-schulz"""
    def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                g = p.grad
                state = self.state[p]

                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = torch.zeros_like(g)

                buf = state["momentum_buffer"]
                buf.lerp_(g, 1 - group["momentum"])
                g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
                g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
                p.add_(g.view_as(p), alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)

def load_and_cache_data(config: ModelConfig, cache_dir: str = "data_cache"):
    """Load and cache tokenized data to avoid reprocessing"""
    os.makedirs(cache_dir, exist_ok=True)
    cache_file = f"{cache_dir}/tokenized_data_{config.num_documents}_{config.max_tokens}.pkl"

    # Check if cached data exists
    if os.path.exists(cache_file):
        print(f"📦 Loading cached data from {cache_file}")
        with open(cache_file, 'rb') as f:
            cached_data = pickle.load(f)

        texts = cached_data['texts']
        tokenizer = cached_data['tokenizer']
        tokens = cached_data['tokens']
        config.vocab_size = tokenizer.vocab_size

        print(f"✅ Loaded {len(texts)} documents, {len(tokens):,} tokens from cache")
        return texts, tokenizer, tokens

    print(f"🔄 Processing new data (will cache for future use)")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M", token=False)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load dataset
    dataset = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", split="train", streaming=True, token=False)

    texts = []
    for i, item in enumerate(dataset):
        if i >= config.num_documents:
            break
        texts.append(item["text"][:3000])

    print(f"Loaded {len(texts)} documents")

    # Tokenize
    print("Tokenizing texts...")
    all_tokens = []
    for text in tqdm(texts, desc="Tokenizing"):
        tokens = tokenizer.encode(text, add_special_tokens=False)
        all_tokens.extend(tokens)

    tokens = all_tokens[:config.max_tokens]
    print(f"Using {len(tokens):,} tokens")
    config.vocab_size = tokenizer.vocab_size

    # Cache the processed data
    cached_data = {'texts': texts, 'tokenizer': tokenizer, 'tokens': tokens}
    with open(cache_file, 'wb') as f:
        pickle.dump(cached_data, f)

    print(f"💾 Cached data to {cache_file}")
    return texts, tokenizer, tokens

class TextTokenDataset(Dataset):
    def __init__(self, tokens: List[int], seq_len: int = 512):
        self.tokens = tokens
        self.seq_len = seq_len

    def __len__(self):
        return max(0, len(self.tokens) - self.seq_len)

    def __getitem__(self, idx):
        x = torch.tensor(self.tokens[idx:idx + self.seq_len], dtype=torch.long)
        y = torch.tensor(self.tokens[idx + 1:idx + self.seq_len + 1], dtype=torch.long)
        return x, y

class Rotary(nn.Module):
    def __init__(self, dim: int, max_seq_len: int):
        super().__init__()
        angular_freq = (1 / 10000) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
        angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
        t = torch.arange(max_seq_len, dtype=torch.float32)
        theta = torch.einsum("i,j -> ij", t, angular_freq)
        self.register_buffer('cos', theta.cos(), persistent=False)
        self.register_buffer('sin', theta.sin(), persistent=False)

    def forward(self, x_BTHD: torch.Tensor):
        assert self.cos.size(0) >= x_BTHD.size(-3)
        cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
        x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
        y1 = x1 * cos + x2 * sin
        y2 = x1 * (-sin) + x2 * cos
        return torch.cat((y1, y2), 3).type_as(x_BTHD)

class Qwen3Attention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.n_kv_heads = config.n_kv_heads
        self.n_kv_groups = config.n_kv_groups
        self.d_k = config.d_k

        # Separate linear layers for Q, K, V
        self.q_proj = nn.Linear(self.d_model, self.n_heads * self.d_k, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.d_model, self.n_kv_heads * self.d_k, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.d_model, self.n_kv_heads * self.d_k, bias=config.attention_bias)
        self.w_o = nn.Linear(self.d_model, self.d_model, bias=False)

        # QK-Normalization layers
        self.q_norm = nn.RMSNorm(self.d_k, eps=config.rms_norm_eps)
        self.k_norm = nn.RMSNorm(self.d_k, eps=config.rms_norm_eps)

        self.rotary = Rotary(self.d_k, config.max_seq_len)
        self.dropout = config.dropout

    def forward(self, x):
        batch_size, seq_len = x.size(0), x.size(1)

        # 1. Project Q, K, V separately
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # 2. Reshape into heads
        q = q.view(batch_size, seq_len, self.n_heads, self.d_k)
        k = k.view(batch_size, seq_len, self.n_kv_heads, self.d_k)
        v = v.view(batch_size, seq_len, self.n_kv_heads, self.d_k)

        # 3. Apply QK-Norm
        q = self.q_norm(q)
        k = self.k_norm(k)

        # 4. Apply RoPE
        # Transpose to (batch, seq_len, n_heads, d_k) -> (batch, n_heads, seq_len, d_k) for rotary
        q = self.rotary(q.permute(0, 2, 1, 3)).permute(0, 2, 1, 3)
        k = self.rotary(k.permute(0, 2, 1, 3)).permute(0, 2, 1, 3)

        # Transpose for attention: (batch, seq_len, n_heads, d_k) -> (batch, n_heads, seq_len, d_k)
        Q = q.transpose(1, 2)
        K = k.transpose(1, 2)
        V = v.transpose(1, 2)

        # 5. Repeat K and V heads for GQA
        K = repeat_kv(K, self.n_kv_groups)
        V = repeat_kv(V, self.n_kv_groups)

        # 6. Scaled Dot-Product Attention
        attn_output = F.scaled_dot_product_attention(
            Q, K, V, is_causal=True, dropout_p=self.dropout if self.training else 0.0
        )

        # 7. Reshape and final projection
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.w_o(attn_output)

class SwiGLUFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
        self.down_proj = nn.Linear(d_ff, d_model, bias=False)
        self.up_proj = nn.Linear(d_model, d_ff, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Implementation of the SwiGLU activation function
        # F.silu is the Swish activation function
        activated_x = F.silu(self.gate_proj(x)) * self.up_proj(x)
        return self.down_proj(self.dropout(activated_x))

class TransformerBlock(nn.Module):
    def __init__(self, config: ModelConfig):  # Pass the entire config object
        super().__init__()
        self.attention = Qwen3Attention(config)
        self.feed_forward = SwiGLUFeedForward(config.d_model, config.d_ff, config.dropout)
        self.norm1 = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps)
        self.norm2 = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        attn_out = self.attention(self.norm1(x))
        x = x + self.dropout(attn_out)
        ff_out = self.feed_forward(self.norm2(x))
        x = x + self.dropout(ff_out)
        return x

class MinimalLLM(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.position_dropout = nn.Dropout(config.dropout)

        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layers)
        ])

        self.norm = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps)
        self.output_dropout = nn.Dropout(config.dropout)

        # Tie weights
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.lm_head.weight = self.token_embedding.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x):
        x = self.token_embedding(x) * math.sqrt(self.config.d_model)
        x = self.position_dropout(x)

        for block in self.transformer_blocks:
            x = block(x)

        x = self.norm(x)
        x = self.output_dropout(x)
        logits = self.lm_head(x)
        return logits

def evaluate_model(model: nn.Module, val_loader: DataLoader, config: ModelConfig):
    """Evaluate model performance"""
    model.eval()
    total_loss = 0
    total_tokens = 0
    total_correct = 0

    device = next(model.parameters()).device

    with torch.no_grad():
        for i, (x, y) in enumerate(val_loader):
            if i >= config.eval_steps:
                break
            x, y = x.to(device), y.to(device)

            with autocast(enabled=config.use_amp):
                logits = model(x)
                loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))

            total_loss += loss.item() * y.numel()
            total_tokens += y.numel()

            predictions = logits.argmax(dim=-1)
            total_correct += (predictions == y).sum().item()

    avg_loss = total_loss / total_tokens
    accuracy = total_correct / total_tokens
    perplexity = math.exp(min(avg_loss, 20))

    model.train()
    return {'val_loss': avg_loss, 'val_accuracy': accuracy, 'val_perplexity': perplexity}

def setup_muon_optimizer(model: nn.Module, config: ModelConfig):
    """Setup Muon optimizer with hybrid approach"""
    muon_params = []
    adamw_params = []

    for name, param in model.named_parameters():
        if (param.ndim == 2 and
            'token_embedding' not in name and
            'norm' not in name and
            param.requires_grad):
            muon_params.append(param)
        else:
            adamw_params.append(param)

    print(f"  Muon parameters: {sum(p.numel() for p in muon_params):,}")
    print(f"  AdamW parameters: {sum(p.numel() for p in adamw_params):,}")

    muon_optimizer = Muon(muon_params, lr=config.muon_lr, momentum=0.95)
    adamw_optimizer = torch.optim.AdamW(adamw_params, lr=config.muon_lr*0.1, weight_decay=config.weight_decay)

    return [muon_optimizer, adamw_optimizer]

def train_model(config: ModelConfig, train_loader: DataLoader, val_loader: DataLoader):
    """Train the model with Muon optimizer"""
    print(f"\n🚀 Training Small model with Muon optimizer")

    # Initialize model
    set_seed(42)
    model = MinimalLLM(config)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"  📊 Total parameters: {total_params:,}")

    # Setup optimizers
    optimizers = setup_muon_optimizer(model, config)

    # Learning rate schedule
    schedulers = []
    for optimizer in optimizers:
        warmup_steps = config.max_steps // 20
        def lr_lambda(step):
            if step < warmup_steps:
                return step / warmup_steps
            else:
                progress = (step - warmup_steps) / (config.max_steps - warmup_steps)
                return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        schedulers.append(scheduler)

    scaler = GradScaler() if config.use_amp else None

    # Training loop
    model.train()
    step = 0
    start_time = time.time()
    best_val_loss = float('inf')

    pbar = tqdm(total=config.max_steps, desc="Training")

    while step < config.max_steps:
        for batch_idx, (x, y) in enumerate(train_loader):
            if step >= config.max_steps:
                break

            x, y = x.to(device), y.to(device)

            # Forward pass with gradient accumulation
            if config.use_amp:
                with autocast():
                    logits = model(x)
                    loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))
                    loss = loss / config.gradient_accumulation_steps
                scaler.scale(loss).backward()
            else:
                logits = model(x)
                loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))
                loss = loss / config.gradient_accumulation_steps
                loss.backward()

            # Optimizer step after accumulation
            if (step + 1) % config.gradient_accumulation_steps == 0:
                if config.use_amp:
                    for optimizer in optimizers:
                        scaler.unscale_(optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)

                    for optimizer in optimizers:
                        scaler.step(optimizer)
                        optimizer.zero_grad()
                    for scheduler in schedulers:
                        scheduler.step()
                    scaler.update()
                else:
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
                    for optimizer in optimizers:
                        optimizer.step()
                        optimizer.zero_grad()
                    for scheduler in schedulers:
                        scheduler.step()

            # Logging
            if step % 10 == 0:
                with torch.no_grad():
                    predictions = logits.argmax(dim=-1)
                    accuracy = (predictions == y).float().mean().item()
                    current_loss = loss.item() * config.gradient_accumulation_steps
                    perplexity = math.exp(min(current_loss, 20))

                pbar.set_postfix({
                    'loss': f'{current_loss:.4f}',
                    'acc': f'{accuracy:.3f}',
                    'ppl': f'{perplexity:.1f}',
                    'lr': f'{optimizers[0].param_groups[0]["lr"]:.2e}'
                })

            # Evaluation
            if step % config.eval_every == 0 and step > 0:
                eval_metrics = evaluate_model(model, val_loader, config)
                print(f"\nStep {step}: Val Loss: {eval_metrics['val_loss']:.4f}, "
                      f"Val Acc: {eval_metrics['val_accuracy']:.4f}, "
                      f"Val PPL: {eval_metrics['val_perplexity']:.2f}")

                if eval_metrics['val_loss'] < best_val_loss:
                    best_val_loss = eval_metrics['val_loss']
                    # Save best model
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'config': config,
                        'step': step,
                        'best_val_loss': best_val_loss,
                        'final_metrics': eval_metrics
                    }, 'best_model.pt')
                    print(f"💾 Saved best model with val_loss: {best_val_loss:.4f}")

            step += 1
            if step % 10 == 0:
                pbar.update(10)

    pbar.close()

    training_time = time.time() - start_time
    print(f"  ⏱️ Training completed in {training_time:.1f} seconds")

    # Final evaluation
    final_eval = evaluate_model(model, val_loader, config)
    print(f"  📊 Final - Loss: {final_eval['val_loss']:.4f}, "
          f"Acc: {final_eval['val_accuracy']:.4f}, PPL: {final_eval['val_perplexity']:.2f}")

    # Save final model
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': config,
        'step': step,
        'final_metrics': final_eval
    }, 'final_model.pt')
    print(f"💾 Saved final model to final_model.pt")

    return model, final_eval

if __name__ == "__main__":
    # Check system
    print(f"🔍 Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name()}")
        print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

    # Set seed
    set_seed(42)

    # Create config for Small model
    config = ModelConfig()
    print(f"\n📋 Model Configuration:")
    print(f"   Architecture: {config.d_model}d, {config.n_layers}L, {config.n_heads}H, {config.d_ff}ff")
    print(f"   Training: {config.max_steps} steps, batch size {config.batch_size}")
    print(f"   Data: {config.max_tokens:,} tokens, seq_len {config.max_seq_len}")

    # Load data
    texts, tokenizer, tokens = load_and_cache_data(config)
    dataset = TextTokenDataset(tokens, config.max_seq_len)

    # Train/val split
    val_size = len(dataset) // 10
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)
    )

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2)

    print(f"📊 Dataset: {len(train_dataset)} train, {len(val_dataset)} val samples")

    # Train model
    start_time = time.time()
    model, final_metrics = train_model(config, train_loader, val_loader)
    total_time = time.time() - start_time

    print(f"\n🎉 TRAINING COMPLETED!")
    print(f"⏱️ Total time: {total_time/60:.1f} minutes")
    print(f"🏆 Final Results:")
    print(f"   Validation Loss: {final_metrics['val_loss']:.4f}")
    print(f"   Validation Accuracy: {final_metrics['val_accuracy']:.4f}")
    print(f"   Validation Perplexity: {final_metrics['val_perplexity']:.2f}")

🔍 Device: CUDA
GPU: Tesla T4
Memory: 15.8 GB
🌱 Set all seeds to 42

📋 Model Configuration:
   Architecture: 384d, 6L, 8H, 1536ff
   Training: 2000 steps, batch size 24
   Data: 500,000 tokens, seq_len 512
📦 Loading cached data from data_cache/tokenized_data_2000_500000.pkl
✅ Loaded 2000 documents, 500,000 tokens from cache
📊 Dataset: 449540 train, 49948 val samples

🚀 Training Small model with Muon optimizer
🌱 Set all seeds to 42
  📊 Total parameters: 32,150,976
  Muon parameters: 13,271,040
  AdamW parameters: 18,879,936



Training:   0%|          | 0/2000 [00:00<?, ?it/s][A
Training:   0%|          | 0/2000 [00:00<?, ?it/s, loss=10.8028, acc=0.015, ppl=49156.2, lr=0.00e+00][A
Training:   0%|          | 10/2000 [00:04<14:15,  2.33it/s, loss=10.8028, acc=0.015, ppl=49156.2, lr=0.00e+00][A
Training:   0%|          | 10/2000 [00:04<14:15,  2.33it/s, loss=10.8007, acc=0.014, ppl=49055.0, lr=2.00e-04][A
Training:   1%|          | 20/2000 [00:08<13:18,  2.48it/s, loss=10.8007, acc=0.014, ppl=49055.0, lr=2.00e-04][A
Training:   1%|          | 20/2000 [00:08<13:18,  2.48it/s, loss=10.7808, acc=0.016, ppl=48088.1, lr=5.00e-04][A
Training:   2%|▏         | 30/2000 [00:11<12:11,  2.69it/s, loss=10.7808, acc=0.016, ppl=48088.1, lr=5.00e-04][A
Training:   2%|▏         | 30/2000 [00:12<12:11,  2.69it/s, loss=10.7480, acc=0.016, ppl=46535.3, lr=7.00e-04][A
Training:   2%|▏         | 40/2000 [00:15<12:15,  2.66it/s, loss=10.7480, acc=0.016, ppl=46535.3, lr=7.00e-04][A
Training:   2%|▏         | 40/2000 [00:15<


Step 500: Val Loss: 4.9941, Val Acc: 0.2329, Val PPL: 147.55
💾 Saved best model with val_loss: 4.9941



Training:  26%|██▌       | 510/2000 [03:13<18:02,  1.38it/s, loss=5.2135, acc=0.221, ppl=183.7, lr=1.00e-02][A
Training:  26%|██▌       | 510/2000 [03:14<18:02,  1.38it/s, loss=5.1443, acc=0.217, ppl=171.5, lr=1.00e-02][A
Training:  26%|██▌       | 520/2000 [03:17<15:38,  1.58it/s, loss=5.1443, acc=0.217, ppl=171.5, lr=1.00e-02][A
Training:  26%|██▌       | 520/2000 [03:18<15:38,  1.58it/s, loss=4.9569, acc=0.231, ppl=142.2, lr=9.99e-03][A
Training:  26%|██▋       | 530/2000 [03:21<13:39,  1.79it/s, loss=4.9569, acc=0.231, ppl=142.2, lr=9.99e-03][A
Training:  26%|██▋       | 530/2000 [03:22<13:39,  1.79it/s, loss=4.9854, acc=0.238, ppl=146.3, lr=9.99e-03][A
Training:  27%|██▋       | 540/2000 [03:25<12:14,  1.99it/s, loss=4.9854, acc=0.238, ppl=146.3, lr=9.99e-03][A
Training:  27%|██▋       | 540/2000 [03:25<12:14,  1.99it/s, loss=4.9169, acc=0.243, ppl=136.6, lr=9.99e-03][A
Training:  28%|██▊       | 550/2000 [03:28<10:52,  2.22it/s, loss=4.9169, acc=0.243, ppl=136.6, lr=9.99


Step 1000: Val Loss: 3.1843, Val Acc: 0.4083, Val PPL: 24.15
💾 Saved best model with val_loss: 3.1843



Training:  50%|█████     | 1010/2000 [06:22<11:57,  1.38it/s, loss=3.5977, acc=0.356, ppl=36.5, lr=9.86e-03][A
Training:  50%|█████     | 1010/2000 [06:22<11:57,  1.38it/s, loss=3.4113, acc=0.375, ppl=30.3, lr=9.86e-03][A
Training:  51%|█████     | 1020/2000 [06:26<10:07,  1.61it/s, loss=3.4113, acc=0.375, ppl=30.3, lr=9.86e-03][A
Training:  51%|█████     | 1020/2000 [06:26<10:07,  1.61it/s, loss=3.2506, acc=0.399, ppl=25.8, lr=9.85e-03][A
Training:  52%|█████▏    | 1030/2000 [06:29<08:35,  1.88it/s, loss=3.2506, acc=0.399, ppl=25.8, lr=9.85e-03][A
Training:  52%|█████▏    | 1030/2000 [06:29<08:35,  1.88it/s, loss=3.4952, acc=0.364, ppl=33.0, lr=9.85e-03][A
Training:  52%|█████▏    | 1040/2000 [06:33<07:44,  2.07it/s, loss=3.4952, acc=0.364, ppl=33.0, lr=9.85e-03][A
Training:  52%|█████▏    | 1040/2000 [06:33<07:44,  2.07it/s, loss=3.5006, acc=0.365, ppl=33.1, lr=9.84e-03][A
Training:  52%|█████▎    | 1050/2000 [06:36<06:54,  2.29it/s, loss=3.5006, acc=0.365, ppl=33.1, lr=9.84


Step 1500: Val Loss: 1.8982, Val Acc: 0.6076, Val PPL: 6.67
💾 Saved best model with val_loss: 1.8982



Training:  76%|███████▌  | 1510/2000 [09:29<05:55,  1.38it/s, loss=2.5132, acc=0.495, ppl=12.3, lr=9.54e-03][A
Training:  76%|███████▌  | 1510/2000 [09:30<05:55,  1.38it/s, loss=2.2681, acc=0.533, ppl=9.7, lr=9.54e-03] [A
Training:  76%|███████▌  | 1520/2000 [09:33<04:57,  1.61it/s, loss=2.2681, acc=0.533, ppl=9.7, lr=9.54e-03][A
Training:  76%|███████▌  | 1520/2000 [09:34<04:57,  1.61it/s, loss=2.3572, acc=0.517, ppl=10.6, lr=9.53e-03][A
Training:  76%|███████▋  | 1530/2000 [09:36<04:09,  1.88it/s, loss=2.3572, acc=0.517, ppl=10.6, lr=9.53e-03][A
Training:  76%|███████▋  | 1530/2000 [09:37<04:09,  1.88it/s, loss=2.4870, acc=0.494, ppl=12.0, lr=9.52e-03][A
Training:  77%|███████▋  | 1540/2000 [09:40<03:42,  2.07it/s, loss=2.4870, acc=0.494, ppl=12.0, lr=9.52e-03][A
Training:  77%|███████▋  | 1540/2000 [09:41<03:42,  2.07it/s, loss=2.3369, acc=0.515, ppl=10.3, lr=9.51e-03][A
Training:  78%|███████▊  | 1550/2000 [09:43<03:16,  2.29it/s, loss=2.3369, acc=0.515, ppl=10.3, lr=9.51e

  ⏱️ Training completed in 741.2 seconds





  📊 Final - Loss: 1.1151, Acc: 0.7511, PPL: 3.05
💾 Saved final model to final_model.pt

🎉 TRAINING COMPLETED!
⏱️ Total time: 12.6 minutes
🏆 Final Results:
   Validation Loss: 1.1151
   Validation Accuracy: 0.7511
   Validation Perplexity: 3.05


# Inference

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import numpy as np
from transformers import AutoTokenizer
from dataclasses import dataclass
from typing import List, Optional, Tuple
import warnings
import os
import pickle
from tqdm import tqdm
import torch.serialization
warnings.filterwarnings('ignore')

# Import the model classes from the training file
import sys
sys.path.append('.')

def set_seed(seed: int = 42):
    """Set all random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f" Set all seeds to {seed}")

class TextGenerator:
    """Text generation class for the trained model"""

    def __init__(self, model_path: str = "final_model.pt", tokenizer_path: str = "HuggingFaceTB/SmolLM-135M", device: str = "auto"):
        """
        Initialize the text generator

        Args:
            model_path: Path to the saved model checkpoint (default: final_model.pt)
            tokenizer_path: Path to the tokenizer (default uses the same as training)
            device: Device to run inference on ("auto", "cpu", "cuda")
        """
        self.device = self._get_device(device)
        print(f"🔧 Using device: {self.device}")

        # Load tokenizer
        print("📚 Loading tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=False)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Load model
        print("🤖 Loading model...")
        self.model, self.config = self._load_model(model_path)
        self.model.to(self.device)
        self.model.eval()

        print(f"✅ Model loaded successfully!")
        print(f"   Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        print(f"   Vocabulary size: {self.config.vocab_size}")
        print(f"   Max sequence length: {self.config.max_seq_len}")

    def _get_device(self, device: str) -> torch.device:
        """Determine the best device to use"""
        if device == "auto":
            if torch.cuda.is_available():
                return torch.device("cuda")
            else:
                return torch.device("cpu")
        else:
            return torch.device(device)

    def _load_model(self, model_path: str) -> Tuple[MinimalLLM, ModelConfig]:
        """Load the trained model from checkpoint"""
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model checkpoint not found: {model_path}")

        # Add ModelConfig to safe globals for PyTorch 2.6+ compatibility
        torch.serialization.add_safe_globals([ModelConfig])

        try:
            # Try loading with weights_only=True first (PyTorch 2.6+ default)
            checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
        except Exception as e:
            print(f"⚠️  weights_only=True failed, trying weights_only=False: {e}")
            # Fallback to weights_only=False for older checkpoints
            checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)

        # Extract config and create model
        config = checkpoint['config']
        model = MinimalLLM(config)

        # Load state dict
        model.load_state_dict(checkpoint['model_state_dict'])

        return model, config

    def generate(
        self,
        prompt: str,
        max_length: int = 100,
        temperature: float = 0.8,
        top_p: float = 0.9,
        top_k: int = 50,
        do_sample: bool = True,
        num_return_sequences: int = 1,
        stop_tokens: Optional[List[str]] = None
    ) -> List[str]:
        """
        Generate text from a prompt

        Args:
            prompt: Input text prompt
            max_length: Maximum length of generated text (including prompt)
            temperature: Sampling temperature (higher = more random)
            top_p: Nucleus sampling parameter
            top_k: Top-k sampling parameter
            do_sample: Whether to use sampling (False for greedy decoding)
            num_return_sequences: Number of sequences to generate
            stop_tokens: List of tokens to stop generation at

        Returns:
            List of generated text sequences
        """
        # Tokenize prompt
        input_ids = self.tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
        input_ids = input_ids.to(self.device)

        # Convert stop tokens to IDs
        stop_token_ids = []
        if stop_tokens:
            for token in stop_tokens:
                token_id = self.tokenizer.encode(token, add_special_tokens=False)
                if token_id:
                    stop_token_ids.extend(token_id)

        generated_sequences = []

        for _ in range(num_return_sequences):
            # Generate sequence
            generated_ids = self._generate_sequence(
                input_ids=input_ids,
                max_length=max_length,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                do_sample=do_sample,
                stop_token_ids=stop_token_ids
            )

            # Decode to text
            generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
            generated_sequences.append(generated_text)

        return generated_sequences

    def _generate_sequence(
        self,
        input_ids: torch.Tensor,
        max_length: int,
        temperature: float,
        top_p: float,
        top_k: int,
        do_sample: bool,
        stop_token_ids: List[int]
    ) -> torch.Tensor:
        """Generate a single sequence using the model"""

        current_ids = input_ids.clone()
        generated_length = current_ids.shape[1]

        with torch.no_grad():
            while generated_length < max_length:
                # Get model predictions
                logits = self.model(current_ids)
                next_token_logits = logits[0, -1, :] / temperature

                # Apply top-k filtering
                if top_k > 0:
                    top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
                    next_token_logits = torch.full_like(next_token_logits, float('-inf'))
                    next_token_logits[top_k_indices] = top_k_logits

                # Apply top-p (nucleus) filtering
                if top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                    # Remove tokens with cumulative probability above the threshold
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
                    sorted_indices_to_remove[0] = 0

                    indices_to_remove = sorted_indices[sorted_indices_to_remove]
                    next_token_logits[indices_to_remove] = float('-inf')

                # Sample or take argmax
                if do_sample:
                    probs = F.softmax(next_token_logits, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1)
                else:
                    next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

                # Ensure next_token has the correct shape for concatenation
                # next_token should be (1, 1) to match current_ids shape (1, seq_len)
                if next_token.dim() == 1:
                    next_token = next_token.unsqueeze(0)

                # Check for stop tokens
                if next_token.item() in stop_token_ids:
                    break

                # Append to sequence
                current_ids = torch.cat([current_ids, next_token], dim=1)
                generated_length += 1

        return current_ids[0]

    def get_perplexity(self, text: str) -> float:
        """Calculate perplexity of the given text"""
        # Tokenize text
        tokens = self.tokenizer.encode(text, add_special_tokens=False)

        if len(tokens) < 2:
            return float('inf')

        # Create sequences for evaluation
        sequences = []
        for i in range(len(tokens) - 1):
            sequences.append((tokens[i], tokens[i + 1]))

        total_loss = 0
        total_tokens = 0

        self.model.eval()
        with torch.no_grad():
            for input_token, target_token in sequences:
                # Create input tensor
                input_tensor = torch.tensor([[input_token]], device=self.device)
                target_tensor = torch.tensor([[target_token]], device=self.device)

                # Get model prediction
                logits = self.model(input_tensor)
                loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), target_tensor.view(-1))

                total_loss += loss.item()
                total_tokens += 1

        avg_loss = total_loss / total_tokens
        perplexity = math.exp(avg_loss)

        return perplexity

def interactive_mode(generator: TextGenerator):
    """Run interactive text generation mode"""
    print("\n🎭 Interactive Text Generation Mode")
    print("Type 'quit' to exit, 'help' for commands")
    print("=" * 50)

    while True:
        try:
            prompt = input("\n Enter your prompt: ").strip()

            if prompt.lower() == 'quit':
                print("👋 Goodbye!")
                break
            elif prompt.lower() == 'help':
                print("\n Available commands:")
                print("  help - Show this help message")
                print("  quit - Exit the program")
                print("  settings - Show current generation settings")
                print("  sample - Generate with sampling")
                print("  greedy - Generate with greedy decoding")
                print("\n💡 Tips:")
                print("  - Use 'sample' or 'greedy' prefix to change generation mode")
                print("  - Example: 'sample The quick brown fox'")
                continue
            elif prompt.lower() == 'settings':
                print("\n⚙️ Current settings:")
                print("  Temperature: 0.8")
                print("  Top-p: 0.9")
                print("  Top-k: 50")
                print("  Max length: 100")
                continue

            # Check for mode prefixes
            do_sample = True
            if prompt.startswith('sample '):
                prompt = prompt[7:]
                do_sample = True
            elif prompt.startswith('greedy '):
                prompt = prompt[7:]
                do_sample = False

            if not prompt:
                continue

            print(f"\n Generating text...")
            generated_texts = generator.generate(
                prompt=prompt,
                max_length=100,
                temperature=0.8,
                top_p=0.9,
                top_k=50,
                do_sample=do_sample,
                num_return_sequences=1
            )

            print(f"\n✨ Generated text:")
            print("-" * 40)
            for i, text in enumerate(generated_texts, 1):
                print(f"{i}. {text}")
            print("-" * 40)

        except KeyboardInterrupt:
            print("\n👋 Goodbye!")
            break
        except Exception as e:
            print(f"❌ Error: {e}")

def main():
    # Set seed
    set_seed(42)

    # Initialize generator with default settings
    try:
        generator = TextGenerator("final_model.pt")
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        return

    # Run interactive mode
    interactive_mode(generator)

if __name__ == "__main__":
    main()

 Set all seeds to 42
🔧 Using device: cuda
📚 Loading tokenizer...
🤖 Loading model...
✅ Model loaded successfully!
   Parameters: 32,150,976
   Vocabulary size: 49152
   Max sequence length: 512

🎭 Interactive Text Generation Mode
Type 'quit' to exit, 'help' for commands

 Enter your prompt: the future of AI is

 Generating text...

✨ Generated text:
----------------------------------------
1. the future of AI is like the First, the way, the right away to build something called an online, and take care of like a lot of this, it.

So why is it important? Well, understanding and keep track of this important to keep track of like exploring and understand market trends and why using smart decisions. By fostering a better future. Who knows, maybe someday you'll become an example of this incredible inventions or problems in helping us. Happy exploring! Chapter 10
----------------------------------------
