# Hello, GPT!

Modula can represent any neural net architecture. Here, let's see how to make a transformer that trains on Shakespeare!

## Getting the data

Let's download the Shakespeare dataset.

In [1]:
context = 64
batch_size = 12

from data.shakespeare import load_shakespeare

data = load_shakespeare(context, batch_size)

train_loader = data["train_loader"]
val_loader = data["val_loader"]
encode = data["encode"]
decode = data["decode"]

Let's peek at an example to verify the data loaded correctly.

In [2]:
for inputs, targets in train_loader:
    print("Input shape:", inputs.shape)
    print("Target shape:", targets.shape)
    print("First input sequence:", inputs[0][:10], "...")
    print("First target sequence:", targets[0][:10], "...")
    print("Decoded input:", decode(inputs[0]))
    print("Decoded target:", decode(targets[0]))
    break

Input shape: (12, 64)
Target shape: (12, 64)
First input sequence: [41 53 50 42  1 40 50 53 53 42] ...
First target sequence: [53 50 42  1 40 50 53 53 42  1] ...
Decoded input: cold blood no spark of honour bides.

NORTHUMBERLAND:
Be thou a 
Decoded target: old blood no spark of honour bides.

NORTHUMBERLAND:
Be thou a p


## Defining the architecture

In [3]:
# transformer hyperparameters

vocab_size = 65
num_heads = 4
d_embed = 128
d_query = 32
d_value = 32
num_blocks = 4
attention_scale = 1
final_scale = 1

# training hyperparameters

lr = 0.1
beta = 0.95
steps = 2001
log_interval = 10
val_interval = 100
val_iters = 20

Great! We're ready to define our transformer. We'll give some extra care to *attention* and *residual connections*.

## Attention (TODO: clarify this explanation)

We'll define attention with $1/d$ scaling instead of the usual $1/\sqrt{d}$ scaling to make the attention block Lipschitz. That is to say, with the usual scaling, attention can become arbitrarily sensitive to changes in its input as $d$ grows. With $1/d$ scaling, the sensitivity is bounded at 3! This means the entire transformer can be well-normed. There are more details to this story. You can see more in "Case Study I: Attention" of our paper [Scalable Optimization in the Modular Norm](https://arxiv.org/pdf/2405.14813).

In [4]:
from modula.abstract import Identity
from modula.atom import Linear
from modula.bond import SplitIntoHeads, MergeHeads, Rope, AttentionQK, CausalMask, Softmax, ApplyAttentionScores, GeLU

def Attention(num_heads, d_embed, d_query, d_value, attention_scale):
    """Multi-head attention"""

    # For keys, queries, and values we add a heads dimension. For the out projection, we remove heads.
    # Remember modules compose right-to-left, and the order is Linear(d_out, d_in)! And @ means compose.
    Q = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)
    K = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)
    V = SplitIntoHeads(num_heads) @ Linear(num_heads * d_value, d_embed)
    W = Linear(d_embed, num_heads * d_value) @ MergeHeads()

    # Read right-to-left: rotate (Q, K) with RoPE, apply Q @ K.T, mask, softmax (with a scale we can choose).
    AttentionScores = Softmax(attention_scale) @ CausalMask() @ AttentionQK() @ Rope(d_query) @ (Q, K)

    # Read right-to-left: apply attention scores, multiply by 1/3 to fix the sensitivity to 1, project back to d_embed.
    return W @ (1/3 * ApplyAttentionScores()) @ (V, AttentionScores)

Let's check that the sensitivity is 1 at initialization.

In [5]:
print(Attention(num_heads, d_embed, d_query, d_value, attention_scale))

CompositeModule
...consists of 4 atoms and 10 bonds
...smooth
...input sensitivity is 1.0
...contributes proportion 4 to feature learning of any supermodule


## Residual Connections

We'll define residual connections using a convex combination: if $L$ is the number of residual blocks, then we use the convex combination of the identity and the block to get $x = \frac{L-1}{L} x + \frac{1}{L} \textsf{block}(x).$ Again the purpose is to control the sensitivity of the transformer. As the number of blocks $L$ grows, the norm of $x$ remains bounded. And the contribution from any particular block decays by at most $1/e$ through the entire network, because $(1 - 1/L)^L > 1/e$ for any depth $L$.

In short, these changes allows us to control the input sensitivity of our transformer as it scales. We're ready to implement it!

In [6]:
from modula.atom import Embed

def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mass=5, attention_scale=1.0, final_scale=1.0):
    # Set embed to have mass 1. This controls the proportion of feature learning that it contributes to the whole network.
    embed = Embed(d_embed, vocab_size)
    embed.tare()

    # Let's create attention and MLP layers. 
    att = Attention(num_heads, d_embed, d_query, d_value, attention_scale)
    mlp = Linear(d_embed, 4*d_embed) @ GeLU() @ Linear(4*d_embed, d_embed)

    # For our residual connections, L = 2*num_blocks because each block has two residual connections.
    att_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * att
    mlp_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * mlp

    # We can use powers of a module to compose it with itself many times!
    blocks = (mlp_block @ att_block) ** num_blocks

    # Set all transformer blocks to have mass 5 (by default).
    # So 5/7 of the change in the network output is due to the blocks,
    # and 2/7 of the change in output is due to the embedding and out projection.
    blocks.tare(absolute=blocks_mass)

    out = final_scale * Linear(vocab_size, d_embed)

    return out @ blocks @ embed

Let's construct our GPT!

In [7]:
model = GPT(
    vocab_size=vocab_size,
    num_heads=num_heads,
    d_embed=d_embed,
    d_query=d_query,
    d_value=d_value,
    num_blocks=num_blocks,
    attention_scale=attention_scale,
    final_scale=final_scale,
)

model.jit()

print(model)

CompositeModule
...consists of 26 atoms and 78 bonds
...non-smooth
...input sensitivity is 1.0
...contributes proportion 7.0 to feature learning of any supermodule


## Loss function and training

To train our transformer we'll use cross entropy loss, which we can compute by decomposing the softmax:

$$
-\log(\text{target probability}) = -\log(\text{softmax}(\text{logits})_\text{target}) = -\text{logit}_\text{target} + \text{logsumexp}(\text{logits})
$$

In [8]:
import jax
import jax.numpy as jnp

def cross_entropy_loss(w, inputs, targets):
    # We use the logsumexp trick for stable cross entropy
    logits = model(inputs, w)  # shape is [batch, seq_len, vocab_size]
    batch_indices = jnp.arange(logits.shape[0])[:, None]  # shape is [batch, 1]
    seq_indices = jnp.arange(logits.shape[1])[None, :]    # shape is [1, seq_len]
    # This indexing selects out logits[b, s, targets[b, s]], which is the target logit
    losses = -logits[batch_indices, seq_indices, targets] + jax.nn.logsumexp(logits, axis=-1)  # shape is [batch, seq_len]
    return losses.mean()

loss_and_grad = jax.jit(jax.value_and_grad(cross_entropy_loss))

And we're ready to train!

In [9]:
key = jax.random.PRNGKey(0)
w = model.initialize(key)

step = 0
momentum = [0 * weight for weight in w]
lr_schedule = lambda step: lr * (steps - step) / steps
for inputs, targets in train_loader:
    loss, grad_w = loss_and_grad(w, inputs, targets)
    momentum = [beta * m + (1 - beta) * g_w for m, g_w in zip(momentum, grad_w)]
    d_w = model.dualize(momentum)
    w = [weight - lr_schedule(step) * d_weight for weight, d_weight in zip(w, d_w)]

    if step % log_interval == 0:
        print(f"Step {step}: loss {loss}")
    
    if step % val_interval == 0:
        val_losses = []
        for val_inputs, val_targets in val_loader:
            loss, _ = loss_and_grad(w, val_inputs, val_targets)
            val_losses.append(loss)
            if len(val_losses) >= val_iters:
                break
        print(f"--> val loss {sum(val_losses)/len(val_losses)}")

    step += 1

    if step >= steps:
        break

Step 0: loss 4.226325988769531
--> val loss 4.179544925689697
Step 10: loss 3.873875379562378
Step 20: loss 3.3448638916015625
Step 30: loss 2.8050029277801514
Step 40: loss 2.68573260307312
Step 50: loss 2.6098484992980957
Step 60: loss 2.40746808052063
Step 70: loss 2.4183788299560547
Step 80: loss 2.359755039215088
Step 90: loss 2.268550157546997
Step 100: loss 2.3140738010406494
--> val loss 2.542112112045288
Step 110: loss 2.2838447093963623
Step 120: loss 2.2057571411132812
Step 130: loss 2.16951322555542
Step 140: loss 2.274829864501953
Step 150: loss 2.125481367111206
Step 160: loss 2.2481937408447266
Step 170: loss 2.173231601715088
Step 180: loss 2.101841449737549
Step 190: loss 2.1321706771850586
Step 200: loss 2.127584457397461
--> val loss 2.3946735858917236
Step 210: loss 2.076207160949707
Step 220: loss 2.1188466548919678
Step 230: loss 1.9753059148788452
Step 240: loss 2.097914457321167
Step 250: loss 1.9066812992095947
Step 260: loss 2.082390308380127
Step 270: loss 2.

## Though this be madness, yet there is method in't

And indeed, let us look at how our wee model stacks up the master.

In [10]:
def generate_text(prompt, max_tokens=100, temperature=0.5, seed=0):
    key = jax.random.PRNGKey(seed)
    tokens = jnp.array(encode(prompt))
    for _ in range(max_tokens):
        logits = model(jnp.expand_dims(tokens, 0), w)
        next_token_logits = logits[0, -1] / temperature
        
        # Sample from our model's token distribution
        key, subkey = jax.random.split(key)
        next_token = jax.random.categorical(subkey, next_token_logits)
        tokens = jnp.append(tokens, next_token)
    
    return decode(tokens)

for seed in range(3):
    print(f"Sample {seed}:\n\n{generate_text('If', max_tokens=100, seed=seed)}")
    print("-" * 80)

Sample 0:

If the heave elder
That were the heave the fare the greath the wors,
The provery the graint be will my
--------------------------------------------------------------------------------
Sample 1:

Ifens the content the and the say man
The be stranger to heart, and I dain their comme the her all her
--------------------------------------------------------------------------------
Sample 2:

If the can heart the was constleman:
In the proper the true the bloody.

CLAUDITA:
I parce then ountra
--------------------------------------------------------------------------------


Hey, not too bad!