# Chapter 9 — Building the Model Step by Step (Colab)

Self-contained notebook that assembles a compact GPT using the classes developed in the book. One creation per cell; show each created object right away.

In [None]:
# Colab-friendly: ensure PyTorch is available
import sys, subprocess
try:
    import torch  # noqa: F401
    print('PyTorch found')
except Exception:
    print('Installing PyTorch...')
    # Heuristic: if NVIDIA driver present, prefer a CUDA build
    has_cuda = False
    try:
        r = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        has_cuda = (r.returncode == 0)
    except Exception:
        has_cuda = False
    if has_cuda:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--index-url', 'https://download.pytorch.org/whl/cu121', 'torch', 'torchvision', 'torchaudio'])
    else:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--index-url', 'https://download.pytorch.org/whl/cpu', 'torch', 'torchvision', 'torchaudio'])
    import torch  # noqa: F401
print('torch', torch.__version__)


In [None]:
# Imports, style, and a tiny seed
import math, platform
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
plt.style.use('seaborn-v0_8')
%config InlineBackend.figure_format = 'svg'
torch.manual_seed(0)
device = (
    'cuda' if torch.cuda.is_available() else
    'mps' if getattr(torch.backends, 'mps', None)
            and torch.backends.mps.is_available()
    else 'cpu'
)
device


In [None]:
# Sinusoidal positions (Chapter 8)
def sinusoidal_positions(T: int, d_model: int, device: torch.device | None = None) -> torch.Tensor:
    pos = torch.arange(T, device=device).float()[:, None]
    i = torch.arange(d_model, device=device).float()[None, :]
    angle = pos / (10000 ** (2 * (i//2) / d_model))
    enc = torch.zeros(T, d_model, device=device)
    enc[:, 0::2] = torch.sin(angle[:, 0::2])
    enc[:, 1::2] = torch.cos(angle[:, 1::2])
    return enc
sinusoidal_positions(4, 8).shape


In [None]:
# Multi-head self-attention (Chapter 8)
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
        super().__init__()
        assert d_model % num_heads == 0
        self.h = num_heads
        self.d = d_model // num_heads
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out = nn.Linear(d_model, d_model, bias=False)
        self.drop = nn.Dropout(dropout)
    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None):
        B, T, Dm = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
        def split(t):
            return t.view(B, T, self.h, self.d).transpose(1, 2)
        q, k, v = map(split, (q, k, v))
        sdpa_mask = None
        if mask is not None:
            if mask.dim() == 2:
                base = (mask == 0).bool()[None, None, :, :]
                sdpa_mask = base.expand(B, self.h, T, T)
            elif mask.dim() == 3:
                base = (mask == 0).bool().unsqueeze(1)
                sdpa_mask = base.expand(B, self.h, T, T)
            elif mask.dim() == 4:
                if mask.size(1) == 1:
                    sdpa_mask = (mask == 0).bool().expand(B, self.h, T, T)
                else:
                    sdpa_mask = (mask == 0).bool()
        attn = F.scaled_dot_product_attention(q, k, v, attn_mask=sdpa_mask)
        attn = self.drop(attn)
        y = attn.transpose(1, 2).contiguous().view(B, T, Dm)
        return self.out(y)
MultiHeadAttention(16, 4)


In [None]:
# Feed-forward (Chapter 8)
class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
    def forward(self, x):
        return self.net(x)
FeedForward(16, 64, 0.1)


In [None]:
# Residual wrapper (pre-norm)
class Residual(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
    def forward(self, x, sublayer: nn.Module, *args, **kwargs):
        return x + sublayer(self.norm(x), *args, **kwargs)
Residual(16)


In [None]:
# Transformer block (Chapter 8)
class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.0):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.res1 = Residual(d_model)
        self.res2 = Residual(d_model)
    def forward(self, x, mask=None):
        x = self.res1(x, self.mha, mask)
        x = self.res2(x, self.ffn)
        return x
TransformerBlock(16, 4, 64, 0.1)


In [None]:
# GPTConfig dataclass
from dataclasses import dataclass
@dataclass
class GPTConfig:
    vocab_size: int
    block_size: int
    d_model: int = 128
    n_head: int = 4
    n_layer: int = 2
    d_ff: int = 512
    dropout: float = 0.1
    pos_type: str = 'learned'
    tie_weights: bool = True
cfg = GPTConfig(vocab_size=256, block_size=32)
cfg


In [None]:
# GPT model (compact)
class GPT(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.cfg = cfg
        V, Tm, D = cfg.vocab_size, cfg.block_size, cfg.d_model
        self.tok_emb = nn.Embedding(V, D)
        self.pos_emb = nn.Embedding(Tm, D) if cfg.pos_type == 'learned' else None
        self.drop = nn.Dropout(cfg.dropout)
        self.blocks = nn.ModuleList([TransformerBlock(D, cfg.n_head, cfg.d_ff, cfg.dropout) for _ in range(cfg.n_layer)])
        self.norm_f = nn.LayerNorm(D)
        self.lm_head = nn.Linear(D, V, bias=False)
        if cfg.tie_weights:
            self.lm_head.weight = self.tok_emb.weight
        self.apply(self._init_weights)
    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.zeros_(m.bias)
    def _build_mask(self, input_ids, attention_mask=None, pad_id=None):
        B, T = input_ids.size()
        device = input_ids.device
        causal = torch.tril(torch.ones(T, T, device=device))
        if attention_mask is not None:
            pad_bt = attention_mask.float()
        elif pad_id is not None:
            pad_bt = (input_ids != pad_id).float()
        else:
            return causal.unsqueeze(0).expand(B, -1, -1)
        return pad_bt[:, None, :] * causal
    def forward(self, input_ids, targets=None, attention_mask=None, pad_id=None):
        B, T = input_ids.size()
        assert T <= self.cfg.block_size
        x = self.tok_emb(input_ids)
        if self.cfg.pos_type == 'learned':
            positions = torch.arange(T, device=input_ids.device)[None, :]
            x = x + self.pos_emb(positions)
        else:
            pe = sinusoidal_positions(T, self.cfg.d_model, device=input_ids.device)
            x = x + pe[None, :, :]
        x = self.drop(x)
        mask = self._build_mask(input_ids, attention_mask, pad_id)
        for block in self.blocks:
            x = block(x, mask)
        x = self.norm_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            logits_f = logits.reshape(B*T, -1)
            targets_f = targets.reshape(B*T)
            ignore = pad_id if pad_id is not None else -100
            loss = F.cross_entropy(logits_f, targets_f, ignore_index=ignore)
        return logits, loss
model = GPT(cfg).to(device); model


In [None]:
# Create a tiny batch of random token ids and show shape
B, T = 2, 16
x = torch.randint(0, cfg.vocab_size, (B, T), device=device)
x.shape


In [None]:
# Forward without targets: logits only
with torch.no_grad():
    logits, _ = model(x)
logits.shape


In [None]:
# Forward with targets: compute next-token loss
y = torch.randint(0, cfg.vocab_size, (B, T), device=device)
_, loss = model(x, targets=y)
float(loss.detach().cpu().item())


In [None]:
# Tiny optimization sanity check: loss should decrease a bit
opt = torch.optim.AdamW(model.parameters(), lr=3e-3)
hist = []
for step in range(5):
    opt.zero_grad()
    _, loss = model(x, targets=y)
    loss.backward()
    opt.step()
    hist.append(float(loss.detach().cpu().item()))
hist
