### Part 2: Transformer language model

**Tutorial on Transformers for Mathematics**

*Simons Institute and SLMath Joint Workshop: AI for Mathematics and Theoretical Computer Science, April 8 2025*

Author: Sean Welleck

------

This notebook implements and trains a simple transformer language model.

**NOTE:** if you only want to train a transformer on a dataset and generate with it, *you can safely skip this notebook and move to the next one*. This notebook shows lower-level details of implementing a transformer and a training loop.

#### Generating names

Let's walk through the same simple example from the previous notebook. Note that ideas we present will apply to any other dataset of discrete token sequences. Here is our dataset:

In [13]:
data = open('names.txt').read().splitlines()
data[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

We create a mapping between tokens (characters plus `[S]`) and token indices:

In [14]:
token_to_index = {tok: i for i, tok in enumerate('abcdefghijklmnopqrstuvwxyz')}
# Start/stop token
token_to_index['[S]'] = 26
# Padding token
token_to_index['[PAD]'] = 27

index_to_token = {i: tok for tok, i in token_to_index.items()}

#### Building a dataset

Now we build a dataset that will teach the model to predict the next token.

**Unlike the bigram model, our transformer model will receive all of the preceding tokens as input.** For instance, predicting the fifth token looks like:

$$ (x_1,x_2,x_3,x_4)\rightarrow x_5$$

To format this prediction problem as a dataset, we will have the input be a sequence of tokens and the output be the sequence of tokens shifted one token to the right. Hence, the model needs to output the next token at each position of the input sequence.


In [24]:
import torch

def build_dataset(data):
    X, Y = [], []
    for item in data:
        tokens = ['[S]'] + list(item) + ['[S]']
        indices = [token_to_index[token] for token in tokens]
        X.append(indices[:-1])
        Y.append(indices[1:])
    return X, Y

# Split into train, dev, test
import random
random.seed(123)
random.shuffle(data)

n1 = int(0.8 * len(data))
n2 = int(0.9 * len(data))

X_train, Y_train = build_dataset(data[:n1])
X_dev, Y_dev = build_dataset(data[n1:n2])
X_test, Y_test = build_dataset(data[n2:])

len(X_train), len(Y_train)

(25626, 25626)

### Implement the transformer

Here is the main transformer layer / block. We'll cheat a bit and use Pytorch's implementation of attention:

In [16]:
import torch.nn as nn

class Block(nn.Module):
    def __init__(self, d_model, nhead, dim_ff=64, max_len=128):
        super(Block, self).__init__()
        self.attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0, batch_first=True)
        self.ff1 = nn.Linear(d_model, dim_ff)
        self.ff2 = nn.Linear(dim_ff, d_model)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.act = nn.ReLU()
        self.register_buffer('mask', torch.triu(torch.ones(max_len, max_len), diagonal=1).bool())

    def forward(self, x):
        B, T, D = x.size()
        # Pre-normalization
        x = self.ln1(x)
        # Self-attention
        x2 = self.attn(x, x, x, is_causal=True, attn_mask=self.mask[:T,:T])[0]
        # Residual connection
        x = x + x2
        # Pre-normalization
        x = self.ln2(x)
        # Feed-forward
        x2 = self.ff2(self.act(self.ff1(x)))
        # Residual connection
        x = x + x2
        return x

In [17]:
# test out the block
block = Block(10, 2)
x = torch.randn(10, 32, 10)
block(x).shape

torch.Size([10, 32, 10])

The transformer language model contains the blocks/layers, token and position embeddings, and an output layer:

In [18]:
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, dim_ff, max_len=128):
        super(TransformerLM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = nn.Embedding(max_len, d_model)
        self.blocks = nn.ModuleList([
            Block(d_model, nhead, dim_ff) for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, vocab_size)
        self.d_model = d_model

    def forward(self, x):
        pos = torch.arange(x.size(0), device=x.device).unsqueeze(1)
        x = self.embedding(x) + self.pos_encoder(pos)
        for block in self.blocks:
            x = block(x)
        logits = self.fc(x)
        return logits


In [19]:
model = TransformerLM(len(token_to_index), 64, 2, 2, 64)

x = torch.tensor(X_train[:1])

logits = model(x)
logits.size()

torch.Size([1, 6, 28])

### Formatting the data

During training we provide multiple examples to the transformer in a batch. Since the examples can be of varied length we need to "pad" them so that they are of the same length. We do so by introducing a special `[PAD]` token.

In [20]:
def pad_batch(X_batch, Y_batch, pad_index):
    max_len = max(len(x) for x in X_batch)
    X_padded = torch.zeros(len(X_batch), max_len, dtype=torch.long) + pad_index
    Y_padded = torch.zeros(len(Y_batch), max_len, dtype=torch.long) + pad_index
    for i, (x, y) in enumerate(zip(X_batch, Y_batch)):
        X_padded[i, :len(x)] = torch.tensor(x)
        Y_padded[i, :len(y)] = torch.tensor(y)
    return X_padded, Y_padded

xp, yp = pad_batch(X_train[:4], Y_train[:4], token_to_index['[PAD]'])

print(xp)
for x in xp:
    print([index_to_token[i.item()] for i in x])

tensor([[26, 11, 20,  0, 13, 13, 27, 27, 27, 27],
        [26, 18,  7,  0,  8, 13, 27, 27, 27, 27],
        [26, 17, 20, 15,  4, 17, 19, 27, 27, 27],
        [26, 12, 14, 10, 18,  7,  0,  6, 13,  0]])
['[S]', 'l', 'u', 'a', 'n', 'n', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['[S]', 's', 'h', 'a', 'i', 'n', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['[S]', 'r', 'u', 'p', 'e', 'r', 't', '[PAD]', '[PAD]', '[PAD]']
['[S]', 'm', 'o', 'k', 's', 'h', 'a', 'g', 'n', 'a']


We can forward the batch through the model and get the outputs, which are known as "logits". They give a score for each one of the possible next-tokens for each position. Hence the size of the output tensor is `batch x sequence length x vocab size`.

In [22]:
X_batch, Y_batch = pad_batch(X_train[:2], Y_train[:2], token_to_index['[PAD]'])

logits = model(X_batch)
logits.size()

torch.Size([2, 6, 28])

### Training loop

Now we implement the training loop. We go over each batch, compute the loss, and perform a backward pass. We report the loss on the training batches and on a held out validation set:

In [23]:
import torch.optim as optim

model = TransformerLM(len(token_to_index), 64, 2, 2, 64)

# Count model parameters
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")

# Hyperparameters
learning_rate = 0.001
num_epochs = 10
batch_size = 16

# Loss function and optimizer
# NOTE: We ignore the loss whenever the target token is a padding token
criterion = nn.CrossEntropyLoss(ignore_index=token_to_index['[PAD]'])

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    # Reshuffle the data
    perm = torch.randperm(len(X_train))
    X_train = [X_train[i] for i in perm]
    Y_train = [Y_train[i] for i in perm]
    
    model.train()
    total_loss = 0
    for i in range(0, len(X_train), batch_size):
        X_batch = X_train[i:i+batch_size]
        Y_batch = Y_train[i:i+batch_size]
        X_batch, Y_batch = pad_batch(X_batch, Y_batch, token_to_index['[PAD]'])

        # Forward pass
        outputs = model(X_batch) # [batch_size, seq_len, vocab_size]
        outputs = outputs.view(-1, len(token_to_index)) # [batch_size * seq_len, vocab_size]
        Y_batch = Y_batch.view(-1) # [batch_size * seq_len]
        loss = criterion(outputs, Y_batch)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / (len(X_train) // batch_size)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

    # Evaluate validation loss
    eval_loss = 0
    model.eval()
    with torch.no_grad():
        for i in range(0, len(X_dev), batch_size):
            X_batch = X_dev[i:i+batch_size]
            Y_batch = Y_dev[i:i+batch_size]
            X_batch, Y_batch = pad_batch(X_batch, Y_batch, token_to_index['[PAD]'])

            outputs = model(X_batch)
            outputs = outputs.view(-1, len(token_to_index))
            Y_batch = Y_batch.view(-1)
            loss = criterion(outputs, Y_batch)

            eval_loss += loss.item()
    avg_eval_loss = eval_loss / (len(X_dev) // batch_size)
    print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_eval_loss:.4f}')


Model parameters: 62236
Epoch [1/10], Loss: 2.3235
Epoch [1/10], Validation Loss: 2.2521
Epoch [2/10], Loss: 2.2105
Epoch [2/10], Validation Loss: 2.2052
Epoch [3/10], Loss: 2.1712
Epoch [3/10], Validation Loss: 2.1708
Epoch [4/10], Loss: 2.1431
Epoch [4/10], Validation Loss: 2.1536
Epoch [5/10], Loss: 2.1220
Epoch [5/10], Validation Loss: 2.1406
Epoch [6/10], Loss: 2.1050
Epoch [6/10], Validation Loss: 2.1310
Epoch [7/10], Loss: 2.0926
Epoch [7/10], Validation Loss: 2.1117
Epoch [8/10], Loss: 2.0818
Epoch [8/10], Validation Loss: 2.1120
Epoch [9/10], Loss: 2.0711
Epoch [9/10], Validation Loss: 2.1021
Epoch [10/10], Loss: 2.0614
Epoch [10/10], Validation Loss: 2.1039


### Generate new names with the model

We do this by sampling one token at a time given the preceding tokens:

In [28]:
# Sample from the model
def sample(model, context, max_length=100):
    model.eval()
    output = []
    with torch.no_grad():
        x = torch.tensor([[token_to_index['[S]']] + context])
        for _ in range(max_length):
            logits = model(x)
            y = torch.softmax(logits[0, -1], dim=0)
            y = torch.multinomial(y, 1)
            token = index_to_token[y.item()]
            if token == '[S]':
                break
            output.append(token)
            x = torch.cat([x, y.unsqueeze(0)], dim=1)
    return ''.join(output)

Generate 10 names:

In [26]:
for i in range(10):
    print(sample(model, []))

leyma
grahaed
auily
alenand
jahaj
nevy
arazie
tyaton
yukn
krija


Much better than the bigram model!

### Prompting

We can ensure that the initial tokens are equal to a "prompt". For instance, we can ensure a generated name starts with `s`:

In [29]:
prompt = 's'
for i in range(10):
    out = sample(model, [token_to_index[tok] for tok in prompt])
    print(prompt + out)

shegon
swax
sarif
sarey
somel
sifaeh
schigera
shael
sehrigh
sari
