# GPT from scratch!

Let's train a transformer with Modula.

In [1]:
import jax
import jax.numpy as jnp

First, let's set up the model parameters and the training parameters.

In [8]:
# Karpathy's smallest GPT config

vocab_size = 65
context = 64
num_heads = 4
d_embed = 128
d_query = 32
d_value = 32
num_blocks = 3
softmax_scale = 1

# training hyperparameters

lr = 0.1
beta = 0.95
batch_size = 64
steps = 2001
log_interval = 10
val_interval = 100

Let's download the data.

In [3]:
from data.shakespeare import load_shakespeare

data = load_shakespeare(context, batch_size)

train_loader = data["train_loader"]
val_loader = data["val_loader"]
encode = data["encode"]
decode = data["decode"]

Let's see what the data looks like.

In [4]:
for inputs, targets in train_loader:
    print("Input shape:", inputs.shape)
    print("Target shape:", targets.shape)
    print("First input sequence:", inputs[0])
    print("First target sequence:", targets[0])
    print("Decoded input:", decode(inputs[0]))
    print("Decoded target:", decode(targets[0]))
    break

Input shape: (64, 64)
Target shape: (64, 64)
First input sequence: [41 53 50 42  1 40 50 53 53 42  1 52 53  1 57 54 39 56 49  1 53 44  1 46
 53 52 53 59 56  1 40 47 42 43 57  8  0  0 26 27 30 32 20 33 25 14 17 30
 24 13 26 16 10  0 14 43  1 58 46 53 59  1 39  1]
First target sequence: [53 50 42  1 40 50 53 53 42  1 52 53  1 57 54 39 56 49  1 53 44  1 46 53
 52 53 59 56  1 40 47 42 43 57  8  0  0 26 27 30 32 20 33 25 14 17 30 24
 13 26 16 10  0 14 43  1 58 46 53 59  1 39  1 54]
Decoded input: cold blood no spark of honour bides.

NORTHUMBERLAND:
Be thou a 
Decoded target: old blood no spark of honour bides.

NORTHUMBERLAND:
Be thou a p


We're ready to define our transformer!

In [9]:
from modula.compound import GPT

model = GPT(
    vocab_size=vocab_size,
    num_heads=num_heads,
    d_embed=d_embed,
    d_query=d_query,
    d_value=d_value,
    num_blocks=num_blocks,
    softmax_scale=softmax_scale
)

model.jit()

print(model)

CompositeModule
...consists of 20 atoms and 58 bonds
...non-smooth
...input sensitivity is 1.0
...contributes proportion 7.0 to feature learning of any supermodule


We'll use cross entropy loss. We compute it using the logsumexp trick.

$$
-\log(\text{target probability}) = -\log(\text{softmax}(\text{logits}, \text{axis}=-1)_\text{target}) = -\text{logit}_\text{target} + \text{logsumexp}(\text{logits})
$$

In [28]:
def cross_entropy_loss(w, inputs, targets):
    # We use the logsumexp trick for stable cross entropy
    logits = model(inputs, w)  # shape is [batch, seq_len, vocab_size]
    batch_indices = jnp.arange(logits.shape[0])[:, None]  # shape is [batch, 1]
    seq_indices = jnp.arange(logits.shape[1])[None, :]    # shape is [1, seq_len]
    losses = -logits[batch_indices, seq_indices, targets] + jax.nn.logsumexp(logits, axis=-1)  # shape is [batch, seq_len]
    return losses.mean()

loss_and_grad = jax.jit(jax.value_and_grad(cross_entropy_loss))

And we're ready to train!

In [29]:
key = jax.random.PRNGKey(0)
w = model.initialize(key)

step = 0
momentum = [0 * weight for weight in w]
lr_schedule = lambda step: lr * (steps - step) / steps
for inputs, targets in train_loader:
    loss, grad_w = loss_and_grad(w, inputs, targets)
    momentum = [beta * m + (1 - beta) * g_w for m, g_w in zip(momentum, grad_w)]
    d_w = model.dualize(momentum)
    w = [weight - lr_schedule(step) * d_weight for weight, d_weight in zip(w, d_w)]

    if step % log_interval == 0:
        print(f"Step {step}: loss {loss}")
    
    # if step % val_interval == 0:
    #     val_losses = []
    #     for val_inputs, val_targets in val_loader:
    #         loss, _ = loss_and_grad(w, val_inputs, val_targets)
    #         val_losses.append(loss)
    #     print(f"\tval loss {sum(val_losses)/len(val_losses)}")

    step += 1

    if step >= steps:
        break

Step 0: loss 4.184149265289307
Step 10: loss 3.728048324584961
Step 20: loss 3.069167137145996


KeyboardInterrupt: 