# Hello, GPT!

Modula can represent any neural net architecture. Here, let's see how to make a transformer that trains on Shakespeare! Let's set out the model parameters first.

In [18]:
# Karpathy's smallest GPT config

vocab_size = 65
context = 64
num_heads = 4
d_embed = 128
d_query = 32
d_value = 32
num_blocks = 4
attention_scale = 1
final_scale = 1

# training hyperparameters

lr = 0.2
beta = 0.0
batch_size = 64
steps = 401
log_interval = 10
val_interval = 100

Now let's download the data.

In [2]:
from data.shakespeare import load_shakespeare

data = load_shakespeare(context, batch_size)

train_loader = data["train_loader"]
val_loader = data["val_loader"]
encode = data["encode"]
decode = data["decode"]

Let's peek at an example to verify the data loaded correctly.

In [3]:
for inputs, targets in train_loader:
    print("Input shape:", inputs.shape)
    print("Target shape:", targets.shape)
    print("First input sequence:", inputs[0][:10], "...")
    print("First target sequence:", targets[0][:10], "...")
    print("Decoded input:", decode(inputs[0]))
    print("Decoded target:", decode(targets[0]))
    break

Input shape: (64, 64)
Target shape: (64, 64)
First input sequence: [41 53 50 42  1 40 50 53 53 42] ...
First target sequence: [53 50 42  1 40 50 53 53 42  1] ...
Decoded input: cold blood no spark of honour bides.

NORTHUMBERLAND:
Be thou a 
Decoded target: old blood no spark of honour bides.

NORTHUMBERLAND:
Be thou a p


Great! We're ready to define our transformer. We'll give some extra care to *attention* and *residual connections*.

1. We'll define attention with $1/d$ scaling instead of the usual $1/\sqrt{d}$ scaling to make the attention block Lipschitz. That is to say, with the usual scaling, attention can become arbitrarily sensitive to changes in its input as $d$ grows. With $1/\sqrt{d}$ scaling, the sensitivity is bounded at 1 and the sharpness is bounded at 3! This means the entire transformer can be well-normed. The full story is in "Case Study I: Attention" of our paper ![Scalable Optimization in the Modular Norm](https://arxiv.org/pdf/2405.14813).

2. We'll define residual connections using a linear combination: if $L$ is the number of residual blocks, then we use $x = \frac{L-1}{L} x + \frac{1}{L} \textsf{block}(x).$ Again the purpose is to control the sensitivity of the transformer. As $L$ grows, the norm of $x$ remains bounded. And the contribution from any particular block decays by at most $1/e$ through the entire network, because $(1 - 1/L)^L > 1/e$ for any depth $L$.

In short, these changes allow us to control the input sensitivity of our transformer as it scales. We're ready to implement it!

In [19]:
from modula.abstract import Identity
from modula.atom import Linear, Embed
from modula.bond import SplitIntoHeads, MergeHeads, Rope, AttentionQK, CausalMask, Softmax, ApplyAttentionScores, GeLU

def Attention(num_heads, d_embed, d_query, d_value, softmax_scale, causal):
    """Multi-head attention"""

    # For keys, queries, and values we add a heads dimension. For the out projection, we remove heads.
    # Remember modules compose right-to-left, and the order is Linear(d_out, d_in)!
    Q = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)
    K = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)
    V = SplitIntoHeads(num_heads) @ Linear(num_heads * d_value, d_embed)
    W = Linear(d_embed, num_heads * d_value) @ MergeHeads()

    # Read right-to-left: rotate (Q, K) with RoPE, apply Q @ K.T, mask, softmax (with a scale we can choose).
    AttentionScores = Softmax(softmax_scale) @ CausalMask() @ AttentionQK() @ Rope(d_query) @ (Q, K)

    # Read right-to-left: apply attention scores, multiply by 1/3 to fix the sharpness to 1, project back to d_embed.
    return W @ (1/3 * ApplyAttentionScores()) @ (V, AttentionScores)

def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mass=5, attention_scale=1.0, final_scale=1.0):
    # Set embed to have mass 1. This controls the proportion of feature learning that occurs here.
    embed = Embed(d_embed, vocab_size)
    embed.tare()

    # Let's create attention and MLP layers. 
    att = Attention(num_heads, d_embed, d_query, d_value, attention_scale, causal=True)
    mlp = Linear(d_embed, 4*d_embed) @ GeLU() @ Linear(4*d_embed, d_embed)

    # For our residual connections, L = 2*num_blocks because each block has two residual connections.
    att_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * att
    mlp_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * mlp

    # We can use powers of a module to compose it with itself many times!
    blocks = (mlp_block @ att_block) ** num_blocks

    # Set all transformer blocks to have mass 5 (by default).
    # So 5/7 of all feature learning occurs in the blocks,
    # and 2/7 occurs in the embedding and out projection.
    blocks.tare(absolute=blocks_mass)

    out = final_scale * Linear(vocab_size, d_embed)

    return out @ blocks @ embed

model = GPT(
    vocab_size=vocab_size,
    num_heads=num_heads,
    d_embed=d_embed,
    d_query=d_query,
    d_value=d_value,
    num_blocks=num_blocks,
    attention_scale=attention_scale,
    final_scale=final_scale,
)

model.jit()

print(model)

CompositeModule
...consists of 26 atoms and 78 bonds
...non-smooth
...input sensitivity is 1.0
...contributes proportion 7.0 to feature learning of any supermodule


Now for the loss function. We'll use cross entropy loss, which we can compute using the logsumexp trick:

$$
-\log(\text{target probability}) = -\log(\text{softmax}(\text{logits})_\text{target}) = -\text{logit}_\text{target} + \text{logsumexp}(\text{logits})
$$

In [5]:
import jax
import jax.numpy as jnp

def cross_entropy_loss(w, inputs, targets):
    # We use the logsumexp trick for stable cross entropy
    logits = model(inputs, w)  # shape is [batch, seq_len, vocab_size]
    batch_indices = jnp.arange(logits.shape[0])[:, None]  # shape is [batch, 1]
    seq_indices = jnp.arange(logits.shape[1])[None, :]    # shape is [1, seq_len]
    # This indexing selects out logits[b, s, targets[b, s]], which is the target logit
    losses = -logits[batch_indices, seq_indices, targets] + jax.nn.logsumexp(logits, axis=-1)  # shape is [batch, seq_len]
    return losses.mean()

loss_and_grad = jax.jit(jax.value_and_grad(cross_entropy_loss))

And we're ready to train!

In [20]:
key = jax.random.PRNGKey(0)
w = model.initialize(key)

step = 0
momentum = [0 * weight for weight in w]
lr_schedule = lambda step: lr * (steps - step) / steps
for inputs, targets in train_loader:
    loss, grad_w = loss_and_grad(w, inputs, targets)
    momentum = [beta * m + (1 - beta) * g_w for m, g_w in zip(momentum, grad_w)]
    d_w = model.dualize(momentum)
    w = [weight - lr_schedule(step) * d_weight for weight, d_weight in zip(w, d_w)]

    if step % log_interval == 0:
        print(f"Step {step}: loss {loss}")
    
    if step % val_interval == 0:
        val_losses = []
        val_batches = 10
        for val_inputs, val_targets in val_loader:
            loss, _ = loss_and_grad(w, val_inputs, val_targets)
            val_losses.append(loss)
            if len(val_losses) >= val_batches:
                break
        print(f"--> val loss {sum(val_losses)/len(val_losses)}")

    step += 1

    if step >= steps:
        break

Step 0: loss 4.210641860961914
--> val loss 4.149214267730713
Step 10: loss 3.3135969638824463
Step 20: loss 2.638061285018921
Step 30: loss 2.4735331535339355
Step 40: loss 2.352996826171875
Step 50: loss 2.2717297077178955
Step 60: loss 2.253775119781494
Step 70: loss 2.179286241531372
Step 80: loss 2.1060492992401123
Step 90: loss 2.0014119148254395
Step 100: loss 1.9917633533477783
--> val loss 2.2128195762634277
Step 110: loss 2.0247018337249756
Step 120: loss 1.94600510597229
Step 130: loss 1.9080532789230347
Step 140: loss 1.9041284322738647
Step 150: loss 1.8841463327407837
Step 160: loss 1.8827264308929443
Step 170: loss 1.8754804134368896
Step 180: loss 1.817456603050232
Step 190: loss 1.818677544593811
Step 200: loss 1.7669349908828735
--> val loss 1.9174721240997314
Step 210: loss 1.768047571182251
Step 220: loss 1.789674997329712
Step 230: loss 1.7021842002868652
Step 240: loss 1.743955373764038
Step 250: loss 1.7292665243148804
Step 260: loss 1.7314412593841553
Step 270: 

And as a last check, let's see how our model stacks up to Shakespeare when it writes. Hey, not too bad!

In [21]:
def generate_text(prompt, max_tokens=100, temperature=1.0, seed=0):
    key = jax.random.PRNGKey(seed)
    tokens = jnp.array(encode(prompt))
    for _ in range(max_tokens):
        logits = model(jnp.expand_dims(tokens, 0), w)
        next_token_logits = logits[0, -1] / temperature
        
        # Sample from our model's token distribution
        key, subkey = jax.random.split(key)
        next_token = jax.random.categorical(subkey, next_token_logits)
        tokens = jnp.append(tokens, next_token)
    
    return decode(tokens)

for seed in range(8):
    print("Generated: ", generate_text("If", max_tokens=20, seed=seed))
    print("-" * 80)

Generated:  IfOHAUS:
First Merish 
--------------------------------------------------------------------------------
Generated:  Ifamble, tho last,
Bes
--------------------------------------------------------------------------------
Generated:  If
Lor man beend the w
--------------------------------------------------------------------------------
Generated:  If you. Mercies.
Lood 
--------------------------------------------------------------------------------
Generated:  If let, didly ant womo
--------------------------------------------------------------------------------
Generated:  If wich with shy Say c
--------------------------------------------------------------------------------
Generated:  Ifry, Aumit? a to as n
--------------------------------------------------------------------------------
Generated:  If thy hows; brot have
--------------------------------------------------------------------------------
