In [1]:
!python -m pip install --upgrade pip && pip install -r ../requirements.txt
!pip install --pre "torch==2.9.0.dev20250713+cu126" --index-url https://download.pytorch.org/whl/nightly/cu126
!pip show torch
!pip install pytz
!python ../data/cached_fineweb10B.py 4

Collecting tqdm (from -r ../requirements.txt (line 2))
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting huggingface-hub (from -r ../requirements.txt (line 4))
  Downloading huggingface_hub-0.34.4-py3-none-any.whl.metadata (14 kB)
Collecting hf-xet<2.0.0,>=1.1.3 (from huggingface-hub->-r ../requirements.txt (line 4))
  Downloading hf_xet-1.1.8-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (703 bytes)
Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
Downloading huggingface_hub-0.34.4-py3-none-any.whl (561 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m561.5/561.5 kB[0m [31m43.3 MB/s[0m  [33m0:00:00[0m
[?25hDownloading hf_xet-1.1.8-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m203.3 MB/s[0m  [33m0:00:00[0m
[?25hInstalling collected packages: tqdm, hf-xet, huggingface-hub
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## Imports

In [1]:
import os, sys

In [2]:
import uuid, time, copy

In [3]:
from dataclasses import dataclass
from functools import lru_cache, partial

In [4]:
import glob
from pathlib import Path

In [5]:
import torch
from torch import Tensor, nn
import torch.distributed as dist
from torch.nn.attention.flex_attention import BlockMask, flex_attention
import torch.nn.functional as F

In [6]:
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

## Logging

In [7]:
from datetime import datetime
import pytz
eastern = pytz.timezone("US/Eastern")
timestamp = datetime.now(eastern).strftime("%H:%M-%Y-%m-%d")

In [8]:
# begin logging
logfile = None

run_id = os.environ.get("NB_BASE",uuid.uuid4())
os.makedirs("logs", exist_ok=True)
logfile = f"logs/{timestamp}.txt"
print(logfile)


def print0(s, console=False):
    if master_process:
        with open(logfile, "a") as f:
            if console:
                print(s)
            print(s, file=f)


logs/03:02-2025-08-25.txt


In [9]:
sys.version, torch.version.__version__

('3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0]',
 '2.9.0.dev20250713+cu126')

## Optimizers


In [10]:
@torch.compile ## ns
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
    a, b, c = (3.4445, -4.7750,  2.0315)
    X = G
    if G.size(-2) > G.size(-1):
        X = X.mT

    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
        X = a * X + B @ X

    if G.size(-2) > G.size(-1):
        X = X.mT
    return X

In [11]:
class Muon(torch.optim.Optimizer):
    """
    Muon - MomentUm Orthogonalized by Newton-schulz

    https://kellerjordan.github.io/posts/muon/

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
    processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
    matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
    the advantage that it can be stably run in bfloat16 on the GPU.

    Warning: This optimizer should not be used for the embedding layer, the final fully connected layer,
    or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
    """
    def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95):
        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
        params = list(params)
        sizes = {p.shape for p in params}
        # create one buffer per unique parameter-size
        param_groups = []
        for size in sizes:
            group_params = [p for p in params if p.shape == size]
            param_groups.append(dict(params=group_params))
        super().__init__(param_groups, defaults)

    @torch.no_grad()
    def step(self):
        # Efficient systems-wise implementation of step developed by @YouJiacheng,
        # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad,
        # @ryanyang0, and @vagrawal.
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        reduce_scatter_futures: list[torch.Future] = []
        all_reduce_futures: list[torch.Future] = []
        for group in self.param_groups:
            params: list[Tensor] = group["params"]
            grad = torch.empty_like(params[-1])
            grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size
            for base_i in range(0, len(params), world_size):
                if base_i + rank < len(params):
                    grad = params[base_i + rank].grad
                # This gives strange dynamo warnings
                reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future())

        idx = 0
        for group in self.param_groups:
            params: list[Tensor] = group["params"]
            params_pad = params + [torch.empty_like(params[-1])] * world_size
            momentum = group["momentum"]
            for base_i in range(0, len(params), world_size):
                reduce_scatter_futures[idx].wait()
                if base_i + rank < len(params):
                    p = params[base_i + rank]
                    grad = p.grad
                    eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0)
                    eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0)
                    state = self.state[p]
                    if len(state) == 0:
                        state["momentum_buffer"] = torch.zeros_like(grad)
                    momentum_buffer = state["momentum_buffer"]
                    p.mul_(1 - eff_weight_decay)
                    momentum_buffer.lerp_(grad, 1 - momentum)
                    grad = grad.lerp_(momentum_buffer, momentum)
                    v = zeropower_via_newtonschulz5(grad.bfloat16(), 5)
                    p.add_(other=v, alpha=-eff_lr)
                idx += 1
                all_reduce_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future())
        torch.futures.collect_all(all_reduce_futures).wait()

In [12]:
class DistAdam(torch.optim.Optimizer):
    def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        params = list(params)
        sizes = {p.shape for p in params}
        # create one buffer per unique parameter-size
        param_groups = []
        for size in sizes:
            group_params = [p for p in params if p.shape == size]
            param_groups.append(dict(params=group_params))
        super().__init__(param_groups, defaults)
        # DistributedAdam implementation by @vagrawal

    @torch.compile
    @torch.no_grad()
    def step(self):
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        reduce_scatter_futures: list[torch.Future] = []
        all_reduce_futures: list[torch.Future] = []
        grad_slices = []
        for group in self.param_groups:
            params: list[Tensor] = group["params"]
            grad = torch.empty_like(params[-1])
            for base_i in range(len(params)):
                grad = params[base_i].grad
                rank_size = grad.shape[0] // world_size
                grad_slice = torch.empty_like(grad[:rank_size])
                reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
                grad_slices.append(grad_slice)

        idx = 0
        for group in self.param_groups:
            beta1, beta2 = group['betas']
            eps = group['eps']
            wd = group['weight_decay']
            params = group['params']
            for base in range(len(params)):
                reduce_scatter_futures[idx].wait()
                p = params[base]
                rank_size = p.shape[0] // world_size
                p_slice = p[rank * rank_size:(rank + 1) * rank_size]
                lr = group['lr'] * getattr(p, "lr_mul", 1.0)
                state = self.state[p]
                g_slice = grad_slices[idx]
                # State init
                if not state:
                    state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
                    state['exp_avg'] = torch.zeros_like(p_slice)
                    state['exp_avg_sq'] = torch.zeros_like(p_slice)
                exp_avg = state['exp_avg']
                exp_avg_sq = state['exp_avg_sq']
                state['step'] += 1
                t = state['step']
                # weight decay
                if wd != 0:
                    eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
                    p_slice.mul_(1 - eff_weight_decay)
                # update running averages
                exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
                # bias corrections
                bias1 = 1 - beta1 ** t
                bias2 = 1 - beta2 ** t
                # compute step
                denom = exp_avg_sq.sqrt().add_(eps)
                step_size = lr * (torch.sqrt(bias2) / bias1)
                update = exp_avg.div(denom).mul_(step_size)
                p_slice.add_(other=update, alpha=-1.0)
                idx += 1
                all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
        torch.futures.collect_all(all_reduce_futures).wait()

## Custom Operators

In [13]:
@torch.library.custom_op("nanogpt::mm", mutates_args=())
def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]:
    @torch.compile
    def impl(x: Tensor, w: Tensor):
        assert x.is_contiguous() and w.is_contiguous()
        x_f8 = x.div(x_s).to(torch.float8_e4m3fn)
        w_f8 = w.div(w_s).to(torch.float8_e4m3fn)
        out = torch._scaled_mm(
            x_f8,
            w_f8.T,
            out_dtype=torch.bfloat16,
            scale_a=x.new_tensor(x_s, dtype=torch.float32),
            scale_b=x.new_tensor(w_s, dtype=torch.float32),
            use_fast_accum=True,
        )
        return out, x_f8, w_f8

    return impl(x, w)

@mm_op.register_fake
def _(x: Tensor, w: Tensor, *_):
    assert x.ndim == w.ndim == 2
    assert x.shape[1] == w.shape[1]
    assert x.device == w.device
    assert x.is_contiguous() and w.is_contiguous()
    return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn)

@torch.library.custom_op("nanogpt::mm_backward", mutates_args=())
def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]:
    @torch.compile
    def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor):
        assert grad.is_contiguous()
        x_inv_s = grad.new_tensor(x_s, dtype=torch.float32)
        w_inv_s = grad.new_tensor(w_s, dtype=torch.float32)
        grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32)
        grad_f8 = grad.div(grad_s).to(torch.float8_e5m2)
        grad_x = torch._scaled_mm(
            grad_f8,
            w_f8.T.contiguous().T,
            out_dtype=torch.bfloat16,
            scale_a=grad_inv_s,
            scale_b=w_inv_s,
            use_fast_accum=False,
        )
        # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768)
        grad_w = torch._scaled_mm(
            x_f8.T.contiguous(),
            grad_f8.T.contiguous().T,
            out_dtype=torch.float32,
            scale_a=x_inv_s,
            scale_b=grad_inv_s,
            use_fast_accum=False,
        ).T
        return grad_x, grad_w

    return impl(g, x_f8, w_f8)

@mm_backward_op.register_fake
def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_):
    return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32)

def backward(ctx, grad_out: Tensor, *_):
    x_f8, w_f8 = ctx.saved_tensors
    x_s, w_s, grad_s = ctx.scales
    grad_x, grad_w = torch.ops.nanogpt.mm_backward(
        grad_out, x_f8, w_f8, x_s, w_s, grad_s
    )
    return grad_x, grad_w, None, None, None

def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output):
    *_, x_s, w_s, grad_s = inputs
    _, x_f8, w_f8 = output
    ctx.save_for_backward(x_f8, w_f8)
    ctx.scales = x_s, w_s, grad_s
    ctx.set_materialize_grads(False)

mm_op.register_autograd(backward, setup_context=setup_context)


## Modules

In [14]:
def norm(x: Tensor):
    return F.rms_norm(x, (x.size(-1),))

### CastedLinear

In [15]:
class CastedLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0):
        super().__init__(in_features, out_features, bias=False)
        self.use_fp8 = use_fp8
        self.x_s = x_s
        self.w_s = w_s
        self.grad_s = grad_s

    def reset_parameters(self) -> None:
        std = 0.5 * (self.in_features ** -0.5)  # 0.5 is a bit better than the default 1/sqrt(3)
        bound = (3 ** 0.5) * std
        with torch.no_grad():
            self.weight.uniform_(-bound, bound)

    def forward(self, x: Tensor):
        assert x.size(-1) == self.in_features
        if self.use_fp8 and self.training:
            _x = x.flatten(0, -2)
            out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0]
            return out.reshape(*x.shape[:-1], -1)
        else:
            return F.linear(x, self.weight.type_as(x))

### RoPE

In [16]:
class Rotary(nn.Module):
    def __init__(self, dim: int, max_seq_len: int):
        super().__init__()
        # half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
        angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim // 4, dtype=torch.float32)
        angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim // 4)])
        t = torch.arange(max_seq_len, dtype=torch.float32)
        theta = torch.einsum("i,j -> ij", t, angular_freq)
        self.cos = nn.Buffer(theta.cos(), persistent=False)
        self.sin = nn.Buffer(theta.sin(), persistent=False)

    def forward(self, x_BTHD: Tensor):
        assert self.cos.size(0) >= x_BTHD.size(-3)
        cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
        x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
        y1 = x1 * cos + x2 * sin
        y2 = x1 * (-sin) + x2 * cos
        return torch.cat((y1, y2), 3).type_as(x_BTHD)

### Attention

In [17]:
class CausalSelfAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        hdim = num_heads * head_dim
        std = 0.5 * (dim ** -0.5)
        bound = (3 ** 0.5) * std  # improved init scale by @YouJiacheng
        # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng
        # https://x.com/hi_tysam/status/1879699187107033311
        self.qkv_w = nn.Parameter(torch.empty(3, hdim, dim).uniform_(-bound, bound))
        self.rotary = Rotary(head_dim, max_seq_len)
        self.c_proj = CastedLinear(hdim, dim)
        self.c_proj.weight.detach().zero_()  # zero init suggested by @Grad62304977
        # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun
        # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283
        self.attn_scale = 0.12

    def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask):
        B, T = x.size(0), x.size(1)  # batch size, sequence length

        q, k, v = F.linear(x, self.qkv_w.flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads,
                                                                             self.head_dim).chunk(3, dim=-2)
        q, k = norm(q), norm(k)  # QK norm @Grad62304977
        q, k = self.rotary(q), self.rotary(k)
        if ve is not None:
            v = lambdas[0] * v + lambdas[1] * ve.view_as(v)  # @KoszarskyB & @Grad62304977
        else:  # skip mid-layers token value embeddings by @YouJiacheng
            v = lambdas[0] * v
        y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask,
                           scale=self.attn_scale).transpose(1, 2)
        y = y.contiguous().view(B, T, self.num_heads * self.head_dim)  # re-assemble all head outputs side by side
        y = self.c_proj(y)
        return y

### MLP

In [18]:
class MLP(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        hdim = 4 * dim
        self.c_fc = CastedLinear(dim, hdim)
        self.c_proj = CastedLinear(hdim, dim)
        self.c_proj.weight.detach().zero_()  # zero init suggested by @Grad62304977

    def forward(self, x: Tensor):
        x = self.c_fc(x)
        x = F.relu(
            x).square()  # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
        x = self.c_proj(x)
        return x

### Transformer Block

In [19]:
class Block(nn.Module):
    def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int):
        super().__init__()
        # skip attention of blocks.7 (the 8th layer) by @YouJiacheng
        self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None
        self.mlp = MLP(dim)

    def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor,
                block_mask: BlockMask):

        x = lambdas[0] * x + lambdas[1] * x0
        if self.attn is not None:
            x = x + self.attn(norm(x), ve, sa_lambdas, block_mask)
        x = x + self.mlp(norm(x))
        return x

### Create blockmasks

In [20]:
def create_blockmasks(seq_len: int,
                      sliding_window_num_blocks: int,
                      BLOCK_SIZE: int = 128,
                      device: torch.device = "cuda") -> BlockMask:
    assert seq_len % BLOCK_SIZE == 0, "seq_len must be a multiple of BLOCK_SIZE"

    Q = seq_len // BLOCK_SIZE
    i32 = torch.int32

    # Block coordinates
    q = torch.arange(Q, device=device, dtype=i32).view(-1, 1)  # [Q,1]
    k = torch.arange(Q, device=device, dtype=i32).view(1, -1)  # [1,K]

    
    # Partial (current block): counts=1, indices first column = q
    partial_idx = torch.zeros((Q, Q), dtype=i32, device=device)   # [Q,K]
    partial_idx[:, 0] = q.squeeze(1)
    partial_cnt = torch.ones((Q,), dtype=i32, device=device)      # [Q]

    # Full (previous blocks), nearest-first prefix: [q-1, q-2, ..., 0]
    full_idx = (q - 1 - k) % Q                                    # [Q,K]
    full_cnt = torch.arange(Q, dtype=i32, device=device)          # [Q] = number of previous blocks

    # Add (B,H) broadcast dims
    partial_kv_indices    = partial_idx[None, None].contiguous()  # [1,1,Q,K]
    full_kv_indices       = full_idx[None, None].contiguous()     # [1,1,Q,K]
    partial_kv_num_blocks = partial_cnt[None, None].contiguous()  # [1,1,Q]
    full_kv_num_blocks    = full_cnt[None, None].contiguous()     # [1,1,Q]

    # Token-level causal within the current (partial) block
    def causal_mask_mod(b, h, q_idx, kv_idx):
        return q_idx >= kv_idx

    # Windowing (ints → tensor ops via broadcasting)
    W = max(sliding_window_num_blocks, 1)
    max_full = max(W - 1, 0)                                      # int
    full_num = torch.clamp_max(full_kv_num_blocks, max_full)      # [1,1,Q]
    remain_for_partial = torch.clamp_min(W - full_num, 1)         # [1,1,Q]
    partial_num = torch.minimum(partial_kv_num_blocks, remain_for_partial)

    return BlockMask.from_kv_blocks(
        partial_num, partial_kv_indices,
        full_num,    full_kv_indices,
        BLOCK_SIZE=BLOCK_SIZE,
        mask_mod=causal_mask_mod,
    )

In [21]:
masks = {}
@lru_cache(1)
def get_blockmask(seq_len, n_windows):
    if (seq_len, n_windows) not in masks:
        bm = create_blockmasks(seq_len, n_windows)
        masks[(seq_len, n_windows)] = bm
    return masks[(seq_len, n_windows)]

### Model

In [22]:
class GPT(nn.Module):
    def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int):
        super().__init__()
        vocab_size = vocab_size
        self.embed = nn.Embedding(vocab_size, model_dim)

        self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)])
        self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)])

        # self.lm_head = CastedLinear(model_dim, vocab_size)
        self.lm_head = CastedLinear(model_dim, 
                                    vocab_size, 
                                    use_fp8=True, 
                                    x_s=(model_dim ** 0.5) / 448, 
                                    w_s=24 / 448,
                                    grad_s=1 / 448
                                   )
        self.lm_head.weight.detach().zero_()  # @Grad62304977
        # Add learnable skip connection weights for decoder layers
        assert num_layers % 2 == 0
        pad = (-num_layers * 5) % dist.get_world_size()
        self.scalars = nn.Parameter(torch.cat([
            torch.ones(num_layers),  # skip_weights
            *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)],  # block lambdas
            *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)],  # SA lambdas
            torch.ones(pad),
        ]))
        # set learning rates
        for param in self.embed.parameters():
            param.lr_mul = 75.
        for param in self.value_embeds.parameters():
            param.lr_mul = 75.
        self.lm_head.weight.lr_mul = 27.5
        self.scalars.lr_mul = 5.0


    def forward(self, input_seq: Tensor, target_seq: Tensor, long_bm: BlockMask, short_bm: BlockMask):
        assert input_seq.ndim == 2

        ve = [value_embed(input_seq) for value_embed in self.value_embeds]
        # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure
        ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]]
        assert len(ve) == len(self.blocks)

        block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm,
                       short_bm, long_bm]
        assert len(block_masks) == len(self.blocks)
        
        x = x0 = norm(self.embed(input_seq))  # use of norm here by @Grad62304977

        # U-net design by @brendanh0gan
        skip_connections = []
        skip_weights = self.scalars[:(len(self.blocks) // 2)]
        lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2)
        sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2)

        n = len(self.blocks) // 2

        for i in range(len(self.blocks)):
            if i >= n:
                x = x + skip_weights[i - n] * skip_connections.pop()
            x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i])
            if i < n:
                skip_connections.append(x)

        x = norm(x)

        logits = self.lm_head(x).float()

        # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1)
        logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1) ** 0.5))
        # loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq,
        #                        reduction="sum" if self.training else "mean")
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1),
                       reduction="sum" if self.training else "mean")
        return loss


## Data loader

In [23]:
class EOSBatchFinder:
    """
    Distributed data generator class.
    - align_to_bos: True -> every sequence starts at the beginning of a document.
    - Supports batch_size and seq_len getting changed changed after initialization.
    """
    def __init__(self, filename_pattern: str, seq_len: int, batch_size: int, 
                 align_to_bos: bool = True, eos_id: int = 50256, device: str = "cuda",
    ):
        self.files: List[Path] = [Path(p) for p in sorted(glob.glob(filename_pattern))]
        if not self.files:
            raise FileNotFoundError(f"No files match: {filename_pattern}")
        self.file_iter = iter(self.files)

        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()

        self.seq_len, self.batch_size = seq_len, batch_size
        self.align_to_bos, self.eos_id, self.device = align_to_bos, eos_id, device

        if self.batch_size % self.world_size:
            raise ValueError("batch_size must be divisible by world_size")

        self.tokens, self.pos = self._load_next_shard()
        self.eos_idx, self.i = self._build_eos_index(), 0
        if self.align_to_bos:
            self._seek(self.pos)

    def _load_next_shard(self):
        path = next(self.file_iter)  # may raise StopIteration (end of data)
        header = torch.from_file(str(path), False, 256, dtype=torch.int32) # header is 256 int32
        assert header[0] == 20240520, "magic number mismatch in the data .bin file"
        assert header[1] == 1, "unsupported version"
        num_tokens = int(header[2]) # number of tokens (claimed)
        with path.open("rb", buffering=0) as f:
            tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng
            f.seek(256 * 4)
            nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng
            assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
        return tokens, 0

    def _build_eos_index(self):
        return (self.tokens == self.eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu()

    def _seek(self, pos: int):
        j = int(torch.searchsorted(self.eos_idx, int(pos)))
        if j >= int(self.eos_idx.numel()):
            raise StopIteration("Seek past last EOS in shard.")
        self.i, self.pos = j, pos
    
    def generator(self):
        return self._gen_eos() if self.align_to_bos else self._gen_contig()

    def _gen_contig(self):
        """Contiguous slicing (no EOS alignment)."""
        while True:
            B, L, W, R = self.batch_size, self.seq_len, self.world_size, self.rank
            if B % W: 
                raise ValueError("batch_size must be divisible by world_size")
            B_local = B // W
            tokens_global, tokens_local = B * L, (B // W) * L

            if self.pos + tokens_global + 1 >= len(self.tokens):
                self.tokens, self.pos = self._load_next_shard(); continue

            start = self.pos + R * tokens_local
            buf = self.tokens[start : start + tokens_local + 1]
            if buf.numel() < tokens_local + 1:
                self.tokens, self.pos = self._load_next_shard(); continue

            x = buf[:-1].view(B_local, L)
            y = buf[1: ].view(B_local, L)
            self.pos += tokens_global

            yield (
                x.to(self.device, dtype=torch.int32, non_blocking=True),
                y.to(self.device, dtype=torch.int64, non_blocking=True)
            )

    def _gen_eos(self):
        """EOS-aligned slicing (each sample starts after an EOS)."""
        while True:
            B, L, W, R = self.batch_size, self.seq_len, self.world_size, self.rank
            if B % W: 
                raise ValueError("batch_size must be divisible by world_size")
            B_local = B // W

            n = int(self.eos_idx.numel())
            if self.i >= n:
                self.tokens, self.pos = self._load_next_shard()
                self.eos_idx, self.i = self._build_eos_index(), 0
                continue

            starts = [[] for _ in range(W)]
            idx = self.i
            cur = self.eos_idx[idx]
            
            ok = True
            for r in range(W):
                for _ in range(B_local):
                    start = cur + 1
                    target = start + L
                    j = int(torch.searchsorted(self.eos_idx, target))
                    if j >= n:
                        ok = False; break
                    starts[r].append(start)
                    idx, cur = j, self.eos_idx[idx]
                if not ok: break

            if not ok:
                self.tokens, self.pos = self._load_next_shard()
                self.eos_idx, self.i = self._build_eos_index(), 0
                continue

            self.pos += self.eos_idx[idx] - self.pos
            self.i = idx

            bufs = [self.tokens[s : s + L + 1] for s in starts[R]]
            buf = torch.stack(bufs, dim=0)  # [B_local, L+1]
            x, y = buf[:, :-1], buf[:, 1:]

            yield (
                x.to(self.device, dtype=torch.int32, non_blocking=True),
                y.to(self.device, dtype=torch.int64, non_blocking=True)
            )

## dist init

In [24]:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
os.environ['WORLD_SIZE'] = '1'
os.environ['LOCAL_WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'  # This would be 0-7 for each process
os.environ['LOCAL_RANK'] = '0'  # This would be 0-7 for each process
os.environ['GROUP_RANK'] = '0'
os.environ['ROLE_RANK'] = '0'
os.environ['ROLE_NAME'] = 'default'
os.environ['ROLE_WORLD_SIZE'] = '1'
os.environ['TORCHELASTIC_RESTART_COUNT'] = '0'
os.environ['TORCHELASTIC_MAX_RESTARTS'] = '0'
os.environ['TORCHELASTIC_RUN_ID'] = 'none'

In [25]:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
master_process = (rank == 0)  # this process will do logging, checkpointing etc.
# assert world_size == 8  # this code is designed for 8xH100

In [26]:
assert torch.cuda.is_available()
device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
torch.cuda.set_device(device)

In [27]:
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()

## hyperparams

In [28]:
def next_multiple_of_n(v: float | int, *, n: int):
    return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)

In [29]:
# potentially refactor code to take in minibatch sizes (e.g. per world sizes) instead
# currently batch size gets distributed across all GPUs -- in this case, just 1
# at current setup, when we scale to 8 gpus, we want to multiply batch size by 8 to maintain minibatch size
@dataclass
class Hyperparameters:
    # data
    train_files: str = "../data/fineweb10B/fineweb_train_*.bin"  # input .bin to train on
    val_files: str = "../data/fineweb10B/fineweb_val_*.bin"  # input .bin to eval validation loss on
    val_tokens: int = 10485760  # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
    
    # train_seq_len = 1024 * 6
    # train_batch_size = 8 * world_size

    # val_seq_len = 1024 * 5  # FlexAttention sequence length for validation
    # val_batch_size = 4 * world_size
    
    # optimization
    num_iterations = 1750  # number of iterations to run
    cooldown_frac = 0.45  # fraction of training spent cooling down the learning rate
    
    # evaluation and logging
    val_loss_every = 100  # every how many steps to evaluate val loss? 0 for only at the end
    save_checkpoint = False

    vocab_size = next_multiple_of_n(50257, n=128)
    num_layers = 12
    num_heads = 6
    model_dim = 768

    block_size = 128

args = Hyperparameters()

In [30]:
val_seq_len, val_batch_size = 1024 * 5, 4

### schedule

In [31]:
from bisect import bisect_right

In [32]:
def make_piecewise(values, breaks):
    assert len(values) == len(breaks) + 1, "values must be one longer than breaks"
    
    def f(t: int) -> int:
        i = bisect_right(breaks, t)
        return values[i]

    return f

In [33]:
schedule = [
    {
        'dense': 2, 'sparse': 1,
        'seq_len': 1024 * 6, 'batch_size': 8,
        'lr_mult': 1.0,
    },
    {
        'dense': 2, 'sparse': 1,
        'seq_len': 1024 * 6, 'batch_size': 8,
        'lr_mult': 1.0,
    },
        {
        'dense': 8, 'sparse': 2,
        'seq_len': 1024 * 6, 'batch_size': 8,
        'lr_mult': 1.0,
    },
    {
        'dense': 8, 'sparse': 2,
        'seq_len': 1024 * 6, 'batch_size': 8,
        'lr_mult': 1.0,
    },
    {
        'dense': 8, 'sparse': 2,
        'seq_len': 1024 * 6, 'batch_size': 8,
        'lr_mult': 1.0,
    },
]
scheduler = make_piecewise(schedule, breaks = [324, 768, 1152, 1476])

In [34]:
# learning rate schedule: stable then decay
def get_lr(step: int):
    x = step / args.num_iterations  # progress in training
    assert 0 <= x < 1
    lr = 1.0
    if x >= 1 - args.cooldown_frac:
        w = (1 - x) / args.cooldown_frac
        lr = w * 1.0 + (1 - w) * 0.1
    return lr * scheduler(step)['lr_mult']

## Model init

In [35]:
model: nn.Module = GPT(vocab_size=args.vocab_size, 
                       num_layers=args.num_layers, 
                       num_heads=args.num_heads, 
                       model_dim=args.model_dim,
                       max_seq_len=max(schedule[-1]['seq_len'], val_seq_len),
                      ).cuda()

In [36]:
for m in model.modules():
    if isinstance(m, nn.Embedding):
        m.bfloat16()
for param in model.parameters():
    dist.broadcast(param.detach(), 0)

In [37]:
# collect the parameters to optimize
hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
scalar_params = [p for p in model.parameters() if p.ndim < 2]
head_params = [model.lm_head.weight]

In [38]:
# init the optimizer(s)
# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence
# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094
optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10,
                      weight_decay=0.0)
optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0)
optimizers = [optimizer1, optimizer2]
for opt in optimizers:
    for group in opt.param_groups:
        group["initial_lr"] = group["lr"]

In [39]:
model: nn.Module = torch.compile(model, dynamic=False)

## Training

### Warmup

In [40]:
from tqdm.auto import tqdm
from torch._dynamo import reset
from torch._dynamo.utils import counters

In [41]:
# Warmup the training kernels, then re-initialize the state so we aren't cheating
warmup_steps = 100
initial_state = dict(model=copy.deepcopy(model.state_dict()),
                     optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers])  # save the initial state


In [42]:
train_loader = EOSBatchFinder(args.train_files, scheduler(0)['seq_len'], scheduler(0)['batch_size'])
train_loader_gen = train_loader.generator()

 878ms wall time. 

In [43]:
for step in tqdm(range(warmup_steps)):
    inputs, targets = next(train_loader_gen)
    long_bm = get_blockmask(train_loader.seq_len, schedule[step % 4]['dense'])
    short_bm = get_blockmask(train_loader.seq_len, schedule[step % 4]['sparse'])
    model(inputs, targets, long_bm, short_bm).backward()
    for opt in optimizers:
        opt.step()
    model.zero_grad(set_to_none=True)

  0%|          | 0/100 [00:00<?, ?it/s]

  return torch._C._get_cublas_allow_tf32()


In [44]:
model.load_state_dict(initial_state["model"])
for opt, opt_state in zip(optimizers, initial_state["optimizers"]):
    opt.load_state_dict(opt_state)

In [45]:
del train_loader, initial_state

### Train loop

In [46]:
train_loader = EOSBatchFinder(args.train_files, scheduler(0)['seq_len'], scheduler(0)['batch_size'])
train_loader_gen = train_loader.generator()

In [47]:
train_steps = args.num_iterations
training_time_ms, tokens_processed = 0, 0

In [48]:
scheduler(1000)

{'dense': 8, 'sparse': 2, 'seq_len': 6144, 'batch_size': 8, 'lr_mult': 1.0}

In [48]:
# start the clock
torch.cuda.synchronize()
t0 = time.perf_counter()
# begin training
for step in range(train_steps + 1):
    last_step = (step == train_steps)

    params = scheduler(step)
    # train_loader.seq_len = params['seq_len']
    # train_loader.batch_size = params['batch_size']
    
    n_windows_long = params['dense']
    n_windows_short = params['sparse']

    # --------------- VALIDATION SECTION -----------------
    if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
        print('---validation---')
        # stop the clock
        torch.cuda.synchronize()
        training_time_ms += 1000 * (time.perf_counter() - t0)
        model.eval()
        
        long_bm = get_blockmask(val_seq_len, n_windows_long)
        short_bm = get_blockmask(val_seq_len, n_windows_short)

        assert args.val_tokens % (val_batch_size * val_seq_len) == 0
        val_steps = args.val_tokens // (val_batch_size * val_seq_len)

        val_loader = EOSBatchFinder(args.val_files, val_seq_len, val_batch_size, align_to_bos=False).generator()
        
        val_loss = 0
        with torch.no_grad():
            for _ in range(val_steps):
                inputs, targets = next(val_loader)
                val_loss += model(inputs, targets, long_bm, short_bm)
        val_loss /= val_steps
        del val_loader
        dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
        print0(
            f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms",
            console=True)

        model.train()
        # start the clock again
        print('---end of validation---')
        torch.cuda.synchronize()
        t0 = time.perf_counter()

    if last_step:
        if master_process and args.save_checkpoint:
            log = dict(step=step, model=model.state_dict(),
                       optimizers=[opt.state_dict() for opt in optimizers])
            os.makedirs(f"logs/{run_id}", exist_ok=True)
            torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt")
        # the last step only has the validation loop, so break to avoid training
        break

    # --------------- TRAINING SECTION -----------------
    inputs, targets = next(train_loader_gen)
    if step == 0: print("First inputs retrieved")
    tokens_processed += inputs.numel() * world_size

    long_bm = get_blockmask(train_loader.seq_len, n_windows_long)
    short_bm = get_blockmask(train_loader.seq_len, n_windows_short)

    model(inputs, targets, long_bm, short_bm).backward()

    for opt in optimizers:
        for group in opt.param_groups:
            group["lr"] = group["initial_lr"] * get_lr(step)

    for group in optimizer2.param_groups:
        frac = min(step / 300, 1)  # momentum warmup for muon
        group["momentum"] = (1 - frac) * 0.85 + frac * 0.95

    for opt in optimizers:
        opt.step()
    model.zero_grad(set_to_none=True)

    approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
    print0(
        f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms tokens_processed:{tokens_processed}",
        console=True)

---validation---
step:0/1750 val_loss:10.8258 train_time:0ms step_avg:0.37ms
---end of validation---
First inputs retrieved
step:1/1750 train_time:159ms step_avg:158.90ms tokens_processed:49152
step:2/1750 train_time:296ms step_avg:148.15ms tokens_processed:98304
step:3/1750 train_time:443ms step_avg:147.51ms tokens_processed:147456
step:4/1750 train_time:593ms step_avg:148.15ms tokens_processed:196608
step:5/1750 train_time:743ms step_avg:148.55ms tokens_processed:245760
step:6/1750 train_time:893ms step_avg:148.87ms tokens_processed:294912
step:7/1750 train_time:1044ms step_avg:149.16ms tokens_processed:344064
step:8/1750 train_time:1199ms step_avg:149.86ms tokens_processed:393216
step:9/1750 train_time:1350ms step_avg:150.04ms tokens_processed:442368
step:10/1750 train_time:1502ms step_avg:150.21ms tokens_processed:491520
step:11/1750 train_time:1653ms step_avg:150.26ms tokens_processed:540672
step:12/1750 train_time:1805ms step_avg:150.45ms tokens_processed:589824
step:13/1750 trai

KeyboardInterrupt: 

In [None]:
# start the clock
torch.cuda.synchronize()
t0 = time.perf_counter()
# begin training
for step in range(train_steps + 1):
    last_step = (step == train_steps)

    params = scheduler(step)
    
    n_windows_long = params['dense']
    n_windows_short = params['sparse']

    # --------------- VALIDATION SECTION -----------------
    if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
        print('---validation---')
        # stop the clock
        torch.cuda.synchronize()
        training_time_ms += 1000 * (time.perf_counter() - t0)
        model.eval()
        
        long_bm = get_blockmask(val_seq_len, n_windows_long)
        short_bm = get_blockmask(val_seq_len, n_windows_short)

        assert args.val_tokens % (val_batch_size * val_seq_len) == 0
        val_steps = args.val_tokens // (val_batch_size * val_seq_len)

        val_loader = EOSBatchFinder(args.val_files, val_seq_len, val_batch_size, align_to_bos=False).generator()
        
        val_loss = 0
        with torch.no_grad():
            for _ in range(val_steps):
                inputs, targets = next(val_loader)
                val_loss += model(inputs, targets, long_bm, short_bm)
        val_loss /= val_steps
        del val_loader
        dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
        print0(
            f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms",
            console=True)

        model.train()
        # start the clock again
        print('---end of validation---')
        torch.cuda.synchronize()
        t0 = time.perf_counter()

    if last_step:
        if master_process and args.save_checkpoint:
            log = dict(step=step, model=model.state_dict(),
                       optimizers=[opt.state_dict() for opt in optimizers])
            os.makedirs(f"logs/{run_id}", exist_ok=True)
            torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt")
        # the last step only has the validation loop, so break to avoid training
        break

    # --------------- TRAINING SECTION -----------------
    inputs, targets = next(train_loader_gen)
    if step == 0: print("First inputs retrieved", inputs.shape)
    tokens_processed += inputs.numel() * world_size

    long_bm = get_blockmask(train_loader.seq_len, n_windows_long)
    short_bm = get_blockmask(train_loader.seq_len, n_windows_short)

    model(inputs, targets, long_bm, short_bm).backward()

    for opt in optimizers:
        for group in opt.param_groups:
            group["lr"] = group["initial_lr"] * get_lr(step)

    for group in optimizer2.param_groups:
        frac = min(step / 300, 1)  # momentum warmup for muon
        group["momentum"] = (1 - frac) * 0.85 + frac * 0.95

    for opt in optimizers:
        opt.step()
    model.zero_grad(set_to_none=True)

    approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
    print0(
        f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms tokens_processed:{tokens_processed}",
        console=True)

---validation---
step:0/1750 val_loss:10.8258 train_time:1ms step_avg:0.78ms
---end of validation---
First inputs retrieved torch.Size([8, 6144])
step:1/1750 train_time:157ms step_avg:157.05ms tokens_processed:49152
step:2/1750 train_time:293ms step_avg:146.59ms tokens_processed:98304
step:3/1750 train_time:439ms step_avg:146.48ms tokens_processed:147456
step:4/1750 train_time:589ms step_avg:147.36ms tokens_processed:196608
step:5/1750 train_time:737ms step_avg:147.46ms tokens_processed:245760
step:6/1750 train_time:887ms step_avg:147.80ms tokens_processed:294912
step:7/1750 train_time:1037ms step_avg:148.10ms tokens_processed:344064
step:8/1750 train_time:1192ms step_avg:149.03ms tokens_processed:393216
step:9/1750 train_time:1345ms step_avg:149.42ms tokens_processed:442368
step:10/1750 train_time:1495ms step_avg:149.54ms tokens_processed:491520
step:11/1750 train_time:1646ms step_avg:149.60ms tokens_processed:540672
step:12/1750 train_time:1796ms step_avg:149.63ms tokens_processed:58

In [None]:
print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
       f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True)
dist.destroy_process_group()
