# GPT from scratch!

Let's train a transformer with Modula.

In [1]:
import jax
import jax.numpy as jnp
import optax

First, let's set up the model parameters and the training parameters.

In [2]:
# Karpathy's smallest GPT config

vocab_size = 65
context = 64
num_heads = 4
d_embed = 128
d_query = 32
d_value = 32
num_blocks = 4
softmax_scale = 1

# training hyperparameters

lr = 0.5
batch_size = 64
steps = 1001
log_interval = 100

Let's download the data.

In [3]:
from data.shakespeare import load_shakespeare

data = load_shakespeare(context, batch_size)

train_loader = data["train_loader"]
val_loader = data["val_loader"]
encode = data["encode"]
decode = data["decode"]

Let's see what the data looks like.

In [4]:
for inputs, targets in train_loader:
    print("Input shape:", inputs.shape)
    print("Target shape:", targets.shape)
    print("First input sequence:", inputs[0])
    print("First target sequence:", targets[0])
    print("Decoded input:", decode(inputs[0]))
    print("Decoded target:", decode(targets[0]))
    break

Input shape: (64, 64)
Target shape: (64, 64)
First input sequence: [41 53 50 42  1 40 50 53 53 42  1 52 53  1 57 54 39 56 49  1 53 44  1 46
 53 52 53 59 56  1 40 47 42 43 57  8  0  0 26 27 30 32 20 33 25 14 17 30
 24 13 26 16 10  0 14 43  1 58 46 53 59  1 39  1]
First target sequence: [53 50 42  1 40 50 53 53 42  1 52 53  1 57 54 39 56 49  1 53 44  1 46 53
 52 53 59 56  1 40 47 42 43 57  8  0  0 26 27 30 32 20 33 25 14 17 30 24
 13 26 16 10  0 14 43  1 58 46 53 59  1 39  1 54]
Decoded input: cold blood no spark of honour bides.

NORTHUMBERLAND:
Be thou a 
Decoded target: old blood no spark of honour bides.

NORTHUMBERLAND:
Be thou a p


We're ready to define our transformer!

In [5]:
from modula.compound import GPT

model = GPT(
    vocab_size=vocab_size,
    num_heads=num_heads,
    d_embed=d_embed,
    d_query=d_query,
    d_value=d_value,
    num_blocks=num_blocks,
    softmax_scale=softmax_scale
)

model.jit()

print(model)

CompositeModule
...consists of 26 atoms and 77 bonds
...non-smooth
...input sensitivity is 1.0
...contributes proportion 7.0 to feature learning of any supermodule


We'll use cross entropy loss.

In [6]:
def cross_entropy_loss(w, inputs, targets):
    logits = model(inputs, w)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
    return loss.mean()

loss_and_grad = jax.jit(jax.value_and_grad(cross_entropy_loss))

And we're ready to train!

In [9]:
key = jax.random.PRNGKey(0)
w = model.initialize(key)

step = 0
for inputs, targets in train_loader:
    loss, grad_w = loss_and_grad(w, inputs, targets)
    d_w = mlp.dualize(grad_w)
    w = [weight - lr * d_weight for weight, d_weight in zip(w, d_w)]

    if step % log_interval == 0:
        val_loss = []
        for val_inputs, val_targets in val_loader:
            loss, _ = loss_and_grad(w, val_inputs, val_targets)
            val_loss.append(loss)
        print(f"Step {step}: val loss {mean(val_loss)}")

    step += 1

TypeError: dot_general requires contracting dimensions to have the same shape, got (128,) and (64,).