# ***Qwen 3 LLM Implementation From Scratch***

**Project Overview**
- Lightweight LLM inspired by Qwen3, built from scratch in PyTorch.
- Implements modern transformer components including RMSNorm, Rotary Position Embeddings (RoPE), Grouped-Query Attention (GQA), and SwiGLU feed-forward layers.
- Trained using a hybrid Muon + AdamW optimizer setup with causal masking, efficient batching, and evaluation utilities.
- Includes full training pipeline, model loading, and interactive text generation demos for hands-on experimentation.

**Useful Materials**
- Qwen 3 Technical Report PDF: https://arxiv.org/pdf/2505.09388
- Qwen 3 GitHub Repo: https://github.com/QwenLM/Qwen3

**Step by Step Overview (Table of Contents)**
1. Imports
2. Utility Functions (set_seed, ...)
3. Model Configuration
4. Key/Value Head Expansion Function
5. Muron Optimizer (Orthogonalized Momentum via Newton–Schulz)
6. Data Loading and Caching
7. TextTokenDataset Class
8. Rotary Position Embeddings (RoPE)
9. Grouped-Query Attention (GQA)
10. SwiGLU Feed-Forward Network (FFN)
11. Transformer block (attention + FFN + RMSNorm + residuals)
12. Language model class (MinimalLLM)
13. Evaluation function (loss, accuracy, perplexity)
14. Optimizer setup (hybrid Muon + AdamW)
15. Training loop (AMP, grad accumulation, schedulers)
16. Training Script
17. Model Loading
18. Model Inference - Autoregressive Text Generation and Chat Interactive Inference

# **1. Imports**

In [1]:
import os                                          # File system operations (creating folders, path checking, etc.)
import time                                        # Timing utilities, measuring time
import math                                        # Standard math operations (e.g. sqrt, exp, cos)
import random                                      # Python's random number utilities (used for seeding)
import pickle                                      # Python object serialization (used to save/load preprocessed datasets)
import warnings                                    # Suppress or handle warnings

import numpy as np                                 # Numerical computing library, used for random seeding and general array ops
import torch                                       # PyTorch main library
import torch.nn as nn                              # Neural network modules like Linear, Embedding, etc.
import torch.nn.functional as F                    # Functional interface for operations like cross_entropy, silu, etc.
from torch.utils.data import Dataset, DataLoader  # Base class and utilities for loading datasets
from torch.cuda.amp import autocast, GradScaler    # Automatic Mixed Precision (AMP) tools for faster/lower-memory training

from datasets import load_dataset                  # Hugging Face Datasets library for streaming large datasets
from tqdm import tqdm                              # Progress bar visualization library, great for loops
from transformers import AutoTokenizer             # Load pretrained tokenizers from HuggingFace with one line

from dataclasses import dataclass                  # Define simple classes for configs with less boilerplate
from typing import List, Optional                  # Type hints for better readability and tooling

warnings.filterwarnings('ignore')                  # Silences warnings for cleaner outputs during training


# **2. Utility Functions**

In [2]:
# Set Seed Utility Function Ensuring Reproducibiity
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Set all seeds to {seed}")

# **3. Model Configuration**

In [3]:
@dataclass
class ModelConfig:
    """
    Compact Qwen3-style configuration for reasoning-focused experiments.
    --------------------------------------------------------------------
    This configuration mirrors Qwen3’s architectural principles (GQA, RMSNorm,
    no attention bias, SwiGLU FFN) but at a much smaller scale for efficient
    prototyping and reasoning research on limited hardware.
    """

    # ----------------------
    # Model architecture
    # ----------------------
    d_model: int = 384                    # Hidden dimension size for each token embedding (Qwen3 uses 2K–8K)
    n_heads: int = 8                      # Total number of attention heads for multi-head attention
    n_layers: int = 6                     # Number of Transformer layers (Qwen3 ranges 28–94)
    d_ff: int = 1536                      # Feedforward dimension (≈4 × d_model, using SwiGLU in Qwen3)
    batch_size: int = 24                  # Mini-batch size for training
    max_steps: int = 2000                 # Total number of training iterations (smaller-scale training)

    # ----------------------
    # Qwen3-specific architectural tweaks
    # ----------------------
    n_kv_heads: int = 4                   # Number of KV heads for Grouped Query Attention (GQA)
    sliding_window: int = 4096            # Context chunk size; large default effectively disables sliding
    attention_bias: bool = False          # Qwen3 removes QKV bias for training stability
    rms_norm_eps: float = 1e-6            # Stabilizing epsilon for RMSNorm (Qwen3 uses 1e-6)

    # ----------------------
    # Training hyperparameters
    # ----------------------
    gradient_accumulation_steps: int = 4  # Accumulate gradients to simulate larger batch size
    muon_lr: float = 0.01                 # Learning rate (placeholder; scaled for small models)

    # ----------------------
    # Data handling parameters
    # ----------------------
    max_seq_len: int = 512                # Max sequence length per training sample
    num_documents: int = 2000             # Dataset size (for testing or research)
    max_tokens: int = 500000              # Limit total processed tokens to manage memory

    # ----------------------
    # Evaluation
    # ----------------------
    eval_every: int = 500                 # Frequency of evaluation during training
    eval_steps: int = 100                 # Number of evaluation steps

    # ----------------------
    # Regularization & stability
    # ----------------------
    weight_decay: float = 0.1             # L2 regularization strength
    dropout: float = 0.1                  # Dropout rate; Qwen3 dense models often disable dropout
    grad_clip: float = 1.0                # Gradient clipping threshold

    # ----------------------
    # Technical settings
    # ----------------------
    use_amp: bool = True                  # Enable automatic mixed precision for faster, lower-memory training
    vocab_size: Optional[int] = None      # Vocabulary size (Qwen3 uses 151,669 BBPE tokens)

    def __post_init__(self):
        # Per-head dimension for attention projections
        self.d_k = self.d_model // self.n_heads
        assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"

        # Grouped Query Attention (GQA) consistency check
        assert self.n_heads % self.n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"

        # Each Key/Value head is shared by multiple Query heads
        # Example: n_heads=8, n_kv_heads=4 → each KV head serves 2 Q heads
        self.n_kv_groups = self.n_heads // self.n_kv_heads

## **4. Key/Value Head Expansion for Grouped-Query Attention**
- In GQA, multiple query heads share a smaller number of key/value heads to save memory and compute.
- This helper function "expands" the K/V tensors so that each query head has a corresponding K/V pair to attend to.
## **Example:**
- if n_heads = 8 and n_kv_heads = 4 → each K/V head is repeated twice.
- Input shape:  (batch, n_kv_heads, seq_len, head_dim)
- Output shape: (batch, n_heads,    seq_len, head_dim)
- This effectively enables multi-head queries to reuse fewer key/value projections efficiently.



In [4]:
## **4. Key/Value Head Expansion for Multi-Query or Grouped-Query Attention**
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    Expands the number of key/value heads to match the number of attention heads
    in multi-query or grouped-query attention setups.

    Args:
        hidden_states: Tensor of shape (batch, num_kv_heads, seq_len, head_dim)
        n_rep: Number of repetitions for each key/value head

    Returns:
        Tensor of shape (batch, num_kv_heads * n_rep, seq_len, head_dim)

    Example:
        Suppose hidden_states has shape (2, 2, 4, 8) — 2 batches, 2 KV heads,
        sequence length 4, and head dimension 8.

        If n_rep = 4 (e.g., 8 attention heads / 2 KV heads),
        the output will have shape (2, 8, 4, 8):
        - Each KV head is repeated 4 times to match total attention heads.
    """
    batch, num_kv_heads, slen, head_dim = hidden_states.shape

    if n_rep == 1:
        return hidden_states

    # Insert new dimension → expand → flatten
    # (b, kv, s, d) → (b, kv, 1, s, d) → (b, kv, n_rep, s, d) → (b, kv*n_rep, s, d)
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)


## **5. Muron Optimizer (Orthogonalized Momentum via Newton–Schulz)**

- The Muon optimizer is a stability-oriented optimizer that applies Newton–Schulz orthogonalization to the gradient or momentum updates.

- Instead of directly applying raw gradient steps (which can accumulate correlated noise or directional bias), Muon reprojects updates onto an approximately orthogonal basis.
    - This helps preserve feature diversity and prevent parameter collapse—especially important in large, low-rank or transformer-style models like Qwen.

### **orthogonalize_newton_schulz Function**
- uses a polynomially approximated Newton–Schulz iteration to find the “zeroth power” of GᵀG (i.e., G * (GᵀG)⁻¹/²).
  - This approximates orthogonalization efficiently without an expensive matrix inverse or SVD.

### **Example Flow:**

#### ***Step 1: Start with a gradient tensor***

$$
G =
\begin{bmatrix}
1.0 & 2.0 \\
3.0 & 4.0
\end{bmatrix}
$$



- This gradient is *not orthogonal* — its columns are correlated.

---

#### ***Step 2: Normalize for numerical stability***
$$
G_{\text{norm}} = \frac{G}{\|G\|_F}
$$


- This scales \( G \) to unit Frobenius norm, preventing exploding values during iteration.

---

#### Step 3: Newton–Schulz iteration
- Each Newton–Schulz iteration refines \( X \) toward orthogonality via the polynomial update:

$$
X_{t+1} = \alpha X_t + \big( \beta (X_t X_t^\top) + \gamma (X_t X_t^\top X_t X_t^\top) \big) X_t
$$

- The coefficients $\alpha = 3.4445$, $\beta = -4.7750$, and $\gamma = 2.0315$ are carefully tuned to balance convergence speed and numerical stability, ensuring that the Newton–Schulz iteration  

---

#### ***Step 4: Result — Orthogonalized Gradient***
- After 5 iterations, \( X \) approaches:
$$
X_{orth} =
\begin{bmatrix}
-0.89 & 0.46 \\
0.46 & 0.89
\end{bmatrix}
$$
- Now the columns are approximately **orthogonal** and **unit-norm**.

---

#### ***Step 5: Parameter Update***
- Muon applies momentum, Nesterov correction (if enabled), and updates weights:
$$
\theta \leftarrow \theta - \eta \cdot X_{orth}
$$
- This ensures stable and diverse gradient directions during training.

---






In [5]:
@torch.compile
def orthogonalize_newton_schulz(G: torch.Tensor, num_iters: int = 5) -> torch.Tensor:
    """
    Orthogonalizes a tensor using the Newton–Schulz iterative method.

    This function projects the input matrix G toward an orthogonal basis
    (approximate zeroth power of GᵀG) using a numerically stable polynomial
    approximation.

    Args:
        G (torch.Tensor): Input tensor (matrix or batch of matrices).
        num_iters (int): Number of Newton–Schulz iterations (default: 5).

    Returns:
        torch.Tensor: Orthogonalized tensor with approximately unit-norm columns.
    """
    assert G.ndim >= 2, "Input must have at least 2 dimensions."

    # Coefficients for the polynomial update (empirically chosen)
    alpha, beta, gamma = 3.4445, -4.7750, 2.0315

    # Convert to bfloat16 for stability and performance
    X = G.bfloat16()

    # Ensure the "tall" matrix case is handled consistently
    transpose_needed = G.size(-2) > G.size(-1)
    if transpose_needed:
        X = X.mT

    # Normalize to unit Frobenius norm (stabilizes iteration)
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)

    # Newton–Schulz iterations to approach orthogonality
    for _ in range(num_iters):
        XXt = X @ X.mT
        poly_term = beta * XXt + gamma * (XXt @ XXt)
        X = alpha * X + poly_term @ X

    if transpose_needed:
        X = X.mT

    return X


class Muon(torch.optim.Optimizer):
    """
    Muon Optimizer: Momentum Orthogonalized via Newton–Schulz

    This optimizer extends classical momentum/Nesterov updates by
    orthogonalizing the gradient using the Newton–Schulz iteration.
    It helps stabilize updates in overparameterized regimes and can
    improve convergence for large, structured models.

    Key points:
    - Momentum: Maintains an exponential moving average of past gradients
      to smooth updates and accelerate convergence.
    - Nesterov Momentum: Looks ahead by applying the gradient correction
      based on the estimated future position, improving stability
      and convergence speed.

    Args:
        params (iterable): Model parameters to optimize.
        lr (float): Learning rate (default: 0.02).
        momentum (float): Momentum coefficient (default: 0.95).
        nesterov (bool): Whether to apply Nesterov momentum (default: True).
        ns_steps (int): Number of Newton–Schulz iterations for orthogonalization (default: 5).
    """
    def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            lr = group["lr"]
            momentum = group["momentum"]
            use_nesterov = group["nesterov"]
            ns_steps = group["ns_steps"]

            for param in group["params"]:
                if param.grad is None:
                    continue

                grad = param.grad
                state = self.state[param]

                # Initialize momentum buffer on first update
                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = torch.zeros_like(grad)

                momentum_buf = state["momentum_buffer"]

                # Exponential moving average update of gradient
                momentum_buf.lerp_(grad, 1 - momentum)

                # Apply Nesterov correction if enabled
                if use_nesterov:
                    grad = grad.lerp(momentum_buf, momentum)
                else:
                    grad = momentum_buf

                # Orthogonalize the gradient update
                grad = orthogonalize_newton_schulz(grad, num_iters=ns_steps)

                # Adaptive learning rate scaling based on matrix shape
                aspect_ratio = max(1.0, param.size(-2) / param.size(-1))
                scale = (aspect_ratio ** 0.5)

                # Parameter update
                param.add_(grad.view_as(param), alpha=-lr * scale)


## **6. Data Loading & Caching**

Loading and processing large text corpora can be time-consuming. This step implements an efficient **data loading and caching pipeline** to minimize redundant work and speed up repeated experiments.

#### **Key features of this pipeline include:**

* **Caching mechanism**: Saves tokenized data to disk and reloads it if available, avoiding repeated processing.
* **Configurable dataset size**: Limits the number of documents processed using `config.num_documents`.
* **Token limit enforcement**: Ensures the total number of tokens does not exceed `config.max_tokens`.
* **Pretrained tokenizer**: Uses HuggingFace’s SmolLM tokenizer and automatically handles missing `pad_token`.
* **Streaming dataset loading**: Efficiently loads the SmolLM corpus without overwhelming memory.
* **Truncation of long documents**: Limits document length (e.g., first 3000 characters) to prevent excessively long inputs.
* **Progress tracking**: Displays tokenization progress using `tqdm`.
* **Automatic vocab size update**: Sets `config.vocab_size` based on the tokenizer.
* **Reproducibility & consistency**: Guarantees consistent processed data across runs via caching.

This ensures that training and experimentation are faster, memory-efficient, and reproducible.


In [6]:
def load_and_cache_data(config: ModelConfig, cache_dir: str = "data_cache"):
    """
    Load and cache tokenized data for model training to avoid redundant preprocessing.

    This function ensures efficient data handling by:
    1. Checking for pre-existing cached data to skip reprocessing.
    2. Loading a tokenizer and assigning a pad token if missing.
    3. Streaming a dataset to limit memory usage and control document size.
    4. Tokenizing each document into integer token IDs compatible with the model.
    5. Truncating the total number of tokens to `config.max_tokens` to fit memory limits.
    6. Saving the processed data to a cache file for faster subsequent runs.

    Args:
        config (ModelConfig): Configuration object with dataset and token limits.
        cache_dir (str): Directory where cached tokenized data will be stored.

    Returns:
        texts (List[str]): List of raw text documents.
        tokenizer (PreTrainedTokenizer): HuggingFace tokenizer used to encode text.
        tokens (List[int]): Flattened list of token IDs ready for model input.
    """

    # Ensure the cache directory exists; create if it doesn't
    os.makedirs(cache_dir, exist_ok=True)

    # Construct a cache file path that encodes the dataset size and token limit
    cache_file = f"{cache_dir}/tokenized_data_{config.num_documents}_{config.max_tokens}.pkl"

    # ----------------------------
    # 1. Attempt to load cached data
    # ----------------------------
    if os.path.exists(cache_file):
        print(f"Loading cached data from {cache_file}")
        with open(cache_file, 'rb') as f:
            cached_data = pickle.load(f)

        # Extract cached components
        texts = cached_data['texts']
        tokenizer = cached_data['tokenizer']
        tokens = cached_data['tokens']
        config.vocab_size = tokenizer.vocab_size  # Update model config

        print(f"Loaded {len(texts)} documents and {len(tokens):,} tokens from cache")
        return texts, tokenizer, tokens

    # Cache not found; will process data from scratch
    print("No cached data found; processing new dataset...")

    # ----------------------------
    # 2. Load and configure tokenizer
    # ----------------------------
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")

    # Ensure pad token exists for batch padding; if not, use EOS token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # ----------------------------
    # 3. Stream dataset to avoid memory overload
    # ----------------------------
    dataset = load_dataset(
        "HuggingFaceTB/smollm-corpus",
        "cosmopedia-v2",
        split="train",
        streaming=True
    )

    texts = []
    for i, item in enumerate(dataset):
        if i >= config.num_documents:
            break
        # Truncate each document to 3000 characters to reduce memory usage
        texts.append(item["text"][:3000])

    print(f"Loaded {len(texts)} documents for processing")

    # ----------------------------
    # 4. Tokenize all documents
    # ----------------------------
    print("Tokenizing documents into model-compatible token IDs...")
    all_tokens = []
    for text in tqdm(texts, desc="Tokenizing"):
        # Encode text without special tokens to maintain uniform tokenization
        tokens = tokenizer.encode(text, add_special_tokens=False)
        all_tokens.extend(tokens)

    # ----------------------------
    # 5. Truncate total tokens to fit memory and model constraints
    # ----------------------------
    tokens = all_tokens[:config.max_tokens]
    print(f"Final token count: {len(tokens):,}")
    config.vocab_size = tokenizer.vocab_size  # Update config with tokenizer vocab size

    # ----------------------------
    # 6. Save processed data to cache for future runs
    # ----------------------------
    cached_data = {
        'texts': texts,
        'tokenizer': tokenizer,
        'tokens': tokens
    }

    with open(cache_file, 'wb') as f:
        pickle.dump(cached_data, f)

    print(f"Processed data cached to {cache_file} for future use")

    return texts, tokenizer, tokens


## **7. TextTokenDataset: Sequential Token Dataset for Language Modeling**

The `TextTokenDataset` class provides a simple and efficient way to prepare tokenized text data for training autoregressive language models. It transforms a flat list of token IDs into overlapping input-target sequences suitable for next-token prediction.

**Key Features:**

* Converts a flat token list into sequences of a fixed length (`seq_len`), ready for model input.
* Returns `(x, y)` pairs where `x` is the input sequence and `y` is the corresponding next-token targets.
* Supports efficient slicing without duplicating data in memory.
* Compatible with PyTorch `DataLoader` for batching and shuffling.
* Handles corpus edges gracefully by computing `len(dataset)` as the number of available sequences.

In [7]:
class TextTokenDataset(Dataset):
    """
    PyTorch Dataset for sequential tokenized text.

    This dataset converts a flat list of tokens into overlapping
    input-target sequences suitable for language model training.

    Attributes:
        tokens (List[int]): The complete tokenized text corpus.
        seq_len (int): Length of each input sequence (default: 512).

    Each item returns a pair (x, y):
        x: Tensor of token IDs for the input sequence.
        y: Tensor of token IDs for the next-token targets (shifted by 1).
    """

    def __init__(self, tokens: List[int], seq_len: int = 512):
        self.tokens = tokens
        self.seq_len = seq_len

    def __len__(self):
        # Total number of sequences available in the dataset
        return max(0, len(self.tokens) - self.seq_len)

    def __getitem__(self, idx):
        """
        Returns a single training example.

        Args:
            idx (int): Starting index of the sequence.

        Returns:
            x (torch.LongTensor): Input sequence of length `seq_len`.
            y (torch.LongTensor): Target sequence (input shifted by 1).
        """
        x = torch.tensor(self.tokens[idx:idx + self.seq_len], dtype=torch.long)
        y = torch.tensor(self.tokens[idx + 1:idx + self.seq_len + 1], dtype=torch.long)
        return x, y


## **8. Rotary Position Embeddings (RoPE)**

Rotary Position Embeddings (RoPE) provide a continuous, rotation-based method for encoding positional information in transformer models. Unlike fixed sinusoidal or learned positional embeddings, RoPE applies **2D rotations to token embeddings based on their position**, allowing the model to naturally generalize to sequences longer than those seen during training.

#### ***Key Benefits:***

* **Long-range generalization:** Works well for sequences longer than the training length.
* **Seamless integration:** Encodes position directly in the attention mechanism via rotation matrices.
* **Preserves relative distances:** Positional information is embedded in a way that maintains token-to-token relationships, improving context modeling.

### **RoPE (Rotary Positional Embedding) Math**

$$
\text{For a token embedding } x \in \mathbb{R}^d \text{ at position } p,
\text{ split } x \text{ into pairs } (x_{2i}, x_{2i+1}) \text{ and apply:}
$$

$$
\begin{bmatrix} y_{2i} \\ y_{2i+1} \end{bmatrix}
=
\begin{bmatrix}
\cos \theta_{p,i} & -\sin \theta_{p,i} \\
\sin \theta_{p,i} & \cos \theta_{p,i}
\end{bmatrix}
\begin{bmatrix} x_{2i} \\ x_{2i+1} \end{bmatrix}
$$

$$
\text{where } \theta_{p,i} = p \cdot \omega_i, \quad \text{with} \quad \omega_i = \frac{1}{10000^{2i/d}}
$$

$$
i = 0, 1, \dots, \frac{d}{2}-1, \quad d \text{ is the embedding dimension.}
$$

$$
\text{This rotates each pair of embedding dimensions by an angle proportional to position, preserving the vector norm.}
$$





In [18]:
class Rotary(nn.Module):
    """
    Rotary Positional Embeddings (RoPE) for transformer attention.

    This module encodes relative positional information by applying
    sinusoidal rotations to the query/key embeddings in the attention mechanism.

    Conceptually:
        - Each embedding vector is split into two halves.
        - These halves are rotated in a 2D plane according to a precomputed
          sine and cosine based on the token's position and frequency.
        - This effectively introduces a continuous, relative positional bias
          without requiring absolute positional embeddings.

    Advantages:
        - Captures relative positions naturally.
        - Works well with long sequences.
        - Can be applied directly to Q/K in multi-head attention.

    Args:
        dim (int): Dimensionality of the input embeddings (must be divisible by 4).
        max_seq_len (int): Maximum sequence length to precompute rotation matrices.
    """
    def __init__(self, dim: int, max_seq_len: int):
        super().__init__()
        assert dim % 2 == 0, "Head dim must be divisible by 2"
        half_dim = dim // 2

        # Frequencies
        inv_freq = 1.0 / (10000 ** (torch.arange(0, half_dim, dtype=torch.float32) / half_dim))

        # Positions
        positions = torch.arange(max_seq_len, dtype=torch.float32)

        # Outer product -> [seq_len, half_dim]
        theta = torch.einsum("i,j->ij", positions, inv_freq)

        # Cache sin/cos
        self.register_buffer("cos_cached", theta.cos(), persistent=False)
        self.register_buffer("sin_cached", theta.sin(), persistent=False)

    def forward(self, x: torch.Tensor):
        # x: [batch, seq_len, n_heads, head_dim]
        batch, seq_len, n_heads, head_dim = x.shape
        assert head_dim % 2 == 0
        half_dim = head_dim // 2

        # Slice and expand for broadcasting: [1, seq_len, 1, half_dim]
        cos = self.cos_cached[:seq_len, :].unsqueeze(0).unsqueeze(2)
        sin = self.sin_cached[:seq_len, :].unsqueeze(0).unsqueeze(2)

        # Split last dimension
        x1, x2 = x[..., :half_dim], x[..., half_dim:]

        # Apply RoPE rotation
        x_rot = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
        return x_rot



## **9. Grouped-Query Attention (GQA) Implementation**

The attention mechanism is the core of the Qwen3 model, and **Grouped-Query Attention (GQA)** is a key efficiency innovation. This section implements the attention flow while reducing computation and memory without sacrificing performance.

**Key Steps in the Attention Flow:**

1. **Project Queries, Keys, and Values Separately**
   Each input embedding is projected into query (Q), key (K), and value (V) vectors using dedicated linear layers. This allows fine-grained attention computation per head.

2. **Apply QK Normalization (RMSNorm)**
   Queries and keys are normalized using RMSNorm.

   * RMSNorm stabilizes the magnitude of the vectors while preserving direction.
   * This prevents numerical instabilities during dot-product attention and improves training convergence.

3. **Incorporate Rotary Position Embeddings (RoPE)**
   Q and K vectors are rotated based on token positions.

   * Encodes relative and absolute positions efficiently.
   * Enables the model to generalize to longer sequences than it was trained on.

4. **Grouped-Query Attention (GQA)**
   Keys and values are repeated across groups of query heads.

   * Reduces the number of K/V heads compared to Q heads, saving memory and computation.
   * Maintains expressivity because each query head can attend to a shared key/value representation.

5. **Compute Scaled Dot-Product Attention**
   Attention scores are computed as the scaled dot-product of Q and K, optionally with causal masking for autoregressive generation.

   * The resulting attention weights are applied to the V vectors to produce context-aware outputs.

6. **Final Linear Projection**
   The attended output from all heads is concatenated and projected back to the model dimension, ready for the next transformer block.

**Grouped-Query Attention Flow (GQA)**

$$
\begin{aligned}
& X \in \mathbb{R}^{B \times L \times D} \quad \text{(input embeddings)} \\
& Q = \text{Linear}_Q(X), \quad K = \text{Linear}_K(X), \quad V = \text{Linear}_V(X) \\
& Q' = \text{RMSNorm}(Q), \quad K' = \text{RMSNorm}(K) \\
& Q'' = \text{RoPE}(Q'), \quad K'' = \text{RoPE}(K') \\
& K_\text{gqa}, V_\text{gqa} = \text{RepeatKV}(K'', V, n_\text{kv\_groups}) \\
& \text{Attention} = \text{Softmax}\Bigg(\frac{Q'' K_\text{gqa}^\top}{\sqrt{d_k}}\Bigg) V_\text{gqa} \\
& \text{Output} = \text{Linear}_O(\text{Attention}) \in \mathbb{R}^{B \times L \times D}
\end{aligned}
$$

Where:  
- $B$ = batch size  
- $L$ = sequence length  
- $D$ = hidden dimension of the model  
- $d_k$ = dimension per attention head  
- $n_\text{kv\_groups}$ = number of Key/Value groups in GQA



In [19]:
class Qwen3Attention(nn.Module):
    """
    Qwen3 Multi-Head Attention with Grouped-Query Attention (GQA) and Rotary Position Embeddings (RoPE).

    Key Features:
    - **Grouped-Query Attention (GQA):** Uses fewer Key/Value heads than Query heads to reduce computation while maintaining expressivity.
    - **RMSNorm:** Root Mean Square Layer Normalization stabilizes the magnitude of query and key vectors, preventing numerical issues and helping training stability.
    - **RoPE (Rotary Position Embeddings):** Encodes token positions using sinusoidal rotations of query/key vectors, allowing the model to generalize to longer sequences.
    - **Causal Attention:** Ensures autoregressive behavior, so each token attends only to previous tokens.

    Args:
        config (ModelConfig): Configuration object with attention parameters (number of heads, head dimension, etc.)
    """
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.d_model = config.d_model
        self.num_query_heads = config.n_heads
        self.num_kv_heads = config.n_kv_heads
        self.num_kv_groups = config.n_kv_groups
        self.head_dim = config.d_k

        # Linear projections for queries, keys, values
        self.query_proj = nn.Linear(self.d_model, self.num_query_heads * self.head_dim, bias=config.attention_bias)
        self.key_proj = nn.Linear(self.d_model, self.num_kv_heads * self.head_dim, bias=config.attention_bias)
        self.value_proj = nn.Linear(self.d_model, self.num_kv_heads * self.head_dim, bias=config.attention_bias)
        self.output_proj = nn.Linear(self.d_model, self.d_model, bias=False)

        # RMSNorm for queries and keys
        self.query_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.key_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)

        # Rotary positional embeddings
        self.rotary = Rotary(self.head_dim, config.max_seq_len)
        self.dropout = config.dropout

    def forward(self, hidden_states):
        """
        Compute multi-head attention with GQA, RoPE, and RMSNorm.

        Args:
            hidden_states (torch.Tensor): Input embeddings of shape (batch_size, seq_len, d_model)

        Returns:
            torch.Tensor: Output of attention layer, shape (batch_size, seq_len, d_model)
        """
        batch_size, seq_len, _ = hidden_states.size()

        # 1. Project input embeddings into query, key, value vectors
        queries = self.query_proj(hidden_states)
        keys = self.key_proj(hidden_states)
        values = self.value_proj(hidden_states)

        # 2. Reshape into separate heads
        queries = queries.view(batch_size, seq_len, self.num_query_heads, self.head_dim)
        keys = keys.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
        values = values.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)

        # 3. Apply RMSNorm to stabilize magnitudes
        queries = self.query_norm(queries)  # RMSNorm scales queries to maintain consistent norm
        keys = self.key_norm(keys)

        # 4. Apply Rotary Position Embeddings
        queries = self.rotary(queries.permute(0, 2, 1, 3)).permute(0, 2, 1, 3)
        keys = self.rotary(keys.permute(0, 2, 1, 3)).permute(0, 2, 1, 3)

        # 5. Transpose for attention computation: (batch, num_heads, seq_len, head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # 6. Repeat K and V heads for Grouped-Query Attention
        keys = repeat_kv(keys, self.num_kv_groups)
        values = repeat_kv(values, self.num_kv_groups)

        # 7. Compute scaled dot-product attention with causal masking
        attention_output = F.scaled_dot_product_attention(
            queries, keys, values, is_causal=True, dropout_p=self.dropout if self.training else 0.0
        )

        # 8. Reshape back to (batch, seq_len, d_model)
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

        # 9. Final linear projection
        return self.output_proj(attention_output)


## **10. SwiGLU Feed-Forward Network (FFN)**

SwiGLU (Swish-Gated Linear Unit) is a modern feed-forward network activation that combines the **Swish** activation with a **gating mechanism (GLU)**. Compared to traditional activations like ReLU or GELU, SwiGLU improves expressivity, gradient flow, and selective feature propagation, making it well-suited for transformer-style models such as Qwen3.

Conceptually, the SwiGLU operation can be thought of as:

```
output = gate(x) * value(x)
```

Or, in an intuitive analogy:

```
light = brightness_control × light_source
```

Where:

* `gate(x)` (brightness_control) determines **which features are allowed to pass through**.
* `value(x)` (light_source) represents the **candidate features or information**.
* Multiplying them element-wise allows the network to **selectively amplify or suppress information**, improving learning dynamics and overall model performance.

Absolutely! Here's a concise MathJax version you can include in your README to visualize the SwiGLU gating mechanism:

---

### **SwiGLU Operation (Mathematical Form)**

$$
\text{Output} = \text{Gate}(x) \odot \text{Value}(x)
$$

Where:

$$
\text{Gate}(x) = \text{Swish}(W_g x), \quad
\text{Value}(x) = W_v x
$$

* $(W_g)$ and $(W_v)$ are learnable projection matrices.
* $(\odot)$ denotes element-wise multiplication.
* Swish activation, with $(\sigma)$ the sigmoid function.: $$(\text{Swish}(z) = z \cdot \sigma(z))$$

In [20]:
class SwiGLUFeedForward(nn.Module):
    """
    SwiGLU Feed-Forward Network (FFN)

    SwiGLU (Swish-Gated Linear Unit) is a modern variant of the feed-forward
    network used in transformers. It combines the Swish activation with a gating
    mechanism (GLU) to improve expressivity and gradient flow compared to ReLU or GELU.

    Conceptually:
        output = gate(x) * value(x)
    Think of it like:
        light = brightness_control × light_source

    Gating Mechanism Explanation:
    ---------------------------------
    - The gate pathway controls which elements of the "value" pathway are allowed
      to pass through.
    - gate(x) = Swish(W_gate * x) produces a vector with values typically in [0, 1]
      that act like "valves" on the information flow.
    - value(x) = W_up * x represents the candidate activations (information to propagate).
    - By multiplying gate(x) * value(x) element-wise, the network can selectively
      suppress or amplify features dynamically based on the input.
    - This improves gradient flow, prevents saturation, and allows the FFN to model
      richer interactions compared to traditional activations like ReLU.

    Input:
        x: Tensor of shape (batch_size, seq_len, d_model)

    Output:
        Tensor of shape (batch_size, seq_len, d_model)
        Same shape as input, suitable for residual connections in transformer blocks.
    """


    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        # Project input to feed-forward dimension for the gating mechanism
        self.gate_proj = nn.Linear(d_model, d_ff, bias=False)

        # Linear layer that produces the "value" path
        self.up_proj = nn.Linear(d_model, d_ff, bias=False)

        # Linear layer to project back to model dimension
        self.down_proj = nn.Linear(d_ff, d_model, bias=False)

        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass:
        1. Compute gate values using Swish (F.silu)
        2. Multiply gate values with the "value" projection
        3. Apply dropout for regularization
        4. Project back to model dimension
        """
        # Compute gate activations
        gate_activations = F.silu(self.gate_proj(x))

        # Compute value projections
        value_activations = self.up_proj(x)

        # Element-wise gating
        gated_output = gate_activations * value_activations

        # Apply dropout and project back
        return self.down_proj(self.dropout(gated_output))


## **11. Transformer Block**

Each Transformer block integrates multi-head attention and a feed-forward network (FFN), connected through residual pathways and pre-normalization layers to ensure stable gradient flow during deep training.

Unlike traditional implementations that use LayerNorm, this model employs RMSNorm (Root Mean Square Normalization), which normalizes activations based on their RMS magnitude rather than their mean and variance. This approach improves numerical stability, training efficiency, and scalability, especially in large or low-precision models.


In [21]:
class TransformerBlock(nn.Module):
    """
    Transformer Block with Grouped-Query Attention (GQA) and SwiGLU Feed-Forward Network.

    This block integrates attention and feed-forward computation in a pre-normalized
    residual architecture for stable deep training. Each block performs:

        Input
          → RMSNorm                    # Normalizes activations by their root-mean-square value
                                       # Ensures stable feature scaling before attention
                                       # Helps gradients stay well-conditioned in deep networks

          → Multi-Head GQA (with RoPE) # Applies Grouped-Query Attention
                                       # - Projects input into queries, keys, and values
                                       # - Uses shared key/value projections for efficiency
                                       # - Injects positional context using RoPE (rotational encoding)
                                       # Outputs context-aware representations per token

          → Dropout + Residual Add     # Combines attention output with the original input (residual path)
                                       # - Dropout regularizes to prevent overfitting
                                       # - Residual connection preserves original information flow

          → RMSNorm                    # Normalizes again before feed-forward processing
                                       # Prepares stable activations for the non-linear transformation

          → SwiGLU Feed-Forward Network# Applies two-layer feed-forward MLP with SwiGLU activation
                                       # - SwiGLU uses a gated Swish: (x * σ(W₁x))W₂
                                       # - Expands hidden dimension (d_ff), then projects back to d_model
                                       # - Increases representational capacity and non-linearity

          → Dropout + Residual Add     # Adds the feed-forward output back to the residual stream
                                       # - Dropout regularizes the FFN output
                                       # - Residual path ensures smoother gradient propagation

          → Output                     # Returns the updated hidden states
                                       # Now encoded with both contextual (attention) and
                                       # nonlinear (feed-forward) information, ready for the next block

    Key Features:
        • **Grouped-Query Attention (GQA)** — shares key/value projections across query groups,
          reducing memory and computation while preserving representational power.
        • **RoPE (Rotary Positional Embedding)** — encodes absolute and relative positional
          information directly in query/key space via rotational transformations.
        • **SwiGLU Feed-Forward Network** — uses a gated Swish activation for enhanced
          nonlinearity and expressivity over ReLU or GELU.
        • **RMSNorm (Root Mean Square Normalization)** — normalizes activations by their RMS
          magnitude, improving numerical stability and training efficiency compared to LayerNorm.
        • **Pre-Norm Residual Structure** — applies normalization before sublayers to ensure
          gradient flow and stable convergence in deep transformer stacks.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model)

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, seq_len, d_model)
    """
    def __init__(self, config: ModelConfig):
        super().__init__()

        # -----------------------------
        # Attention Block Components
        # -----------------------------

        # Multi-head attention with Grouped-Query Attention (GQA)
        # - QK-normalization stabilizes the queries and keys
        # - RoPE encodes positional information
        self.attention = Qwen3Attention(config)

        # RMSNorm applied before attention (pre-norm)
        # - Normalizes input across feature dimension
        # - Helps with gradient stability in deep networks
        self.attn_norm = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps)

        # -----------------------------
        # Feed-Forward Block Components
        # -----------------------------

        # Feed-forward network using SwiGLU activation
        # - SwiGLU = Swish gating mechanism × linear projection
        # - Improves expressivity compared to ReLU/GELU
        self.feed_forward = SwiGLUFeedForward(config.d_model, config.d_ff, config.dropout)

        # RMSNorm applied before feed-forward (pre-norm)
        # - Stabilizes input to FFN
        self.ffn_norm = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps)

        # Dropout applied to residual connections for regularization
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the Transformer block.

        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)

        Returns:
            Output tensor of shape (batch_size, seq_len, d_model)
        """

        # -----------------------------
        # 1. Attention Block
        # -----------------------------

        # Normalize input features before attention
        normed_input = self.attn_norm(x)  # shape: (batch_size, seq_len, d_model)

        # Compute attention output:
        # - Queries, keys, values are projected and normalized
        # - Rotary embeddings (RoPE) add positional information
        # - Grouped-Query Attention (GQA) allows KV sharing for efficiency
        attention_output = self.attention(normed_input)  # shape: (batch_size, seq_len, d_model)

        # Apply residual connection and dropout:
        # - Residual preserves original input information
        # - Dropout helps prevent overfitting
        x = x + self.dropout(attention_output)

        # -----------------------------
        # 2. Feed-Forward Block
        # -----------------------------

        # Normalize the output of attention block before feed-forward
        normed_residual = self.ffn_norm(x)  # shape: (batch_size, seq_len, d_model)

        # Compute feed-forward output:
        # - SwiGLU applies gating mechanism (Swish × linear)
        # - Expands features to d_ff dimension and projects back to d_model
        ffn_output = self.feed_forward(normed_residual)  # shape: (batch_size, seq_len, d_model)

        # Apply residual connection and dropout:
        # - Preserves input to feed-forward network
        # - Dropout regularizes FFN output
        x = x + self.dropout(ffn_output)

        # Output has same shape as input, ready for next Transformer block
        return x


## **12. Complete Language Model**

This stage integrates all the components into a cohesive **Transformer-based language model**. The model encodes token sequences, processes them through stacked attention and feed-forward layers, and outputs logits for next-token prediction.

**Architecture Overview:**

* **Token Embeddings:** Convert discrete tokens into continuous vector representations.
* **Positional Dropout:** Adds regularization to improve generalization and robustness.
* **Stacked Transformer Blocks:** Each block refines contextual understanding through attention and non-linear transformations.
* **Final RMSNorm & Output Projection:** Normalize the final representations and project them back to the vocabulary dimension.
* **Weight Tying:** Shares weights between input embeddings and output projection to reduce parameters and improve consistency between encoding and decoding spaces.


In [22]:
class MinimalLLM(nn.Module):
    """
    Minimal Transformer-based Language Model.

    This module implements a compact, decoder-only Transformer architecture
    capturing the essential design of modern large language models (LLMs).
    It includes token embeddings, stacked self-attention and feed-forward
    blocks, normalization layers, and a tied output projection head.

    **Model Overview**
    1. **Token Embedding:** Converts input token IDs into dense vector representations.
    2. **Positional Dropout:** Adds stochastic regularization to embeddings before encoding.
    3. **Transformer Stack:** Sequentially applies multiple attention + feed-forward
       blocks to model contextual dependencies and hierarchical semantics.
    4. **Final RMSNorm:** Normalizes the hidden states for improved numerical stability.
    5. **Output Projection (Weight-Tied):** Maps the final hidden states back to
       vocabulary logits using shared weights from the input embeddings.

    Designed for clarity and educational readability while preserving
    architectural fidelity to modern decoder-only LLMs.
    """

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        # -------------------------------------------------------
        # 1. Embedding Layer
        # -------------------------------------------------------

        # Token embeddings map discrete vocabulary indices to dense vectors.
        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)

        # Optional dropout on embeddings — helps regularize positional patterns
        self.position_dropout = nn.Dropout(config.dropout)

        # -------------------------------------------------------
        # 2. Transformer Backbone
        # -------------------------------------------------------

        # Stack of N identical Transformer blocks (each with attention + FFN)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layers)
        ])

        # -------------------------------------------------------
        # 3. Output Normalization and Projection
        # -------------------------------------------------------

        # Final RMSNorm stabilizes activations before output projection.
        self.norm = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps)

        # Dropout before logits for additional regularization.
        self.output_dropout = nn.Dropout(config.dropout)

        # -------------------------------------------------------
        # 4. Output Head (Weight Tying)
        # -------------------------------------------------------

        # Linear layer projects hidden states back to vocabulary space.
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # Weight tying — reuse token embedding weights for output projection.
        # This reduces parameter count and improves generalization,
        # ensuring that the input and output token spaces share semantics.
        self.lm_head.weight = self.token_embedding.weight

        # -------------------------------------------------------
        # 5. Parameter Initialization
        # -------------------------------------------------------

        # Initialize weights with small Gaussian noise for stable training.
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Applies standard normal initialization to Linear and Embedding layers."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the minimal LLM.

        Args:
            x (torch.Tensor): Token indices of shape (batch_size, seq_len).

        Returns:
            torch.Tensor: Logits of shape (batch_size, seq_len, vocab_size),
            representing unnormalized probabilities for each vocabulary token.
        """

        # 1. Token Embedding + Scaling
        # Scale embeddings by sqrt(d_model) to keep variance consistent.
        x = self.token_embedding(x) * math.sqrt(self.config.d_model)

        # 2. Embedding Dropout
        x = self.position_dropout(x)

        # 3. Transformer Stack
        # Sequentially apply each attention–feedforward block.
        for block in self.transformer_blocks:
            x = block(x)

        # 4. Final Normalization + Dropout
        x = self.norm(x)
        x = self.output_dropout(x)

        # 5. Output Projection (shared weights)
        logits = self.lm_head(x)

        return logits


## **13. Evaluation Function**

During training, it is essential to periodically assess the model’s performance on a **validation dataset**.
This evaluation function computes the following key metrics:

* **Loss:** Average token-level cross-entropy, measuring how well the model predicts the next token.
* **Accuracy:** Fraction of tokens correctly predicted, providing an intuitive performance measure.
* **Perplexity:** Exponential of the loss, commonly used to quantify uncertainty in language modeling — lower perplexity indicates better predictive performance.

The function is designed to be **efficient** by:

* Disabling gradient computation to save memory and computation time.
* Using **mixed precision** if enabled for faster evaluation.
* Limiting the number of evaluation steps (`config.eval_steps`) to control runtime on large datasets.

In [23]:
def evaluate_model(model: nn.Module, val_loader: DataLoader, config: ModelConfig):
    """
    Evaluate a trained language model on a validation dataset.

    **Evaluation Steps:**
    1. Switch the model to evaluation mode (disables dropout, etc.)
    2. Iterate over validation batches (limited by `config.eval_steps`)
    3. Move inputs and targets to the model device (CPU/GPU)
    4. Forward pass through the model (mixed precision if enabled)
    5. Compute token-level cross-entropy loss
    6. Accumulate total loss, correct predictions, and token counts
    7. Compute average metrics: loss, accuracy, and perplexity
    8. Restore the model to training mode

    Returns:
        dict: {
            'val_loss': float,
            'val_accuracy': float,
            'val_perplexity': float
        }
    """
    # -----------------------------
    # 1. Switch to evaluation mode
    # -----------------------------
    model.eval()

    # Initialize running totals
    total_loss = 0.0          # Total cross-entropy loss
    total_tokens = 0          # Total number of tokens processed
    total_correct = 0         # Number of correctly predicted tokens

    # Determine which device the model is on (CPU/GPU)
    device = next(model.parameters()).device

    # -----------------------------
    # 2. Iterate over validation batches
    # -----------------------------
    with torch.no_grad():  # Disable gradients for efficiency
        for step, (x, y) in enumerate(val_loader):

            # Limit evaluation to a maximum number of steps
            if step >= config.eval_steps:
                break

            # -----------------------------
            # 3. Move batch to device
            # -----------------------------
            x, y = x.to(device), y.to(device)

            # -----------------------------
            # 4. Forward pass with mixed precision
            # -----------------------------
            with autocast(enabled=config.use_amp):
                logits = model(x)

                # -----------------------------
                # 5. Compute cross-entropy loss
                # -----------------------------
                # Flatten batch and sequence dimensions for proper token-wise loss
                loss = F.cross_entropy(
                    logits.view(-1, config.vocab_size),
                    y.view(-1),
                    reduction='mean'
                )

            # -----------------------------
            # 6. Accumulate metrics
            # -----------------------------
            batch_tokens = y.numel()              # Number of tokens in this batch
            total_loss += loss.item() * batch_tokens  # Weighted sum of loss
            total_tokens += batch_tokens

            # Token-level predictions (argmax over vocab dimension)
            predictions = logits.argmax(dim=-1)
            total_correct += (predictions == y).sum().item()

    # -----------------------------
    # 7. Compute average metrics
    # -----------------------------
    avg_loss = total_loss / total_tokens
    accuracy = total_correct / total_tokens
    perplexity = math.exp(min(avg_loss, 20))  # Clamp to prevent overflow

    # -----------------------------
    # 8. Restore training mode
    # -----------------------------
    model.train()

    return {
        'val_loss': avg_loss,
        'val_accuracy': accuracy,
        'val_perplexity': perplexity
    }


## **14. Optimizer Setup**

The model employs a **hybrid optimization strategy** to leverage the strengths of two different optimizers:

* **Muon Optimizer:** Applied to 2D parameters (primarily linear layer weights in attention and feed-forward blocks).

  * Performs **momentum-orthogonalized updates** to stabilize training and improve convergence for large weight matrices.

* **AdamW:** Applied to all remaining parameters (embeddings, biases, normalization layers).

  * Provides standard **adaptive learning rate optimization** and weight decay regularization.

**Step-by-Step Flow:**

```
Model Parameters
      ↓
Separate by type
      ├─ 2D Linear Weights → Muon Optimizer → Orthogonalized Momentum Updates
      └─ Embeddings / Biases / Norms → AdamW → Adaptive Learning Rate Updates
      ↓
Training Updates Applied to Model
```

This hybrid approach allows the model to benefit from **Muon’s stability for large weight matrices** while **AdamW efficiently handles smaller or non-square parameters**, combining the advantages of both optimizers in a single training run.


In [24]:
def setup_muon_optimizer(model: nn.Module, config: ModelConfig):
    """
    Configure a hybrid optimizer setup combining Muon and AdamW.

    Certain parameters (mostly weight matrices in linear layers) are optimized
    using the Muon optimizer, which applies momentum orthogonalization for better
    convergence and stability. The remaining parameters (embeddings, biases, norms)
    are optimized using standard AdamW.

    Steps:
    1. Separate model parameters into Muon vs AdamW groups.
    2. Print parameter counts for transparency.
    3. Instantiate Muon optimizer for applicable parameters.
    4. Instantiate AdamW optimizer for remaining parameters.
    5. Return both optimizers in a list for joint training.
    """

    # Lists to hold parameters for each optimizer
    muon_params = []  # Typically weight matrices in linear layers
    adamw_params = [] # Embeddings, biases, normalization layers, etc.

    # -----------------------------
    # 1. Separate parameters
    # -----------------------------
    for name, param in model.named_parameters():
        # Criteria for Muon optimizer:
        # - 2D tensors (linear layer weights)
        # - Not embeddings or normalization layers
        # - Must require gradients
        if (param.ndim == 2 and
            'token_embedding' not in name and
            'norm' not in name and
            param.requires_grad):
            muon_params.append(param)
        else:
            adamw_params.append(param)

    # -----------------------------
    # 2. Print parameter counts
    # -----------------------------
    # Provides quick transparency for debugging and reporting
    print(f"  Muon parameters: {sum(p.numel() for p in muon_params):,}")
    print(f"  AdamW parameters: {sum(p.numel() for p in adamw_params):,}")

    # -----------------------------
    # 3. Instantiate Muon optimizer
    # -----------------------------
    # Momentum-orthogonalized updates for linear weight matrices
    muon_optimizer = Muon(
        muon_params,
        lr=config.muon_lr,      # Muon-specific learning rate
        momentum=0.95
    )

    # -----------------------------
    # 4. Instantiate AdamW optimizer
    # -----------------------------
    # Standard AdamW for embeddings, biases, and normalization layers
    adamw_optimizer = torch.optim.AdamW(
        adamw_params,
        lr=config.muon_lr * 0.1,  # Smaller learning rate than Muon
        weight_decay=config.weight_decay
    )

    # -----------------------------
    # 5. Return both optimizers
    # -----------------------------
    return [muon_optimizer, adamw_optimizer]


## **15. Training Loop**

The training loop orchestrates the end-to-end learning process for the language model.
It integrates **forward/backward passes, gradient accumulation, optimizer updates, learning rate scheduling, logging, evaluation, and checkpointing**.

### **Step-by-Step Flow**

```
Initialize Model → Move to Device → Setup Optimizers & Schedulers → Configure Mixed Precision
      ↓
Iterate over Training Batches
      ↓
Forward Pass → Compute Cross-Entropy Loss → Scale Loss (if using AMP)
      ↓
Backward Pass → Gradient Accumulation
      ↓
Gradient Clipping → Optimizer Step → Scheduler Step → Reset Gradients
      ↓
Periodic Logging → Loss / Accuracy / Perplexity
      ↓
Periodic Validation → Evaluate on Val Set → Save Best Model
      ↓
Repeat Until Max Steps Reached
      ↓
Final Evaluation → Save Final Model Checkpoint
```

### **Key Features**

* **Hybrid Optimizer:** Uses **Muon** for 2D weight matrices and **AdamW** for remaining parameters.
* **Mixed Precision (AMP):** Optional acceleration of training with minimal loss in numerical precision.
* **Gradient Accumulation:** Allows effective large-batch training even with limited GPU memory.
* **Learning Rate Scheduling:** Linear warmup followed by cosine decay for stable convergence.
* **Logging & Metrics:** Tracks token-level loss, accuracy, and perplexity every few steps.
* **Checkpointing:** Saves both the **best validation model** and the **final model**.

This loop ensures **efficient, stable, and reproducible training** of the MinimalLLM language model.

---

If you want, I can also make a **visual “full training pipeline” diagram** that combines **embedding → Transformer blocks → optimizer → evaluation → checkpointing** in a single cohesive flow, so it matches all your previous README sections visually.

Do you want me to do that next?


In [25]:
def train_model(config: ModelConfig, train_loader: DataLoader, val_loader: DataLoader):
    """
    Train a MinimalLLM model using a hybrid Muon + AdamW optimizer with optional mixed-precision.

    This function handles all steps of training from initialization to final evaluation and model saving.
    It is designed for clarity, stability, and reproducibility, with advanced features such as:
        - Gradient accumulation for effective large-batch training.
        - Mixed-precision (AMP) for faster computation with minimal memory footprint.
        - Hybrid optimizer setup (Muon for 2D weights, AdamW for others).
        - Cosine learning rate schedule with linear warmup.
        - Gradient clipping to avoid exploding gradients.
        - Periodic evaluation and best model checkpointing.
        - Logging of loss, token-level accuracy, and perplexity for monitoring.

    -----------------------------
    Training Steps Overview
    -----------------------------
    1. Initialize the MinimalLLM model and set a random seed for reproducibility.
    2. Move the model to the correct device (GPU if available, otherwise CPU).
    3. Count total trainable parameters for reference.
    4. Setup the hybrid optimizers:
        - Muon optimizer for 2D weight matrices (attention and feed-forward layers)
        - AdamW optimizer for embeddings, biases, and normalization layers
    5. Configure learning rate schedulers with linear warmup + cosine decay.
    6. Setup mixed-precision training (if enabled via config.use_amp) using torch.cuda.amp.GradScaler.
    7. Iterate through training batches:
        a) Move input (x) and target (y) tensors to device.
        b) Forward pass through the model to get logits.
        c) Compute cross-entropy loss and normalize it for gradient accumulation.
        d) Backward pass:
            - Use scaler.scale() for AMP
            - Accumulate gradients over multiple steps if gradient_accumulation_steps > 1
        e) Gradient clipping to avoid exploding gradients.
        f) Optimizer step and zero gradients after accumulation.
        g) Step the learning rate schedulers.
        h) Log training metrics every few steps:
            - Loss
            - Token-level accuracy
            - Perplexity
            - Learning rate
        i) Periodically evaluate on validation set and save best model checkpoint.
    8. After completing all training steps:
        - Perform final evaluation on validation set.
        - Save the final model checkpoint.

    Args:
        config (ModelConfig): Configuration object containing model and training hyperparameters.
        train_loader (DataLoader): PyTorch DataLoader for training dataset.
        val_loader (DataLoader): PyTorch DataLoader for validation dataset.

    Returns:
        model (nn.Module): Trained MinimalLLM model.
        final_eval (dict): Dictionary containing final evaluation metrics:
            - val_loss
            - val_accuracy
            - val_perplexity
    """

    print(f"\n🚀 Training Small model with Muon optimizer")

     # -----------------------------
    # 1. Initialize model
    # -----------------------------
    set_seed(42)  # Fix random seed for reproducibility of results
    model = MinimalLLM(config)  # Instantiate the MinimalLLM model with given configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Select GPU if available
    model = model.to(device)  # Move model parameters to selected device

    # Count and print total number of trainable parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"  📊 Total parameters: {total_params:,}")

    # -----------------------------
    # 2. Setup optimizers and schedulers
    # -----------------------------
    optimizers = setup_muon_optimizer(model, config)  # Create Muon + AdamW hybrid optimizers

    schedulers = []
    for optimizer in optimizers:
        warmup_steps = config.max_steps // 20  # Set warmup for first 5% of steps

        def lr_lambda(step):
            # Linear warmup for first `warmup_steps` steps
            if step < warmup_steps:
                return step / warmup_steps
            else:
                # Cosine decay for remaining steps
                progress = (step - warmup_steps) / (config.max_steps - warmup_steps)
                return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)  # Attach scheduler
        schedulers.append(scheduler)

    # -----------------------------
    # 3. Setup mixed precision (optional)
    # -----------------------------
    scaler = GradScaler() if config.use_amp else None  # Enable AMP for faster computation if required

    # -----------------------------
    # 4. Training loop setup
    # -----------------------------
    model.train()  # Set model to training mode
    step = 0  # Initialize global step counter
    start_time = time.time()  # Start timing the training
    best_val_loss = float('inf')  # Initialize best validation loss for checkpointing

    pbar = tqdm(total=config.max_steps, desc="Training")  # Create progress bar

    # -----------------------------
    # 5. Iterate over batches
    # -----------------------------
    while step < config.max_steps:
        for batch_idx, (x, y) in enumerate(train_loader):
            if step >= config.max_steps:
                break  # Stop if max training steps reached

            # Move input and target tensors to the correct device
            x, y = x.to(device), y.to(device)

            # -----------------------------
            # Forward and backward pass
            # -----------------------------
            if config.use_amp:  # Use mixed precision
                with autocast():  # Enable automatic mixed precision
                    logits = model(x)  # Forward pass
                    # Compute cross-entropy loss and normalize for gradient accumulation
                    loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))
                    loss = loss / config.gradient_accumulation_steps
                scaler.scale(loss).backward()  # Scale loss and backpropagate
            else:
                logits = model(x)  # Forward pass
                loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))
                loss = loss / config.gradient_accumulation_steps
                loss.backward()  # Backpropagation

            # -----------------------------
            # Optimizer step after gradient accumulation
            # -----------------------------
            if (step + 1) % config.gradient_accumulation_steps == 0:
                # Clip gradients to prevent exploding gradients
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)

                if config.use_amp:
                    for optimizer in optimizers:
                        scaler.unscale_(optimizer)  # Unscale gradients before optimizer step
                        scaler.step(optimizer)  # Apply optimizer step
                        optimizer.zero_grad()  # Reset gradients
                    for scheduler in schedulers:
                        scheduler.step()  # Update learning rate
                    scaler.update()  # Update scaler for next iteration
                else:
                    for optimizer in optimizers:
                        optimizer.step()  # Apply optimizer step
                        optimizer.zero_grad()  # Reset gradients
                    for scheduler in schedulers:
                        scheduler.step()  # Update learning rate

            # -----------------------------
            # Logging
            # -----------------------------
            if step % 10 == 0:
                with torch.no_grad():  # Disable gradient computation for metrics
                    predictions = logits.argmax(dim=-1)  # Predicted token IDs
                    accuracy = (predictions == y).float().mean().item()  # Token-level accuracy
                    current_loss = loss.item() * config.gradient_accumulation_steps  # Rescale loss
                    perplexity = math.exp(min(current_loss, 20))  # Compute perplexity

                # Update progress bar with metrics
                pbar.set_postfix({
                    'loss': f'{current_loss:.4f}',
                    'acc': f'{accuracy:.3f}',
                    'ppl': f'{perplexity:.1f}',
                    'lr': f'{optimizers[0].param_groups[0]["lr"]:.2e}'
                })

            # -----------------------------
            # Periodic evaluation and checkpointing
            # -----------------------------
            if step % config.eval_every == 0 and step > 0:
                eval_metrics = evaluate_model(model, val_loader, config)  # Evaluate on validation set
                print(f"\nStep {step}: Val Loss: {eval_metrics['val_loss']:.4f}, "
                      f"Val Acc: {eval_metrics['val_accuracy']:.4f}, "
                      f"Val PPL: {eval_metrics['val_perplexity']:.2f}")

                # Save checkpoint if validation loss improves
                if eval_metrics['val_loss'] < best_val_loss:
                    best_val_loss = eval_metrics['val_loss']
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'config': config,
                        'step': step,
                        'best_val_loss': best_val_loss,
                        'final_metrics': eval_metrics
                    }, 'best_model.pt')
                    print(f"💾 Saved best model with val_loss: {best_val_loss:.4f}")

            step += 1  # Increment global step
            if step % 10 == 0:
                pbar.update(10)  # Update progress bar

    pbar.close()  # Close progress bar at the end of training

    # -----------------------------
    # Final evaluation and saving
    # -----------------------------
    training_time = time.time() - start_time
    print(f"  ⏱️ Training completed in {training_time:.1f} seconds")

    final_eval = evaluate_model(model, val_loader, config)  # Evaluate final model
    print(f"  📊 Final - Loss: {final_eval['val_loss']:.4f}, "
          f"Acc: {final_eval['val_accuracy']:.4f}, PPL: {final_eval['val_perplexity']:.2f}")

    # Save final model checkpoint
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': config,
        'step': step,
        'final_metrics': final_eval
    }, 'final_model.pt')
    print(f"💾 Saved final model to final_model.pt")

    return model, final_eval  # Return trained model and final evaluation metrics

## **16. Main Training Script**

This script orchestrates the end-to-end training of the **MinimalLLM** model.

**Step-by-Step Overview:**

1. **Check System and Device**

   * Detects whether GPU (CUDA) is available; prints GPU name and memory if applicable.

2. **Set Random Seed**

   * Ensures reproducible results for model initialization, data shuffling, and training.

3. **Create Model Configuration**

   * Initializes a `ModelConfig` object with architecture and training hyperparameters.
   * Prints a summary of model size, training steps, batch size, and dataset details.

4. **Load and Preprocess Data**

   * Tokenizes raw text and converts it to token IDs.
   * Wraps tokens in a PyTorch `TextTokenDataset` for easy batching.

5. **Train/Validation Split**

   * Splits dataset into 90% training and 10% validation.
   * Uses a fixed random seed for reproducibility.

6. **Create PyTorch DataLoaders**

   * Wraps datasets in `DataLoader` for batched training and evaluation.
   * Enables shuffling for training, disables shuffling for validation.

7. **Train the Model**

   * Calls `train_model()` to run the full training loop with hybrid Muon + AdamW optimizer, gradient accumulation, optional mixed precision, logging, and checkpointing.

8. **Print Final Training Summary**

   * Reports total training time.
   * Displays final validation metrics: loss, token-level accuracy, and perplexity.

This main script provides a reproducible and fully automated pipeline for training a small transformer-based language model on your dataset.

In [26]:
if __name__ == "__main__":
    # -----------------------------
    # 1. Check system and device
    # -----------------------------
    device_name = 'CUDA' if torch.cuda.is_available() else 'CPU'
    print(f"🔍 Device: {device_name}")  # Show whether GPU or CPU will be used

    if torch.cuda.is_available():
        # Print GPU name
        print(f"GPU: {torch.cuda.get_device_name()}")
        # Print total GPU memory in GB
        total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"Memory: {total_mem_gb:.1f} GB")

    # -----------------------------
    # 2. Set random seed
    # -----------------------------
    set_seed(42)  # Ensures reproducibility of model initialization, data shuffling, etc.

    # -----------------------------
    # 3. Create model configuration
    # -----------------------------
    config = ModelConfig()  # Initialize default configuration for the Small MinimalLLM
    print(f"\n📋 Model Configuration:")
    print(f"   Architecture: {config.d_model}d, {config.n_layers}L, {config.n_heads}H, {config.d_ff}ff")
    print(f"   Training: {config.max_steps} steps, batch size {config.batch_size}")
    print(f"   Data: {config.max_tokens:,} tokens, seq_len {config.max_seq_len}")

    # -----------------------------
    # 4. Load and preprocess data
    # -----------------------------
    # load_and_cache_data() should return:
    # - texts: raw text strings
    # - tokenizer: tokenizer object for encoding/decoding
    # - tokens: list of token ids for all text
    texts, tokenizer, tokens = load_and_cache_data(config)

    # Wrap token ids in a PyTorch dataset for easy batching
    dataset = TextTokenDataset(tokens, config.max_seq_len)

    # -----------------------------
    # 5. Train/validation split
    # -----------------------------
    val_size = len(dataset) // 10  # 10% for validation
    train_size = len(dataset) - val_size  # 90% for training

    # Random split with fixed seed for reproducibility
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)
    )

    # -----------------------------
    # 6. Create PyTorch DataLoaders
    # -----------------------------
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,  # Shuffle training data each epoch
        num_workers=2  # Number of subprocesses for data loading
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,  # No need to shuffle validation data
        num_workers=2
    )

    print(f"📊 Dataset: {len(train_dataset)} train, {len(val_dataset)} val samples")

    # -----------------------------
    # 7. Train the model
    # -----------------------------
    start_time = time.time()  # Record start time for total training duration
    model, final_metrics = train_model(config, train_loader, val_loader)  # Call training function
    total_time = time.time() - start_time  # Compute total training time

    # -----------------------------
    # 8. Print summary of training
    # -----------------------------
    print(f"\n🎉 TRAINING COMPLETED!")
    print(f"⏱️ Total time: {total_time/60:.1f} minutes")  # Convert seconds to minutes
    print(f"🏆 Final Results:")
    print(f"   Validation Loss: {final_metrics['val_loss']:.4f}")  # Cross-entropy loss
    print(f"   Validation Accuracy: {final_metrics['val_accuracy']:.4f}")  # Token-level accuracy
    print(f"   Validation Perplexity: {final_metrics['val_perplexity']:.2f}")  # Exponential of loss


🔍 Device: CUDA
GPU: Tesla T4
Memory: 15.8 GB
Set all seeds to 42

📋 Model Configuration:
   Architecture: 384d, 6L, 8H, 1536ff
   Training: 2000 steps, batch size 24
   Data: 500,000 tokens, seq_len 512
Loading cached data from data_cache/tokenized_data_2000_500000.pkl
Loaded 2000 documents and 500,000 tokens from cache
📊 Dataset: 449540 train, 49948 val samples

🚀 Training Small model with Muon optimizer
Set all seeds to 42
  📊 Total parameters: 32,150,976
  Muon parameters: 13,271,040
  AdamW parameters: 18,879,936




Training:   0%|          | 0/2000 [00:00<?, ?it/s][A[A

Training:   0%|          | 0/2000 [39:50<?, ?it/s]
W1020 14:41:34.085000 1316 torch/_inductor/utils.py:1436] [0/0] Not enough SMs to use max_autotune_gemm mode


Training:   0%|          | 10/2000 [00:06<22:25,  1.48it/s, loss=10.8048, acc=0.015, ppl=49254.8, lr=0.00e+00][A[A

Training:   0%|          | 10/2000 [00:07<22:25,  1.48it/s, loss=10.8061, acc=0.014, ppl=49318.8, lr=2.00e-04][A[A

Training:   1%|          | 20/2000 [00:10<16:25,  2.01it/s, loss=10.8061, acc=0.014, ppl=49318.8, lr=2.00e-04][A[A

Training:   1%|          | 20/2000 [00:10<16:25,  2.01it/s, loss=10.7881, acc=0.016, ppl=48441.5, lr=5.00e-04][A[A

Training:   2%|▏         | 30/2000 [00:13<13:59,  2.35it/s, loss=10.7881, acc=0.016, ppl=48441.5, lr=5.00e-04][A[A

Training:   2%|▏         | 30/2000 [00:14<13:59,  2.35it/s, loss=10.7615, acc=0.016, ppl=47169.0, lr=7.00e-04][A[A

Training:   2%|▏         | 40/2000 [00:17<13:17,  2.46it/s, loss=10.76


Step 500: Val Loss: 5.9331, Val Acc: 0.2099, Val PPL: 377.32
💾 Saved best model with val_loss: 5.9331




Training:  26%|██▌       | 510/2000 [03:18<18:06,  1.37it/s, loss=6.1523, acc=0.196, ppl=469.8, lr=1.00e-02][A[A

Training:  26%|██▌       | 510/2000 [03:19<18:06,  1.37it/s, loss=6.0776, acc=0.194, ppl=436.0, lr=1.00e-02][A[A

Training:  26%|██▌       | 520/2000 [03:22<15:21,  1.61it/s, loss=6.0776, acc=0.194, ppl=436.0, lr=1.00e-02][A[A

Training:  26%|██▌       | 520/2000 [03:22<15:21,  1.61it/s, loss=5.8630, acc=0.209, ppl=351.8, lr=9.99e-03][A[A

Training:  26%|██▋       | 530/2000 [03:25<13:10,  1.86it/s, loss=5.8630, acc=0.209, ppl=351.8, lr=9.99e-03][A[A

Training:  26%|██▋       | 530/2000 [03:26<13:10,  1.86it/s, loss=5.9263, acc=0.210, ppl=374.7, lr=9.99e-03][A[A

Training:  27%|██▋       | 540/2000 [03:29<11:53,  2.05it/s, loss=5.9263, acc=0.210, ppl=374.7, lr=9.99e-03][A[A

Training:  27%|██▋       | 540/2000 [03:29<11:53,  2.05it/s, loss=5.8789, acc=0.218, ppl=357.4, lr=9.99e-03][A[A

Training:  28%|██▊       | 550/2000 [03:32<10:44,  2.25it/s, loss=5.87


Step 1000: Val Loss: 4.6150, Val Acc: 0.3280, Val PPL: 100.99
💾 Saved best model with val_loss: 4.6150




Training:  50%|█████     | 1010/2000 [06:29<12:16,  1.34it/s, loss=4.9164, acc=0.287, ppl=136.5, lr=9.86e-03][A[A

Training:  50%|█████     | 1010/2000 [06:30<12:16,  1.34it/s, loss=4.8647, acc=0.308, ppl=129.6, lr=9.86e-03][A[A

Training:  51%|█████     | 1020/2000 [06:33<10:20,  1.58it/s, loss=4.8647, acc=0.308, ppl=129.6, lr=9.86e-03][A[A

Training:  51%|█████     | 1020/2000 [06:34<10:20,  1.58it/s, loss=4.7778, acc=0.311, ppl=118.8, lr=9.85e-03][A[A

Training:  52%|█████▏    | 1030/2000 [06:37<08:48,  1.84it/s, loss=4.7778, acc=0.311, ppl=118.8, lr=9.85e-03][A[A

Training:  52%|█████▏    | 1030/2000 [06:37<08:48,  1.84it/s, loss=4.8822, acc=0.292, ppl=131.9, lr=9.85e-03][A[A

Training:  52%|█████▏    | 1040/2000 [06:40<07:53,  2.03it/s, loss=4.8822, acc=0.292, ppl=131.9, lr=9.85e-03][A[A

Training:  52%|█████▏    | 1040/2000 [06:41<07:53,  2.03it/s, loss=4.9360, acc=0.294, ppl=139.2, lr=9.84e-03][A[A

Training:  52%|█████▎    | 1050/2000 [06:44<07:04,  2.24it/s, 


Step 1500: Val Loss: 3.7905, Val Acc: 0.4302, Val PPL: 44.28
💾 Saved best model with val_loss: 3.7905




Training:  76%|███████▌  | 1510/2000 [09:40<05:59,  1.36it/s, loss=4.2221, acc=0.369, ppl=68.2, lr=9.54e-03][A[A

Training:  76%|███████▌  | 1510/2000 [09:41<05:59,  1.36it/s, loss=4.0710, acc=0.379, ppl=58.6, lr=9.54e-03][A[A

Training:  76%|███████▌  | 1520/2000 [09:44<05:00,  1.60it/s, loss=4.0710, acc=0.379, ppl=58.6, lr=9.54e-03][A[A

Training:  76%|███████▌  | 1520/2000 [09:44<05:00,  1.60it/s, loss=4.0592, acc=0.382, ppl=57.9, lr=9.53e-03][A[A

Training:  76%|███████▋  | 1530/2000 [09:47<04:13,  1.85it/s, loss=4.0592, acc=0.382, ppl=57.9, lr=9.53e-03][A[A

Training:  76%|███████▋  | 1530/2000 [09:48<04:13,  1.85it/s, loss=4.2892, acc=0.360, ppl=72.9, lr=9.52e-03][A[A

Training:  77%|███████▋  | 1540/2000 [09:51<03:45,  2.04it/s, loss=4.2892, acc=0.360, ppl=72.9, lr=9.52e-03][A[A

Training:  77%|███████▋  | 1540/2000 [09:51<03:45,  2.04it/s, loss=4.0346, acc=0.383, ppl=56.5, lr=9.51e-03][A[A

Training:  78%|███████▊  | 1550/2000 [09:55<03:20,  2.25it/s, loss=4.0

  ⏱️ Training completed in 755.4 seconds





  📊 Final - Loss: 3.1932, Acc: 0.5054, PPL: 24.37
💾 Saved final model to final_model.pt

🎉 TRAINING COMPLETED!
⏱️ Total time: 12.8 minutes
🏆 Final Results:
   Validation Loss: 3.1932
   Validation Accuracy: 0.5054
   Validation Perplexity: 24.37


## **17. Model Loading**

After training, we can load our saved model.

In [27]:
def load_trained_model(model_path: str = "final_model.pt"):
    """
    Load a trained MinimalLLM model from a checkpoint.

    This function handles:
        - Safe deserialization for PyTorch 2.6+.
        - Loading model configuration and weights.
        - Moving the model to GPU if available.
        - Setting the model to evaluation mode.

    Args:
        model_path (str): Path to the checkpoint file.

    Returns:
        model (nn.Module): Loaded MinimalLLM model ready for inference.
        config (ModelConfig): Model configuration used for training.
    """
    print(f"🔄 Loading model from {model_path}")

    # Ensure ModelConfig is available for safe deserialization (PyTorch 2.6+)
    from torch.serialization import add_safe_globals
    add_safe_globals([ModelConfig])

    # Load checkpoint
    try:
        # Standard loading
        checkpoint = torch.load(model_path, map_location='cpu')
        config = checkpoint['config']
    except Exception as e:
        # Fallback for potential PyTorch serialization differences
        print(f"⚠️ Warning: failed standard load. Retrying with weights_only=False...")
        checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
        config = checkpoint['config']

    # Initialize model with loaded config
    model = MinimalLLM(config)

    # Load weights into model
    model.load_state_dict(checkpoint['model_state_dict'])

    # Move model to available device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Set model to evaluation mode (disables dropout, etc.)
    model.eval()

    # Print summary
    total_params = sum(p.numel() for p in model.parameters())
    print(f"✅ Model loaded successfully")
    print(f"   Parameters: {total_params:,}")
    print(f"   Device: {device}")

    return model, config

## **18a. Model Inference - Text Generation**

In [28]:
def generate_text(
    model: nn.Module,
    tokenizer,
    prompt: str,
    max_length: int = 100,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 0.9
):
    """
    Generate text autoregressively using a trained language model.

    The model predicts one token at a time, conditioned on the prompt and previously
    generated tokens. Sampling is controlled by `temperature`, `top_k`, and `top_p`
    (nucleus sampling) to balance creativity and coherence.

    Args:
        model (nn.Module): Trained language model.
        tokenizer: Tokenizer for encoding and decoding text.
        prompt (str): Initial text prompt to condition generation.
        max_length (int): Maximum number of new tokens to generate.
        temperature (float): Scales output logits — higher = more random sampling.
        top_k (int): Keeps only the top-k most probable tokens.
        top_p (float): Keeps the smallest set of tokens whose cumulative probability ≥ top_p.

    Returns:
        str: The generated text (prompt + model-generated continuation).
    """

    # Ensure model is in inference mode (disables dropout, etc.)
    model.eval()

    # Detect which device (CPU/GPU) the model is on
    device = next(model.parameters()).device

    # Convert prompt into token IDs and move to model's device
    input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors='pt').to(device)

    # This tensor will grow as we generate more tokens
    generated_ids = input_ids.clone()

    with torch.no_grad():  # Disable gradients for faster inference
        for _ in range(max_length):
            # 1. Forward pass: get logits for the current sequence
            logits = model(generated_ids)

            # 2. Extract the logits for the last token (the next-token prediction)
            next_token_logits = logits[0, -1, :] / temperature

            # 3. Apply Top-K filtering (keep only the top K most likely tokens)
            if top_k > 0:
                top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
                filtered_logits = torch.full_like(next_token_logits, float('-inf'))
                filtered_logits[top_k_indices] = top_k_logits
                next_token_logits = filtered_logits

            # 4. Apply Top-P (nucleus) filtering — keep tokens that make up top_p cumulative probability
            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                # Mask out tokens beyond the top_p threshold
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
                sorted_indices_to_remove[0] = 0  # Always keep the highest probability token
                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                next_token_logits[indices_to_remove] = float('-inf')

            # 5. Sample next token based on filtered probabilities
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            # 6. Append new token to the generated sequence
            next_token = next_token.unsqueeze(0)  # Add batch dimension
            generated_ids = torch.cat([generated_ids, next_token], dim=1)

            # 7. Stop if the model predicts the end-of-sequence token
            if next_token.item() == tokenizer.eos_token_id:
                break

    # 8. Decode all generated tokens back into text
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return generated_text


In [29]:
def demo_inference(model_path: str = "final_model.pt"):
    """
    Run a quick text generation demo to showcase the model's capabilities.

    This function automatically loads the trained model and tokenizer,
    then generates short text samples from a list of example prompts.
    It's a fast, hands-free way to verify that training and inference work correctly.

    Args:
        model_path (str): Path to the saved model checkpoint file (default: "final_model.pt")
    """

    print("🎭 Running inference demo — showcasing model creativity and coherence")

    # ------------------------------------------------------
    # 1. Load the trained model and its configuration
    # ------------------------------------------------------
    model, config = load_trained_model(model_path)

    # ------------------------------------------------------
    # 2. Load tokenizer (same as training tokenizer)
    # ------------------------------------------------------
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")

    # Ensure the tokenizer has a valid padding token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # ------------------------------------------------------
    # 3. Define demo prompts — diverse topics to test creativity
    # ------------------------------------------------------
    demo_prompts = [
        "A mysterious signal was detected from deep space",
        "The future of human and AI collaboration is",
        "Every morning, the robot barista greeted its customers by saying",
        "In the ruins of an ancient digital city, explorers found",
        "When the sun finally rose over the last surviving colony,",
        "The secret ingredient to lifelong curiosity is",
        "By 2099, memories became a form of currency",
        "The scientist stared at the hologram and realized",
        "In the middle of the storm, the android whispered",
        "The code began to write itself when"
    ]

    # ------------------------------------------------------
    # 4. Loop through each demo prompt and generate text
    # ------------------------------------------------------
    for i, prompt in enumerate(demo_prompts, 1):
        print(f"\n🧠 Demo {i}: '{prompt}'")
        print("-" * 60)

        # Generate text for each prompt
        generated_text = generate_text(
            model=model,
            tokenizer=tokenizer,
            prompt=prompt,
            max_length=120,     # Slightly longer outputs for richness
            temperature=0.7,    # Moderate creativity
            top_k=40,           # Focus sampling on top 40 likely tokens
            top_p=0.85          # Nucleus sampling for coherence and diversity
        )

        # Display generated result
        print(f"📝 {generated_text}\n")

    print("✅ Demo complete — model generation looks healthy and coherent!")


In [30]:
if __name__ == "__main__":
    # Check if we have a trained model
    import os

    if os.path.exists("final_model.pt"):
        print("🎉 Found trained model! Running demo...")
        demo_inference("final_model.pt")

    else:
        print("⚠️ No trained model found. Please run the training cells first.")
        print("💡 Look for 'final_model.pt' or 'best_model.pt' in your directory.")

🎉 Found trained model! Running demo...
🎭 Running inference demo — showcasing model creativity and coherence
🔄 Loading model from final_model.pt
✅ Model loaded successfully
   Parameters: 32,150,976
   Device: cuda

🧠 Demo 1: 'A mysterious signal was detected from deep space'
------------------------------------------------------------
📝 A mysterious signal was detected from deep space and create something called the 't'' comes the "The Dottie."

And so, they learned about this, and even more about their experiences. Brandon and its impact on how they found themselves.

Emily took a small town's work. She looked like how to make his friends and even more about them. With a special places, Mr. Brandon smiled and asked Mr.

"I want my love."

"That sounds fascinating things, "It means. But I see you have your favorite!"

Sally the fun and said,


🧠 Demo 2: 'The future of human and AI collaboration is'
------------------------------------------------------------
📝 The future of human and A

## **18b. Model Inference - Chat Interaction**

In [31]:
def interactive_inference(model_path: str = "final_model.pt"):
    """
    Launch an interactive text generation session with a trained language model.

    The user enters prompts, and the model responds with generated text.
    This simulates an interactive chat or creative writing interface.

    Args:
        model_path (str): Path to the saved model checkpoint file.

    Example:
        >>> interactive_inference("final_model.pt")
        🤖 Starting interactive inference session
        Enter your prompt: Once upon a time...
        📝 Generated text: Once upon a time, in a faraway land...
    """

    print("🤖 Starting interactive inference session")
    print("Type 'quit' or 'exit' to stop\n")

    # ------------------------------------------------------
    # 1. Load the trained model and its configuration
    # ------------------------------------------------------
    model, config = load_trained_model(model_path)

    # ------------------------------------------------------
    # 2. Load the same tokenizer used during training
    # ------------------------------------------------------
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")

    # Ensure pad token is defined (some tokenizers don’t include it)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # ------------------------------------------------------
    # 3. Continuous prompt-response loop
    # ------------------------------------------------------
    while True:
        try:
            # Get input prompt from the user
            prompt = input("\n🗨️  Enter your prompt: ").strip()

            # Exit conditions
            if prompt.lower() in {"quit", "exit", "q"}:
                print("\n👋 Goodbye!")
                break
            if not prompt:
                continue  # Skip empty input lines

            # Generate text
            print("\n🔄 Generating response...\n")
            generated_text = generate_text(
                model=model,
                tokenizer=tokenizer,
                prompt=prompt,
                max_length=150,
                temperature=0.8,
                top_k=50,
                top_p=0.9
            )

            # Display the model's output
            print("📝 Generated text:\n")
            print(generated_text)

        except KeyboardInterrupt:
            print("\n👋 Session interrupted by user.")
            break
        except Exception as e:
            print(f"❌ Error during generation: {e}")

In [33]:
if __name__ == "__main__":
    # Check if we have a trained model
    import os

    if os.path.exists("final_model.pt"):
        interactive_inference("final_model.pt")

    else:
        print("⚠️ No trained model found. Please run the training cells first.")
        print("💡 Look for 'final_model.pt' or 'best_model.pt' in your directory.")

🤖 Starting interactive inference session
Type 'quit' or 'exit' to stop

🔄 Loading model from final_model.pt
✅ Model loaded successfully
   Parameters: 32,150,976
   Device: cuda

🗨️  Enter your prompt: hi

🔄 Generating response...

📝 Generated text:

hi, the heart of the beginning of the world.

Section 3: What is Cultural Heritage

Imagine trying to bring friends and family, and family. To achieve this, people who lived, known as the world, and stories of us. However, many people lived there were some people called "The Mahanta" as a country. Let's learn about what makes it on the story and how they believe in today.

Section 4: What does 'A "The Mahanta"

* **
* **B:** Our group has a particular group of "The Mahanta" by the "The Mahanta" through the United States, a "The Mahanta" through the "The Mahanta" by various

👋 Session interrupted by user.
