# 10. AdamW Optimizer

**Using AdamW to update all parameters and complete the training step**

We've computed gradients for every parameter in the model. Now we need to **use** those gradients to actually update the weights.

We're using **AdamW**—the optimizer used to train GPT, LLaMA, and virtually every modern LLM. It's the industry standard for a reason.

## Why Not Just Subtract the Gradient?

You might think: "We have gradients. Just do `θ = θ - lr * gradient` and we're done."

That's **stochastic gradient descent (SGD)**, and it has problems:

1. **Same learning rate for all parameters**—some need big updates, others small
2. **No momentum**—can get stuck oscillating in valleys
3. **Sensitive to learning rate**—too high diverges, too low is slow

AdamW solves all of these.

## What Is AdamW?

AdamW combines three powerful ideas:

1. **Adaptive learning rates**—each parameter gets its own rate based on gradient history
2. **Momentum**—updates are smoothed using exponential moving averages
3. **Weight decay**—regularization applied directly to parameters

In [None]:
import math

# AdamW hyperparameters
learning_rate = 0.001    # α: base learning rate
beta1 = 0.9              # β₁: decay for first moment (momentum)
beta2 = 0.999            # β₂: decay for second moment (adaptive LR)
epsilon = 1e-8           # ε: numerical stability
weight_decay = 0.01      # λ: L2 regularization strength
t = 1                    # time step (first update)

print("AdamW Hyperparameters:")
print(f"  learning_rate (α) = {learning_rate}")
print(f"  beta1 (β₁)        = {beta1}")
print(f"  beta2 (β₂)        = {beta2}")
print(f"  epsilon (ε)       = {epsilon}")
print(f"  weight_decay (λ)  = {weight_decay}")

## The AdamW Algorithm

For each parameter $\theta$, AdamW maintains two **moment estimates**:

1. **First moment** $m$—exponential moving average of gradients (momentum)
2. **Second moment** $v$—exponential moving average of squared gradients

### Step 1: Update Biased Moment Estimates

$$m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t$$
$$v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2$$

### Step 2: Bias Correction

$$\hat{m}_t = \frac{m_t}{1 - \beta_1^t}$$
$$\hat{v}_t = \frac{v_t}{1 - \beta_2^t}$$

### Step 3: Weight Decay

$$\theta_{\text{decayed}} = \theta \cdot (1 - \alpha \cdot \lambda)$$

### Step 4: Update

$$\theta_{\text{new}} = \theta_{\text{decayed}} - \frac{\alpha \cdot \hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$$

In [None]:
def adamw_update(theta, gradient, m, v, t, 
                 lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8, wd=0.01):
    """
    Perform one AdamW update step.
    
    Args:
        theta: current parameter value
        gradient: gradient of loss w.r.t. theta
        m: first moment estimate (momentum)
        v: second moment estimate (adaptive LR)
        t: timestep (starting from 1)
        lr, beta1, beta2, eps, wd: hyperparameters
    
    Returns:
        new_theta, new_m, new_v
    """
    # Step 1: Update biased moments
    m_new = beta1 * m + (1 - beta1) * gradient
    v_new = beta2 * v + (1 - beta2) * gradient**2
    
    # Step 2: Bias correction
    m_hat = m_new / (1 - beta1**t)
    v_hat = v_new / (1 - beta2**t)
    
    # Step 3: Weight decay
    theta_decayed = theta * (1 - lr * wd)
    
    # Step 4: Update
    theta_new = theta_decayed - lr * m_hat / (math.sqrt(v_hat) + eps)
    
    return theta_new, m_new, v_new

## Example: Updating One Parameter

Let's walk through updating the first element of the `<BOS>` token embedding.

In [None]:
# Example parameter
theta = 0.024634      # Current value
g = -0.352893         # Gradient (negative = increase to reduce loss)
m_0 = 0.0             # Initial first moment
v_0 = 0.0             # Initial second moment

print("Initial State:")
print(f"  θ (parameter)  = {theta:.6f}")
print(f"  g (gradient)   = {g:.6f}")
print(f"  m₀ (momentum)  = {m_0:.6f}")
print(f"  v₀ (variance)  = {v_0:.6f}")

In [None]:
# Step 1: Update biased moments
m_1 = beta1 * m_0 + (1 - beta1) * g
v_1 = beta2 * v_0 + (1 - beta2) * g**2

print("Step 1: Update Biased Moments")
print(f"  m₁ = {beta1} × {m_0} + {1-beta1} × {g:.6f}")
print(f"     = {m_1:.6f}")
print()
print(f"  v₁ = {beta2} × {v_0} + {1-beta2} × {g:.6f}²")
print(f"     = {v_1:.6f}")

In [None]:
# Step 2: Bias correction
m_hat = m_1 / (1 - beta1**t)
v_hat = v_1 / (1 - beta2**t)

print("Step 2: Bias Correction")
print(f"  m̂ = {m_1:.6f} / (1 - {beta1}¹)")
print(f"    = {m_1:.6f} / {1 - beta1}")
print(f"    = {m_hat:.6f}")
print()
print(f"  v̂ = {v_1:.6f} / (1 - {beta2}¹)")
print(f"    = {v_1:.6f} / {1 - beta2}")
print(f"    = {v_hat:.6f}")
print()
print(f"Note: m̂ equals the gradient! (Expected for first step)")

In [None]:
# Step 3: Weight decay
theta_decayed = theta * (1 - learning_rate * weight_decay)

print("Step 3: Weight Decay")
print(f"  θ_decayed = {theta:.6f} × (1 - {learning_rate} × {weight_decay})")
print(f"           = {theta:.6f} × {1 - learning_rate * weight_decay}")
print(f"           = {theta_decayed:.6f}")
print()
print(f"Decay is tiny: {theta - theta_decayed:.9f}")

In [None]:
# Step 4: Compute update
adaptive_lr = learning_rate / (math.sqrt(v_hat) + epsilon)
update = m_hat * adaptive_lr
theta_new = theta_decayed - update

print("Step 4: Compute Update")
print(f"  Adaptive LR = {learning_rate} / (√{v_hat:.6f} + {epsilon})")
print(f"             = {learning_rate} / {math.sqrt(v_hat):.6f}")
print(f"             = {adaptive_lr:.6f}")
print()
print(f"  This is {adaptive_lr/learning_rate:.2f}× the base learning rate!")
print()
print(f"  Update = {m_hat:.6f} × {adaptive_lr:.6f} = {update:.6f}")
print()
print(f"  θ_new = {theta_decayed:.6f} - ({update:.6f})")
print(f"        = {theta_new:.6f}")

In [None]:
# Summary
print("="*60)
print("SUMMARY")
print("="*60)
print(f"  Before: θ = {theta:.6f}")
print(f"  After:  θ = {theta_new:.6f}")
print(f"  Change:    {theta_new - theta:+.6f}")
print()
print(f"The gradient was negative, so θ increased.")
print(f"This will help reduce the loss!")

## Updating All Parameters

We apply this same process to **every single parameter** in the model.

Our model has **~2,600 parameters**:
- Token embeddings: 6 × 16 = 96
- Position embeddings: 5 × 16 = 80
- Attention weights (Q, K, V per head): 2 × 3 × (16 × 8) = 768
- Output projection: 16 × 16 = 256
- FFN W1, b1: 64 × 16 + 64 = 1,088
- FFN W2, b2: 16 × 64 + 16 = 1,040
- Layer norm γ, β: 16 + 16 = 32
- LM head: 6 × 16 = 96

All of them get updated using AdamW.

## Why AdamW Works So Well

1. **Automatic per-parameter learning rates**—parameters with large gradients get smaller effective rates

2. **Momentum smoothing**—the first moment $m$ smooths out noisy gradients over time

3. **Adaptive scaling**—$\sqrt{\hat{v}}$ in the denominator prevents instability

4. **Clean weight decay**—regularization independent of gradient noise

This is why AdamW is the default for training LLMs.

## The Complete Training Loop

We've now completed **one full training step**:

1. ✅ **Forward pass**—computed embeddings, attention, FFN, layer norm, and loss
2. ✅ **Backward pass**—computed gradients for every parameter via backpropagation
3. ✅ **Optimization**—updated every parameter using AdamW

This is what training a neural network means. Repeat this loop thousands or millions of times:

```python
for epoch in range(num_epochs):
    for batch in data:
        predictions = model(batch)        # Forward
        loss = compute_loss(predictions)  # Loss
        gradients = backpropagate(loss)   # Backward
        adamw_update(parameters, grads)   # Optimize
```

Each iteration, the loss gets smaller. The model gets better.

## What We've Learned

We calculated—by hand—a complete training step through a transformer:

- **Forward pass**: embeddings → attention → FFN → loss
- **Backward pass**: gradients via chain rule
- **Optimization**: AdamW updates for ~2,600 parameters

Real LLMs train on billions of tokens. GPT-3 trained on 300 billion tokens. Each token goes through this same process.

But now you understand exactly what those calculations are doing.

## Closing Thoughts

You've made it through the entire pipeline.

You've seen every matrix multiplication, every activation function, every gradient calculation, every weight update.

Nothing was hidden. No magic. Just math.

When someone says "a transformer learns by gradient descent," you now know **exactly** what that means—down to the individual floating-point operations.

You understand transformers not because someone explained it in abstract terms, but because you **calculated it yourself**.

That's the difference between knowing about something and truly understanding it.