<img src="https://theaiengineer.dev/tae_logo_gw_flatter.png" width=35% align=right>

# Building a Large Language Model from Scratch — A Step-by-Step Guide Using Python and PyTorch
## Chapter 14 — Beyond the Basics
**© Dr. Yves J. Hilpisch**<br>AI-Powered by GPT-5.

## How to Use This Notebook

- Explore advanced training techniques such as curriculum learning or reinforcement learning signals.
- Benchmark new architectures against your baseline transformer.
- Identify which techniques are worth productizing versus leaving as research spikes.

### Roadmap

We survey techniques that push performance beyond the vanilla transformer, highlighting when to reach for each lever.

### Study Tips

Keep track of assumptions. Advanced tricks often rely on dataset or objective specifics—note them explicitly.

In [None]:
# Setup: Torch, plotting style, device, and contextlib
import torch, contextlib
import matplotlib.pyplot as plt
plt.style.use('seaborn-v0_8')
%config InlineBackend.figure_format = 'svg'
torch.manual_seed(0)
device = ('cuda' if torch.cuda.is_available() else
          'mps' if getattr(torch.backends, 'mps', None) and
                   torch.backends.mps.is_available() else 'cpu')
device


In [None]:
# LoRA: low-rank adapters for a Linear layer
class LoRALinear(torch.nn.Module):
    def __init__(self, d_in, d_out, r=8, alpha=16.0, bias=False):
        """Create a Linear with LoRA adapters.
        Only the small A (r x d_in) and B (d_out x r) train.
        scale = alpha / r.
        """
        super().__init__()
        self.base = torch.nn.Linear(d_in, d_out, bias=bias)
        self.r = int(r); self.scale = float(alpha) / max(1, int(r))
        if self.r > 0:
            self.A = torch.nn.Linear(d_in, self.r, bias=False)
            self.B = torch.nn.Linear(self.r, d_out, bias=False)
            torch.nn.init.kaiming_uniform_(self.A.weight, a=2**0.5)
            torch.nn.init.zeros_(self.B.weight)
            for p in self.base.parameters(): p.requires_grad = False
        else:
            self.A = None; self.B = None
        self.merged = False
    def forward(self, x):
        y = self.base(x)
        if self.r > 0 and not self.merged: y = y + self.scale * self.B(self.A(x))
        return y
    @torch.no_grad()
    def merge(self):
        """Fold LoRA delta into base weight for inference.
        """
        if self.r == 0 or self.merged: self.merged = True; return
        delta = (self.B.weight @ self.A.weight) * self.scale
        self.base.weight += delta; self.merged = True
LoRALinear(16, 16, r=4)


In [None]:
# Parameter counts and trainables
def count_params(m): return sum(p.numel() for p in m.parameters())
def count_trainable(m): return sum(p.numel() for p in m.parameters()
                                 if p.requires_grad)
base = torch.nn.Linear(256, 256, bias=False)
lora = LoRALinear(256, 256, r=8, alpha=16, bias=False)
count_params(base), count_trainable(base),
count_params(lora), count_trainable(lora)


### LoRA mini‑training demo

Train only A/B to fit a small mapping; compare to a fully
trainable Linear as a sanity check.


In [None]:
# Synthetic mapping: y = target @ x^T (no bias)
torch.manual_seed(0)
d=64; target = torch.randn(d, d)
X = torch.randn(256, d); Y = X @ target.t()
# Full linear vs LoRA (frozen base)
full = torch.nn.Linear(d, d, bias=False)
lora = LoRALinear(d, d, r=8, alpha=16, bias=False)
opt_full = torch.optim.Adam(full.parameters(), lr=3e-2)
opt_lora = torch.optim.Adam([p for p in lora.parameters() if p.requires_grad], lr=3e-2)
losses_f, losses_l = [], []
for step in range(60):
    # full
    opt_full.zero_grad(); yhat = full(X); lf = ((yhat - Y)**2).mean(); lf.backward(); opt_full.step()
    # lora
    opt_lora.zero_grad(); yhat2 = lora(X); ll = ((yhat2 - Y)**2).mean(); ll.backward(); opt_lora.step()
    if step % 10 == 0: losses_f.append(float(lf.detach())); losses_l.append(float(ll.detach()))
losses_f, losses_l


In [None]:
# Plot LoRA vs full Linear losses collected every 10 steps
import matplotlib.pyplot as plt
plt.figure(figsize=(4.6,2.8));
plt.plot([i*10 for i in range(len(losses_f))], losses_f, label='full', color='#0A66C2');
plt.plot([i*10 for i in range(len(losses_l))], losses_l, label='lora', color='#DD4444');
plt.title('Mini training: full vs LoRA'); plt.xlabel('step'); plt.ylabel('MSE');
plt.legend(); plt.show()


In [None]:
# Tiny weight-only quantization demo (per-tensor, symmetric)
def fake_int8_weight_only(W):
    """Return (q:int8, scale:float) with W ≈ scale * q.
    """
    maxv = W.abs().max().clamp(min=1e-8)
    scale = float(maxv / 127.0)
    q = torch.clamp((W / scale).round(), -127, 127).to(torch.int8)
    return q, scale
def dequant(q, scale): return q.float() * scale
W = torch.randn(256, 256)
q, s = fake_int8_weight_only(W); rec = dequant(q, s)
float((W - rec).abs().mean())

# Error histogram
import matplotlib.pyplot as plt
err = (W - rec).view(-1).detach().numpy()
plt.figure(figsize=(4.2,2.6)); plt.hist(err, bins=40, color='#0A66C2');
plt.title('Quantization error'); plt.show()


In [None]:
# Distillation step with temperature; robust to modules returning
# logits or (logits, loss).
import torch.nn.functional as F
def _forward_logits(m, x):
    out = m(x)
    return out[0] if isinstance(out, (tuple, list)) else out
def distill_step(student, teacher, x, T=2.0, lam=0.5):
    """Return distillation loss (batchmean KL scaled by T^2).
    """
    with torch.no_grad(): t_logits = _forward_logits(teacher, x)
    s_logits = _forward_logits(student, x)
    t = F.log_softmax(t_logits / T, dim=-1)
    s = F.log_softmax(s_logits / T, dim=-1)
    kl = F.kl_div(s, t, log_target=True, reduction='batchmean') * (T*T)
    return lam * kl
teacher = torch.nn.Linear(32, 32)
student = torch.nn.Linear(32, 32)
x = torch.randn(8, 32)
float(distill_step(student, teacher, x).detach().item())


## Exercises

- Prototype curriculum learning by varying sequence length during training.
- Add reinforcement learning from human feedback (RLHF) stubs and document prerequisites.
- Summarize the pros/cons of two advanced techniques you trialed, including resource implications.

<img src="https://theaiengineer.dev/tae_logo_gw_flatter.png" width=35% align=right>