# MLX‑Test: Pedagogical Walk‑Through of Your OpenELM Architecture

This notebook is meant to **teach** every moving part of the exact model you are training.  
For each major block you’ll find:

1. **Purpose & intuition** – why the block exists and design trade‑offs.  
2. **Key equations / pseudocode** in plain English.  
3. **The real MLX implementation** from your repo so you can run or edit live.

Feel free to experiment: change hidden sizes, activation functions, etc. and re‑run cells.

*Generated 2025‑07‑14 02:38*


## 1. Hyper‑Parameter Dataclass — `SMLMConfig`

All hyper‑parameters (HPs) are centralised in a dataclass so the rest of the model
can stay clean.  We separate:

* **Architecture HPs** (e.g. `model_dim`, `head_dim`, number of layers).
* **Training HPs** (batch size, learning‑rate schedule, etc.).

Loading from *config.json* keeps experiments reproducible and makes sweeps trivial
— swap JSON files, not code.

In [None]:
import dataclasses, json, pathlib
from typing import List

@dataclasses.dataclass
class SMLMConfig:
    """Model + training hyper‑parameters."""
    tokenizer_path: str
    checkpoint_dir: str
    vocab_size: int
    model_dim: int
    num_transformer_layers: int
    head_dim: int
    num_query_heads: List[int]
    num_kv_heads: List[int]
    num_gqa_groups: int
    normalize_qk_projections: bool
    ffn_multipliers: List[float]
    ffn_dim_divisor: int
    ffn_with_glu: bool
    rope_freq_constant: int
    rope_max_length: int
    normalization_layer_name: str
    activation_fn_name: str
    initializer_range: float
    share_input_output_layers: bool
    # training‑time HPs omitted for brevity …

    @classmethod
    def from_json(cls, path):
        return cls(**json.loads(pathlib.Path(path).read_text()))


## 2. Utility Helpers

* **`RMSNorm`** – root‑mean‑square layer‑norm: scale‑invariant and cheaper than LN.  
* **`repeat_kv`** – duplicates keys/values so multiple *query* heads can share them
  (Grouped‑Query Attention).

In [None]:
import mlx.nn as nn, mlx.core as mx

RMSNorm = nn.RMSNorm  # thin alias

def repeat_kv(x: mx.array, n: int, axis: int = 2):
    """Tile along `axis` (usually head dim) for grouped‑query attention."""
    return mx.repeat(x, n, axis=axis)


## 3. Position‑wise Feed‑Forward Network (FFN)

### Why?
Each token, **independently**, gets a non‑linear transformation after attention.
It mixes information across hidden features (but not across sequence length).

### Design choices
| Choice | Your value | Reason |
|---|---|---|
| Hidden multiplier schedule | `ffn_multipliers` array | Saves params in early layers, more capacity deeper |
| GLU variant | **SwiGLU** when `ffn_with_glu=True` | Improves expressiveness with a tiny cost |
| Activation | SiLU (aka Swish) | Smooth & self‑gating, often better than ReLU/GELU |

Equation (with GLU):  
\[\text{FFN}(x)=W_{2}( \sigma(W_{1,a}x)\odot W_{1,b}x )\]

Below is the exact implementation.

In [None]:
import math, mlx.nn as nn, mlx.core as mx

class FeedForward(nn.Module):
    def __init__(self, cfg: SMLMConfig, idx: int):
        super().__init__()
        mult = cfg.ffn_multipliers[idx]
        hidden = math.ceil(mult * cfg.model_dim / cfg.ffn_dim_divisor) * cfg.ffn_dim_divisor
        out_feats = hidden * 2 if cfg.ffn_with_glu else hidden

        # 3.1 Projection to hidden (or 2× hidden for GLU)
        self.proj_in = nn.Linear(cfg.model_dim, out_feats, bias=False)

        # 3.2 Non‑linearity
        self.act = nn.SiLU() if cfg.activation_fn_name == "swish" else nn.GELU()
        self.use_glu = cfg.ffn_with_glu
        self.dropout = nn.Dropout(cfg.dropout)

        # 3.3 Back to model_dim
        self.proj_out = nn.Linear(hidden, cfg.model_dim, bias=False)

    def __call__(self, x):
        y = self.proj_in(x)
        if self.use_glu:
            a, b = mx.split(y, 2, axis=-1)  # gate & value
            y = self.act(a) * b
        else:
            y = self.act(y)
        return self.proj_out(self.dropout(y))


## 4. Grouped‑Query Attention (GQA) with Rotary Positional Encoding

### 4.1 Quick refresher  
Multi‑Head Attention lets each token attend to previous tokens.  
GQA reduces **KV** redundancy: many *query* heads share a smaller set of *key/value* heads.

<br>

### 4.2 Head shapes

| Symbol | Value in layer *i* |
|---|---|
| \(H_q\) query heads | `num_query_heads[i]` |
| \(H_{kv}\) key/value heads | `num_kv_heads[i]` |
| groups (=\(H_q/H_{kv}\)) | `num_gqa_groups` |

### 4.3 Positional Encoding  
**RoPE** rotates Q & K in complex plane so attention remains length‑agnostic.

### 4.4 Implementation steps
1. Linear projection → \([Q;K;V]\)  
2. Apply RoPE to Q & K  
3. `repeat_kv` to share K,V across query groups  
4. `mx.fast.scaled_dot_product_attention` (Metal accelerated)  
5. Output projection back to `model_dim`

In [None]:
class GQAttention(nn.Module):
    def __init__(self, cfg: SMLMConfig, idx: int):
        super().__init__()
        Hq, Hkv, D = cfg.num_query_heads[idx], cfg.num_kv_heads[idx], cfg.head_dim
        self.Hq, self.Hkv, self.D = Hq, Hkv, D
        self.groups = cfg.num_gqa_groups

        # total projection size = (Hq + 2*Hkv) * D
        self.qkv = nn.Linear(cfg.model_dim, (Hq + 2*Hkv) * D, bias=False)
        self.rope = nn.RoPE(D, base=cfg.rope_freq_constant)
        self.proj_out = nn.Linear(Hq * D, cfg.model_dim, bias=False)

    def __call__(self, x, *, mask):
        B, L, _ = x.shape
        qkv = self.qkv(x).reshape(B, L, self.Hq + 2*self.Hkv, self.D).transpose(0,2,1,3)
        q, k, v = mx.split(qkv, [self.Hq, self.Hq+self.Hkv], axis=1)

        q, k = self.rope(q), self.rope(k)               # positional info
        k = repeat_kv(k, self.groups, axis=1)
        v = repeat_kv(v, self.groups, axis=1)

        attn = mx.fast.scaled_dot_product_attention(
            q, k, v, scale=1/math.sqrt(self.D), mask=mask)

        out = attn.transpose(0,2,1,3).reshape(B, L, -1)
        return self.proj_out(out)


## 5. Decoder Layer (Pre‑Norm Residual)

Sequence of operations:

1. **RMSNorm** on input  
2. **Attention** + residual add  
3. **RMSNorm**  
4. **Feed‑Forward** + residual add

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, cfg: SMLMConfig, idx):
        super().__init__()
        self.norm1 = RMSNorm(cfg.model_dim, eps=1e-6)
        self.attn  = GQAttention(cfg, idx)
        self.norm2 = RMSNorm(cfg.model_dim, eps=1e-6)
        self.ffn   = FeedForward(cfg, idx)
    def __call__(self, x, *, mask):
        x = x + self.attn(self.norm1(x), mask=mask)
        return x + self.ffn(self.norm2(x))


## 6. Full Decoder

* **Embedding layer** – projects tokens to `model_dim`.  
* **Stack of decoder layers** – deep computation.  
* **Final RMSNorm** – stabilises outputs.  
* **LM head** – linear layer to vocab.  Weight‑tying with embedding reduces params.

In [None]:
class OpenELM(nn.Module):
    def __init__(self, cfg: SMLMConfig):
        super().__init__()
        self.emb = nn.Embedding(cfg.vocab_size, cfg.model_dim)
        self.layers = [DecoderLayer(cfg, i) for i in range(cfg.num_transformer_layers)]
        self.final_norm = RMSNorm(cfg.model_dim, eps=1e-6)
        self.lm_head = nn.Linear(cfg.model_dim, cfg.vocab_size, bias=False)
        if cfg.share_input_output_layers:
            self.lm_head.weight = self.emb.weight
    def __call__(self, tokens):
        B, L = tokens.shape
        mask = nn.MultiHeadAttention.create_additive_causal_mask(L)
        h = self.emb(tokens)
        for layer in self.layers:
            h = layer(h, mask=mask)
        return self.lm_head(self.final_norm(h))
