# Assignment 4: Training Transformers in PyTorch

*Author:* Thomas Adler

*Copyright statement:* This  material,  no  matter  whether  in  printed  or  electronic  form,  may  be  used  for  personal  and non-commercial educational use only.  Any reproduction of this manuscript, no matter whether as a whole or in parts, no matter whether in printed or in electronic form, requires explicit prior acceptance of the authors.

In this assignment we will implement and train a small transformer model and compare it to the LSTM in the previous assignment. 

## Exercise 1: Causal Self-Attention

Write a class named `CausalSelfAttention` that derives from `nn.Module` and whose `__init__` method takes (apart from the trivial `self`) one argument `hidden_size`. Implement a method `forward` that takes an input sequence `x` of shape $(N, T, D)$ (where $N$ is batch size, $T$ is sequence length, $D$ is hidden size) and performs scaled dot-product self-attention, i.e., 
$$
Y = \operatorname{softmax}\left(\frac{1}{\sqrt{D}} Q K^\top\right) V,
$$
where $Q = X W_Q$ and $K = X W_K$ and $V = X W_V$ and $X \in \mathbb{R}^{T \times D}$ and $W_Q, W_K, W_V \in \mathbb{R}^{D \times D}$ and softmax is applied in a row-wise manner and neglecting bias units. 
It is called self-attention because $Q, K, V$ are all computed from the same input $X$, which hence attends to itself. 

To have the attention be *causal* we need to make sure that we do not allow peeks into the future. That is, the output at time $t$ must be a function of the input at times $1, \dots, t$ but no further. The score matrix $E = \frac{1}{\sqrt{D}} Q K^\top$ has a shape of $T \times T$ and the entry $e_{ij}$ measures how strong the query at time $i$ attends to the key at time $j$. Therefore, positions where $j > i$ constitute peeks into the future and we have to set the corresponding attention values (i.e., the softmax-activated score) to zero. We can do that by setting the corresponding score to `float('-inf')`, which has the advantage that the normalization is adjusted automatically by the softmax. 

In [1]:
import torch
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else torch.device("mps") if torch.mps.is_available() else "cpu")
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import time

class CausalSelfAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(CausalSelfAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.WQ = nn.Linear(hidden_dim, hidden_dim, bias=False, device=DEVICE)
        self.WK = nn.Linear(hidden_dim, hidden_dim, bias=False, device=DEVICE)
        self.WV = nn.Linear(hidden_dim, hidden_dim, bias=False, device=DEVICE)
    
    def forward(self, x):
        N, T, D = x.size()
        Q = self.WQ(x)                                                  #  (N, T, D)
        K = self.WK(x)                                                  #  (N, T, D)
        V = self.WV(x)                                                  #  (N, T, D)
        Q = Q / math.sqrt(D)
        scores = torch.matmul(Q, K.transpose(1, 2))                     #  (N, T, D) * (N, D, T) -> (N, T, T)
        mask = torch.triu(torch.ones(T, T), diagonal=1).to(x.device)    #  (T, T)
        scores = scores.masked_fill(mask == 1, float('-inf'))           #  (N, T, T)
        attention = F.softmax(scores, dim=-1)                           #  (N, T, T)
        out = torch.matmul(attention, V)                                #  (N, T, T) * (N, T, D) -> (N, T, D)
        return out

## Exercise 2: Multi-Head Attention

Write a class `MultiHeadCausalSelfAttention` that derives from `nn.Module` and extends the functionality of `CausalSelfAttention` from the previous exercise. 
The `__init__` method takes arguments `hidden_size, n_head, dropout`. `n_head` specifies the number of attention heads and `dropout` specifies the intensity for the dropout layers. 
The `forward` method should split the hidden dimension of the pre-activations (i.e., $Q, K, V$) in `n_head` equally sized parts and perform attention to these parts in parallel. 
Apply the first dropout layer direcly after the softmax. 
After the multiplication of the scores with the values, recombine the output of the distinct attention heads back into a single hidden dimension of size $D$, i.e., the resulting shape should be the shape of the input. 
Then perform an additional output projection again resulting in a hidden dimension of $D$. 
Finally, apply the second dropout layer after the output projection. 

In [2]:
class MultiHeadCausalSelfAttention(nn.Module):
    def __init__(self, hidden_size, n_head, dropout):
        super(MultiHeadCausalSelfAttention, self).__init__()
        self.hidden_size = hidden_size
        self.n_head = n_head
        self.dropout = nn.Dropout(dropout)
        self.WQ = nn.Linear(hidden_size, hidden_size, device=DEVICE)
        self.WK = nn.Linear(hidden_size, hidden_size, device=DEVICE)
        self.WV = nn.Linear(hidden_size, hidden_size, device=DEVICE)

    def forward(self, x):
        N, T, D = x.size()
        Q = self.WQ(x)
        K = self.WK(x)
        V = self.WV(x)
        Q = Q.view(N, T, self.n_head, D // self.n_head).permute(0, 2, 1, 3)
        K = K.view(N, T, self.n_head, D // self.n_head).permute(0, 2, 1, 3)
        V = V.view(N, T, self.n_head, D // self.n_head).permute(0, 2, 1, 3)
        Q = Q / math.sqrt(D // self.n_head)
        scores = torch.matmul(Q, K.permute(0, 1, 3, 2))
        mask = torch.triu(torch.ones(T, T), diagonal=1).to(x.device)
        scores = scores.masked_fill(mask == 1, float('-inf'))
        attention = F.softmax(scores, dim=-1)
        out = torch.matmul(self.dropout(attention), V).permute(0, 2, 1, 3).contiguous().view(N, T, D)
        return out

## Exercise 3: Multi-Layer Perceptron

Write a class `MLP` that derives from `nn.Module` and whose `__init__` method takes two arguments: `hidden_size` and `dropout`. 
It should implement a 2-layer feedforward network with `hidden_size` inputs, `4*hidden_size` hiddens, and `hidden_size` outputs. 
It should apply the GELU activation function to the hiddens and dropout to the outputs. 

In [3]:
########## YOUR SOLUTION HERE ##########
class MLP(nn.Module):
    def __init__(self, hidden_size, dropout):
        super(MLP, self).__init__()
        self.hidden_size = hidden_size
        self.dropout = nn.Dropout(dropout)
        self.W1 = nn.Linear(hidden_size, 4*hidden_size, device=DEVICE)
        self.W2 = nn.Linear(4*hidden_size, hidden_size, device=DEVICE)
    
    def forward(self, x):
        x = self.W1(x)
        x = F.gelu(x)
        x = self.W2(x)
        x = self.dropout(x)
        return x

## Exercise 4: Block

Write a class `Block` that derives from `nn.Module` and whose `__init__` method takes arguments `hidden_size, n_head, dropout`. 
It should apply `nn.LayerNorm`, `CausalMultiHeadSelfAttention`, `nn.LayerNorm`, `MLP` in that order and feature residual connections from the input to the output of `CausalMultiHeadSelfAttention` and from there to the output of `MLP`. 

In [4]:
class Block(nn.Module):
    def __init__(self, hidden_size, n_head, dropout):
        super(Block, self).__init__()
        self.hidden_size = hidden_size
        self.n_head = n_head
        self.dropout = nn.Dropout(dropout)
        self.ln1 = nn.LayerNorm(hidden_size, device=DEVICE)
        self.ln2 = nn.LayerNorm(hidden_size, device=DEVICE)
        self.attention = MultiHeadCausalSelfAttention(hidden_size, n_head, dropout)
        self.mlp = MLP(hidden_size, dropout)

    def forward(self, x):
        x = self.attention(self.ln1(x)) + x
        x = self.mlp(self.ln2(x)) + x
        return x

## Exercise 5: GPT

Write a class `GPT` that derives from `nn.Module` and whose `__init__` method takes arguments `vocab_size, context_size, hidden_size, n_layer, n_head, dropout`. 
The `forward` method should take two arguments `x, y` representing sequences of input and target tokens, respectively, both of which have type `torch.long` and shape ($N$, $T$), and returns logits and loss as a tuple. 
The `GPT` module should feature two `nn.Embedding` layers, one for token embeddings and one for positional embedding, i.e., it should embed the position of the corresponding token within the input sequence. 
The positional embedding is necessary for the Transformer to determine the order of its inputs. 
Add the two embeddings and apply a dropout layer. 
Next, apply `n_layers` layers of `Block`s followed by a `nn.LayerNorm` and a `nn.Linear` (without bias) mapping to an output dimension of `vocab_size`. 
Finally, apply the cross-entropy loss function to the logits. 
To save some parameters, apply weight tying between the token embedding layer and the output layer, i.e., they should use the same weights. 
Initialize all weights using a normal distribution with a mean of zero and a standard deviation of 0.02 (except for the output layers of the `MLP`s use $0.02/\sqrt{2 * \mathtt{n\_layer}}$) and all biases to zero. 
Use the argument `dropout` as intensity for all dropout layers in the network. 

In [5]:
class GPT(nn.Module):
    def __init__(self, vocab_size, context_size, hidden_size, n_layer, n_head, dropout=0.0):
        super(GPT, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, hidden_size, device=DEVICE)
        self.position_embedding = nn.Embedding(context_size, hidden_size, device=DEVICE)
        self.dropout = nn.Dropout(dropout)
        self.layers = nn.ModuleList([Block(hidden_size, n_head, dropout) for _ in range(n_layer)])
        self.ln = nn.LayerNorm(hidden_size, device=DEVICE)
        self.linear = nn.Linear(hidden_size, vocab_size, bias=False, device=DEVICE)
        self.linear.weight = self.token_embedding.weight  # weight tying
        for module in self.modules():
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif isinstance(module, nn.LayerNorm):
                torch.nn.init.zeros_(module.bias)
                torch.nn.init.ones_(module.weight)
        torch.nn.init.normal_(self.linear.weight, mean=0.0, std=0.02)

    
    def forward(self, x, y):
        x = x[:, -self.position_embedding.num_embeddings:]
        x = self.token_embedding(x)
        pos_emb = self.position_embedding(torch.arange(x.shape[1], device=x.device))
        x = x + pos_emb
        x = self.dropout(x)
        for block in self.layers:
            x = block(x)
        x = self.ln(x)
        x = self.linear(x)
        logits = x
        if y is None:
            return logits
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)
        return logits, loss

## Exercise 6: Optimizer

Add a method `configure_optimizers` to the class `GPT` that takes arguments `weight_decay, learning_rate, betas`. 
Divide the model parameters into two groups. 
The first group consists of all parameters with at least 2 dimensions, e.g., weight/embedding matrices and uses a decay of `weight_decay`. 
The second group consists of all other parameters, e.g., biases and layer norms, and does not use weight decay.
Construct and return a `torch.optim.AdamW` optimizer with `learning_rate` and `betas` that operates on these two parameter groups. 

In [6]:
########## YOUR SOLUTION HERE ##########
def configure_optimizers(self, weight_decay, learning_rate, betas):
    at_least_2_dim_params = (p for p in self.parameters() if p.ndimension() >= 2)
    other_params = (p for p in self.parameters() if p.ndimension() < 2)
    optimizer = torch.optim.AdamW(
            [
                {'params': at_least_2_dim_params, 'weight_decay': weight_decay},
                {'params': other_params, 'weight_decay': 0.0}
            ], lr=learning_rate, betas=betas)
    return optimizer

GPT.configure_optimizers = configure_optimizers

## Exercise 7: Training

In the code cell below you find some globals, helper functions, and boilerplate code. Extend the given code by a training loop that 
* stops after `max_iters` iterations
* applies the learning rate schedule implemented in `get_lr`
* applies gradient clipping at `grad_clip` using `torch.nn.utils.clip_grad_norm_`
* accumulates gradients for `gradient_accumulation_steps` batches before each weight update
* logs the training loss and learning rate every `log_interval` iterations
* evaluates (and potentially checkpoints) the model using `estimate_loss` every `eval_iters` iterations.

The provided hyperparameter values should be a good guess for training a tiny model on CPU but feel free to experiment with them as you please. In particular, if you have a GPU available, you can try to scale things up a bit. 

In [7]:
eval_interval = 200 # validate model every .. iterations
log_interval = 10 # log training loss every .. iterations
eval_iters = 20 # number of batches for loss estimation
gradient_accumulation_steps = 2 # used to simulate larger training batch sizes
batch_size = 16 # if gradient_accumulation_steps > 1, this is the micro-batch size
context_size = 128 # sequence length
vocab = 'abcdefghijklmnopqrstuvwxyz0123456789 .!?' # vocabulary
vocab_size = len(vocab) # 40
n_layer = 12 # number of layers
n_head = 12 # number of attention heads
hidden_size = 768 # layer size
dropout = 0.1 # for pretraining 0 is good, for finetuning try 0.1+
learning_rate = 3e-4 # max learning rate
max_iters = 10_000 # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9 # for AdamW
beta2 = 0.999 # for AdamW
grad_clip = 0.8 # clip gradients at this value, or disable with 0.0
warmup_iters = 500 # how many steps to warm up for
min_lr = 3e-5 # minimum learning rate, usually ~= learning_rate/10

# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > max_iters, return min learning rate
    if it > max_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

def load_data(split):
    import re
    
    with open(f'trump_{split}.txt', 'r') as f:
        text = f.read()
    
    text = text.lower() # convert to lower case
    text = re.sub('[^a-z0-9 .!?]', ' ', text) # replace all unknown chars with ' '
    text = re.sub(' +', ' ', text) # reduce multiple blanks to one
    text = [vocab.index(t) for t in text]
    text = torch.tensor(text, dtype=torch.long, device=DEVICE)
    return text
    
def get_batch(split):
    data = train_data if split == 'train' else val_data
    # Random starting indices (shape: [batch_size])
    ix = torch.randint(len(data) - context_size, (batch_size,), device=DEVICE)

    # Create a 2D index tensor of shape batch_size X context_size
    #  For each element in ix, we want to collect [i, i+1, ..., i+context_size-1].
    #  So we broadcast-add a range of length `context_size` to each element of ix.
    x_positions = ix.unsqueeze(-1) + torch.arange(context_size, device=DEVICE)
    y_positions = x_positions + 1  # Shift by 1
    x = data[x_positions]  # batch_size X context_size
    y = data[y_positions]  # batch_size X context_size

    return x, y
    # old function
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - context_size, (batch_size,))
    x = torch.stack([data[i:i+context_size] for i in ix])
    y = torch.stack([data[i+1:i+1+context_size] for i in ix])
    return x, y

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# data, model, optimizer, etc. 
train_data = load_data('train')
val_data = load_data('val')
model = GPT(vocab_size, context_size, hidden_size, n_layer, n_head)
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2))
iter_num = 0
best_val_loss = 1e9
X, Y = get_batch('train') # fetch the very first batch
t0 = time.time()

########## YOUR SOLUTION HERE ##########

for iter_num in range(max_iters):
    optimizer.zero_grad()
    logits, loss = model(X, Y)
    loss.backward()
    if grad_clip > 0.0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer.step()
    if iter_num % log_interval == 0:
        print(f'[{iter_num}/{max_iters}] loss={loss.item()}')
    if iter_num % eval_interval == 0:
        val_loss = estimate_loss()['val']
        print(f'[{iter_num}/{max_iters}] val_loss={val_loss}')
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
    X, Y = get_batch('train')
    lr = get_lr(iter_num)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

print(f'training took {time.time()-t0} seconds')

[0/10000] loss=3.8327651023864746
[0/10000] val_loss=4.317932605743408
[10/10000] loss=3.3078763484954834
[20/10000] loss=3.0520565509796143
[30/10000] loss=2.8776397705078125
[40/10000] loss=2.871124267578125
[50/10000] loss=2.828428268432617
[60/10000] loss=2.6825506687164307
[70/10000] loss=2.530851364135742
[80/10000] loss=2.492325782775879
[90/10000] loss=2.485316753387451
[100/10000] loss=2.4421496391296387
[110/10000] loss=2.372389793395996
[120/10000] loss=2.3307151794433594
[130/10000] loss=2.367492914199829
[140/10000] loss=2.3824682235717773
[150/10000] loss=2.3652303218841553
[160/10000] loss=2.345362901687622
[170/10000] loss=2.388284921646118
[180/10000] loss=2.352450370788574
[190/10000] loss=2.239760398864746
[200/10000] loss=2.2890467643737793
[200/10000] val_loss=2.3523406982421875
[210/10000] loss=2.2736756801605225
[220/10000] loss=2.2879414558410645
[230/10000] loss=2.3273520469665527
[240/10000] loss=2.2648463249206543
[250/10000] loss=2.2978127002716064
[260/1000

KeyboardInterrupt: 

I early stopped training here!

## Exercise 8: Inference

Add a method `generate` to the class `GPT` that takes arguments `x, max_new_tokens, temperature=1.0`. 
The method should take a batch of token sequences `x`, which it should extend by `max_new_tokens` new tokens generated by the model. 
Once you have computed the logits for the next token, divide them by `temperature` before applying the softmax. 
After applying the softmax, sample the next token from the resulting categorical distribution. 
Try out different values for `temperature` and compare the results to those from the previous assignment. 

In [8]:
def generate(self, x, max_new_tokens, temperature=1.0):
    self.eval()
    N, T = x.size()
    output = torch.zeros(size=(N, T + max_new_tokens), dtype=torch.long, device=x.device)
    output[:, :T] = x
    for t in range(T, T + max_new_tokens):
        logits = self(output[:, t-context_size:t], None)[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        output[:, t] = torch.multinomial(probs, 1).squeeze(-1)
    return output

GPT.generate = generate

In [12]:
temps = [0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.8, 1.0]

for temp in temps:
    model.load_state_dict(torch.load('best_model.pth'))
    model.eval()
    sentence = 'make america great'
    x = torch.ones(size=(1, context_size), dtype=torch.long, device=DEVICE) * vocab.index(' ')
    x[:, -len(sentence):] = torch.tensor([vocab.index(c) for c in sentence], dtype=torch.long).unsqueeze(0)
    output = model.generate(x, 500, temperature=temp)
    text = ''.join([vocab[i] for i in output[0].cpu().numpy()])

    print(f'--- temperature={temp} ---')
    print(text)

--- temperature=0.1 ---
                                                                                                              make america great again. the days of deadly ignorance will end that we will not be able to deal and all of that. but we have to rebuild our country. our country is going to be one of the countries of the way establishment people in and we ve seen in the polls. no i don t think i ve even say that before. i ve already gone after this before. the election is not bad. i said no no no. i don t think i think it s a man of the other reason i think it s going to be very hard. i mean i think it s a much bigger than anybo
--- temperature=0.2 ---
                                                                                                              make america great again. the arm you know the dreamers workers again. the margins against the world. the old post office we have to rebuild our country. we have a presidential race to do a fabulous job. we have a