# Hello, GPT!

In this notebook, we're going to build a transformer. In particular, we'll see how to define attention and residual blocks in Modula.

## Getting the data

First, let's download the Shakespeare dataset. The task will be to predict the next character.

In [1]:
context = 64
batch_size = 12

from data.shakespeare import load_shakespeare

data = load_shakespeare(context, batch_size)

train_loader = data["train_loader"]
val_loader = data["val_loader"]
encode = data["encode"]
decode = data["decode"]

Downloading Shakespeare dataset...
Processing Shakespeare dataset...
Length of dataset in characters: 1,115,394
Vocabulary size: 65
Train has 1,003,854 tokens
Val has 111,540 tokens
Shakespeare dataset processing complete.


Let's peek at an example to verify the data loaded correctly!

In [2]:
for inputs, targets in train_loader:
    print("Input shape:", inputs.shape)
    print("Target shape:", targets.shape)
    print("First input sequence:", inputs[0][:10], "...")
    print("First target sequence:", targets[0][:10], "...")
    print("\nDecoded input:", decode(inputs[0]))
    print("\nDecoded target:", decode(targets[0]))
    break

Input shape: (12, 64)
Target shape: (12, 64)
First input sequence: [41 53 50 42  1 40 50 53 53 42] ...
First target sequence: [53 50 42  1 40 50 53 53 42  1] ...

Decoded input: cold blood no spark of honour bides.

NORTHUMBERLAND:
Be thou a 

Decoded target: old blood no spark of honour bides.

NORTHUMBERLAND:
Be thou a p


## Defining the architecture

Let's use a very small setting for our transformer so it is fast to train.

In [3]:
# transformer hyperparameters

vocab_size = 65
num_heads = 4
d_embed = 128
d_query = 32
d_value = 32
num_blocks = 4
attention_scale = 1
final_scale = 1

# training hyperparameters

lr = 0.1
beta = 0.95
steps = 2001
log_interval = 10
val_interval = 100
val_iters = 20

 Next up, we'll define the *attention* module and *residual blocks*.

## Attention in Modula

In Modula, we'll define attention by stringing together several bond modules to do the parameterless computations. The roadmap is:
* Map `(batch, token, d_embed)` into `(batch, head, token, d_query)` (and same for key and value) via `Linear` and `SplitIntoHeads`
* Use Rotary Positional Embeddings (RoPE) on the query and the key via `Rope`
* Map `query` and `key` into attention similarities of shape `(batch, head, token, token)` via `AttentionQK`
* Use a causal mask and then softmax to create attention scores via `CausalMask` and `Softmax`
* Use the attention scores to create output vectors via `ApplyAttentionScores`, then `MergeHeads` and `Linear`

The main difference to a standard transformer is that `AttentionQK` uses $1/d_\text{head}$ scaling instead of the standard $1/\sqrt{d_\text{head}}$. The reason for this is to provide Lipschitz guarantees for attention that are independent of $d_\text{head}$. For more information on this, see Appendix B.6 of [Scalable Optimization in the Modular Norm](https://arxiv.org/pdf/2405.14813).

And here's the implementation:

In [4]:
from modula.atom import Linear
from modula.bond import SplitIntoHeads, MergeHeads, Rope, AttentionQK, CausalMask, Softmax, ApplyAttentionScores, GeLU

def Attention(num_heads, d_embed, d_query, d_value, attention_scale):
    """Multi-head attention"""

    # For keys, queries, and values we add a heads dimension. For the out projection, we remove heads.
    # Remember modules compose right-to-left, and the order is Linear(d_out, d_in)! And @ means compose.
    Q = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)
    K = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)
    V = SplitIntoHeads(num_heads) @ Linear(num_heads * d_value, d_embed)
    W = Linear(d_embed, num_heads * d_value) @ MergeHeads()

    # Read right-to-left: rotate (Q, K) with RoPE, apply Q @ K.T, mask, softmax (with a scale we can choose).
    AttentionScores = Softmax(attention_scale) @ CausalMask() @ AttentionQK() @ Rope(d_query) @ (Q, K)

    # Read right-to-left: apply attention scores, multiply by 1/3 to fix the sensitivity to 1, project back to d_embed.
    return W @ (1/3 * ApplyAttentionScores()) @ (V, AttentionScores)

Let's check that the sensitivity is 1 at initialization.

In [5]:
print(Attention(num_heads, d_embed, d_query, d_value, attention_scale))

CompositeModule
...consists of 4 atoms and 10 bonds
...smooth
...input sensitivity is 1.0
...contributes proportion 4 to feature learning of any supermodule


## Residual blocks in Modula

To implement the rest of our transformer, the roadmap is:
* Embed the input tokens
* Apply residual blocks for attention and the MLP
* Project out

All that's left is to set up the residual blocks. In Modula, we define residual connections using a convex combination. If $L$ is the number of residual blocks, then we use a convex combination of the identity and the block to get $x \mapsto \frac{L-1}{L} \cdot x + \frac{1}{L} \cdot \textsf{block}(x)$. The purpose is to create a Lipschitz guarantee that is independent of the number of blocks. For more information, see Proposition 4 of [Scalable Optimization in the Modular Norm](https://arxiv.org/pdf/2405.14813).

In short, these changes enable Lipschitz guarantees on our transformer even as we scale the width and the depth!

In [6]:
from modula.abstract import Identity
from modula.atom import Embed

def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mass=5, attention_scale=1.0, final_scale=1.0):
    # Set embed to have mass 1. This controls the proportion of feature learning that it contributes to the whole network.
    embed = Embed(d_embed, vocab_size)
    embed.tare()

    # Let's create attention and MLP layers. 
    att = Attention(num_heads, d_embed, d_query, d_value, attention_scale)
    mlp = Linear(d_embed, 4*d_embed) @ GeLU() @ Linear(4*d_embed, d_embed)

    # For our residual connections, L = 2*num_blocks because each block has two residual connections.
    att_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * att
    mlp_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * mlp

    # We can use powers of a module to compose it with itself many times!
    blocks = (mlp_block @ att_block) ** num_blocks

    # Set all transformer blocks to have mass 5 (by default).
    # So 5/7 of the change in the network output is due to the blocks,
    # and 2/7 of the change in output is due to the embedding and out projection.
    blocks.tare(absolute=blocks_mass)

    out = final_scale * Linear(vocab_size, d_embed)

    return out @ blocks @ embed

And finally we are ready to construct our GPT!

In [7]:
model = GPT(
    vocab_size=vocab_size,
    num_heads=num_heads,
    d_embed=d_embed,
    d_query=d_query,
    d_value=d_value,
    num_blocks=num_blocks,
    attention_scale=attention_scale,
    final_scale=final_scale,
)

model.jit()

print(model)

CompositeModule
...consists of 26 atoms and 78 bonds
...non-smooth
...input sensitivity is 1.0
...contributes proportion 7.0 to feature learning of any supermodule


## Loss function and training

To train our transformer we'll use cross entropy loss, which we can compute by decomposing the softmax:

$$
-\log(\text{target probability}) = -\log(\text{softmax}(\text{logits})_\text{target}) = -\text{logit}_\text{target} + \text{log\,sum\,exp}(\text{logits})
$$

In [8]:
import jax
import jax.numpy as jnp

def cross_entropy_loss(w, inputs, targets):
    # We use the logsumexp trick for stable cross entropy
    logits = model(inputs, w)  # shape is [batch, seq_len, vocab_size]
    batch_indices = jnp.arange(logits.shape[0])[:, None]  # shape is [batch, 1]
    seq_indices = jnp.arange(logits.shape[1])[None, :]    # shape is [1, seq_len]
    # This indexing selects out logits[b, s, targets[b, s]], which is the target logit
    losses = -logits[batch_indices, seq_indices, targets] + jax.nn.logsumexp(logits, axis=-1)  # shape is [batch, seq_len]
    return losses.mean()

loss_and_grad = jax.jit(jax.value_and_grad(cross_entropy_loss))

And we're ready to train!

In [9]:
key = jax.random.PRNGKey(0)
w = model.initialize(key)

step = 0
momentum = [0 * weight for weight in w]
lr_schedule = lambda step: lr * (steps - step) / steps
for inputs, targets in train_loader:
    loss, grad_w = loss_and_grad(w, inputs, targets)
    momentum = [beta * m + (1 - beta) * g_w for m, g_w in zip(momentum, grad_w)]
    d_w = model.dualize(momentum)
    w = [weight - lr_schedule(step) * d_weight for weight, d_weight in zip(w, d_w)]

    if step % log_interval == 0:
        print(f"Step {step}: loss {loss}")
    
    if step % val_interval == 0:
        val_losses = []
        for val_inputs, val_targets in val_loader:
            loss, _ = loss_and_grad(w, val_inputs, val_targets)
            val_losses.append(loss)
            if len(val_losses) >= val_iters:
                break
        print(f"--> val loss {sum(val_losses)/len(val_losses)}")

    step += 1

    if step >= steps:
        break

Step 0: loss 4.226325988769531
--> val loss 4.179544448852539
Step 10: loss 3.8738746643066406
Step 20: loss 3.3448646068573
Step 30: loss 2.805002212524414
Step 40: loss 2.68573260307312
Step 50: loss 2.6098480224609375
Step 60: loss 2.407468557357788
Step 70: loss 2.418379783630371
Step 80: loss 2.359757423400879
Step 90: loss 2.2685279846191406
Step 100: loss 2.314124584197998
--> val loss 2.541980743408203
Step 110: loss 2.283424139022827
Step 120: loss 2.2063167095184326
Step 130: loss 2.1598031520843506
Step 140: loss 2.252727508544922
Step 150: loss 2.124152660369873
Step 160: loss 2.23785662651062
Step 170: loss 2.2059123516082764
Step 180: loss 2.102996587753296
Step 190: loss 2.132392168045044
Step 200: loss 2.130244255065918
--> val loss 2.359212636947632
Step 210: loss 2.0895276069641113
Step 220: loss 2.1278815269470215
Step 230: loss 1.9647449254989624
Step 240: loss 2.1118733882904053
Step 250: loss 1.9459623098373413
Step 260: loss 2.118051290512085
Step 270: loss 2.060

## Though this be madness, yet there is method in't

And indeed, let us look at how our wee model stacks up to the master.

In [10]:
def generate_text(prompt, max_tokens=100, temperature=0.5, seed=0):
    key = jax.random.PRNGKey(seed)
    tokens = jnp.array(encode(prompt))
    for _ in range(max_tokens):
        logits = model(jnp.expand_dims(tokens, 0), w)
        next_token_logits = logits[0, -1] / temperature
        
        # Sample from our model's token distribution
        key, subkey = jax.random.split(key)
        next_token = jax.random.categorical(subkey, next_token_logits)
        tokens = jnp.append(tokens, next_token)
    
    return decode(tokens)

for seed in range(3):
    print(f"Sample {seed}:\n\n{generate_text('If', max_tokens=100, seed=seed)}")
    print("-" * 80)

Sample 0:

If where his elperiend and is here in think the comfore be pray virtue deather I the grouth a pears my
--------------------------------------------------------------------------------
Sample 1:

If as the conture the weet to the man's death the greeen he with thought rame the prosates he palousen
--------------------------------------------------------------------------------
Sample 2:

If him the be not me were and let for the earth the forth,
That the his a wort of you the fearshould a
--------------------------------------------------------------------------------
