#### Transformer Language Model

Similar to what we did for the RNN, we are given a sequence of word embedding vectors $w_1, .., w_N$, we pass the sequence through a `Transformer decoder network` (i.e. causal masking is applied to the attention scores) to obtain a sequence of output vectors $h_1,..,h_N$. Then each of these $h_i$ can be regarded as a contextual repsentation of the words $w_1, .., w_i$. So using this output, we can compute a probability distribution for the next word that follows all the preceding words:

$P(w_{i+1} | w_1,...,w_i) = f(h_i)$

Again, as in the case of the RNN, we will use a simple feed forward network for the transformation $f$, in this case a linear projection followed by a softmax. Instead of using word-level tokenization, we will use sub-word tokenization, in particular Byte-Pair Encoding. We will train the model on the collected work of Shakespeare.

In [1]:
%load_ext autoreload
%autoreload 2

import random, math
from nltk.tokenize import word_tokenize
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import psutil
from BPE import BPE
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
# prep the training data
with open('shakespeare.txt', 'r') as file:
    lines = file.readlines()

In [3]:
# now lets train a BPE tokenizer on the Shakespeare corpus
tokenizer = BPE(max_vocab_size=10000, sos_token="<s>", eos_token="</s>")
tokenizer.learn(lines)
tokenizer.precompute_word_tokens(lines)
print(f"\nLearned vocab: {tokenizer.vocab}\n")


Number of words in corpus: 334797


Building vocab. Num tokens added --> :  22%|██▏       | 2249/10000 [00:38<01:54, 67.70it/s]

In [None]:
# encode and decode on test sentence to make sure tokenizer is working properly 
s = ["Yeah, I'm gonna inform them today."]
s_tokenized= tokenizer.tokenize_sentences(s)
s_encoded= tokenizer.encode(s)
s_decoded= tokenizer.decode(s_encoded)
print(s_tokenized)
print(s_encoded)
print(s_decoded)

[['<s>_', 'Y', 'ea', 'h_', ',_', 'I_', "'", 'm_', 'g', 'on_', 'na_', 'inform_', 'them_', 'to', 'day_', '._', '</s>_']]
[[94, 2105, 4214, 5226, 72, 1024, 5, 6265, 4970, 6884, 6656, 5645, 9029, 9192, 3759, 81, 92]]
["<s> Yeah , I 'm gon na inform them today . \n"]


In [None]:
# save the trained tokenizer to file

import pickle
with open("BPE_tokenizer.pkl", 'wb') as file:
    pickle.dump(tokenizer, file)
    
"""
# load trained BPE tokenizer from file
with open('BPE_tokenizer.pkl', 'rb') as f:
    tokenizer = pickle.load(f)
"""

"\n# load trained BPE tokenizer from file\nwith open('BPE_tokenizer.pkl', 'rb') as f:\n    tokenizer = pickle.load(f)\n"

Tokenize the sentences and create train-test splits

In [None]:
# subword tokenize the sentences and convert to integer tokens
sentences_tokenized = tokenizer.encode(lines)

# create train-val splits
num_sent = len(sentences_tokenized)
num_test = int(0.1 * num_sent)
x_train = sentences_tokenized[:-num_test]
x_val   = sentences_tokenized[-num_test:]

# concatenate all sentences
x_train = [w for s in x_train for w in s] 
x_val   = [w for s in x_val for w in s] 

In [None]:
print(f"Num tokens train: {len(x_train)}")
print(f"Num tokens val: {len(x_val)}")

Num tokens train: 316352
Num tokens val: 32542


Create pytorch dataset

In [None]:
class Shakespeare(Dataset):
    def __init__(self, corpus, block_size=16):
        self.corpus = corpus
        self.block_size = block_size

    def __len__(self):
        return len(self.corpus)-self.block_size-1

    def __getitem__(self, idx):
        inputs = torch.tensor(self.corpus[idx:idx+self.block_size], dtype=torch.long)
        targets = torch.tensor(self.corpus[idx+1:idx+1+self.block_size], dtype=torch.long)
        return inputs, targets

Create Transformer decoder model

In [None]:
class TransformerLM(torch.nn.Module):
    def __init__(self, vocab_size, blocks_size, embedding_dim=16, feedforward_dim=64, num_heads=1, num_layers=1, dropout_rate=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.block_size = blocks_size
        # embedding layers
        self.emb = torch.nn.Embedding(vocab_size, embedding_dim)
        self.pos_emb = torch.nn.Embedding(blocks_size, embedding_dim)
        c = 0.01        
        torch.nn.init.uniform_(self.emb.weight, -c, c)
        torch.nn.init.uniform_(self.pos_emb.weight, -c, c)

        # transformer decoder
        decoder_layer = torch.nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=num_heads, dim_feedforward=feedforward_dim, dropout=dropout_rate, batch_first=True)
        self.transformer_decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # create output layer (computes output class logits for each item in sequence)
        self.output_layer =  torch.nn.Linear(embedding_dim, vocab_size)
        # tie the output layer weights with the embedding layer weights
        self.output_layer.weight = self.emb.weight

    def create_causal_mask(self, input):
        _, L, _ = input.shape
        # create an L x L matrix with ones on and below diagonal
        mask = torch.tril(torch.ones(size=(L,L), device=input.device))
        # create mask in which the positions where there is a zero is filled with -infinity 
        mask = mask.masked_fill((mask==0), float("-inf"))
        return mask

    def sinusoidal_positional_encoding(self, input):
        _, L, D = input.shape
        # static positional encoding (max length set to 1024)
        pos_emb = torch.zeros(size=(L, D), device=input.device)
        for pos in range(L):
            for i in range(0, D, 2):
                pos_emb[pos, i] = math.sin(pos / (10000 ** ((2 * i)/D)))
                if i+1 < D:
                    pos_emb[pos, i+1] = math.cos(pos / (10000 ** ((2 * (i+1))/D)))
        return pos_emb

    # forward pass
    def forward(self, x, y=None):
        # get embeddings for batch of input sequences of length L
        x = self.emb(x) # shape: (B,L,D)
        # add positional embedding
        #x = x + self.sinusoidal_positional_encoding(x)# shape: (B,L,D)
        x = x + self.pos_emb(torch.arange(x.shape[1], device=x.device)) # shape: (B,L,D)
        # pass through transformer decoder layers
        mask = self.create_causal_mask(x)
        x = self.transformer_decoder(x, x, tgt_mask=mask) # shape: (B,L,D)
        # compute output logits
        x = self.output_layer(x) # shape: (B,L,vocab_size)

        if y==None:
            return x

        # reshape
        x = x.view(-1,x.shape[-1]) # shape: (B*L,vocab_size)
        y = y.view(-1) # shape: (B*L,)
        # compute cross entropy loss
        loss = F.cross_entropy(x, y)
        return x, loss
    
    @torch.no_grad()
    def generate(self, vocab, subword2idx, temperature=1.0, topk=None, start_token="<s>", end_token="</s>", max_len=30, device="cpu"):
        self.eval()
        # generate one token at a time
        x = torch.full(size=(1,1), fill_value=subword2idx[start_token], dtype=torch.long, device=device)
        tokens = [x.item()]
        for _ in range(max_len):
            # crop the input sequence so that it doesn't exceed block size (only keep the last block_size tokens in the sequence to generate the next token)
            x = x[:,-self.block_size:]
            logits = self.forward(x) # shape: (1,L,V)
            # rescale the logits with the temperature
            logits = logits / temperature
            if topk is not None:
                topk_logits, idx = torch.sort(logits[0,-1,:], descending=True)
                # sample from the distribution for the last word in the sequence
                p = F.softmax(topk_logits, dim=-1) # shape: (V,)
                next_word_idx = idx[torch.multinomial(p, num_samples=1)]
            else:             
                # sample from the distribution for the last word in the sequence
                p = F.softmax(logits[:,-1,:], dim=-1) # shape: (V,)
                next_word_idx = torch.multinomial(p, num_samples=1)
            # append to the sequence
            x = torch.cat((x, next_word_idx.view(1,1)), dim=1)
            tokens.append(next_word_idx.item())

        self.train()
        return tokens

# training loop
def train(model, optimizer, scheduler, train_dataloader, val_dataloader, device="cpu", num_epochs=10, val_every=1, save_every=10, log_metrics=None):
    avg_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    pp = 0
    for epoch in range(num_epochs):
        num_correct = 0
        num_total = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for batch in pbar:
            inputs, targets = batch
            # move batch to device
            inputs, targets = inputs.to(device), targets.to(device)
            # reset gradients
            optimizer.zero_grad()
            # forward pass
            logits, loss = model(inputs, targets)
            # backward pass
            loss.backward()
            # optimizer step
            optimizer.step()
            avg_loss = 0.9* avg_loss + 0.1*loss.item()
            y_pred = logits.argmax(dim=-1) # shape (B*L)
            y = targets.view(-1) # shape (B*L)
            mask = (y != -1)
            num_correct += (torch.eq(y[mask], y_pred[mask])).sum().item()
            num_total += len(y[mask])
            
            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}, Val Perplexity: {pp:.1f}")  

            if log_metrics:
                metrics = {"Batch loss" : loss.item(), "Moving Avg Loss" : avg_loss, "Val Loss": val_loss}
                log_metrics(metrics)

        scheduler.step()
        train_acc = num_correct / num_total        
        if epoch%val_every == 0:
            # compute validation loss
            val_loss, val_acc, pp = validation(model, val_dataloader, device=device)

        if (epoch+1) % save_every == 0:
            save_model_checkpoint(model, optimizer, epoch, avg_loss)

def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        num_correct = 0
        num_total = 0
        for i,batch in enumerate(val_dataloader):
            inputs, targets = batch = batch
            inputs, targets = inputs.to(device), targets.to(device)
            logits, loss = model(inputs, targets)
            y_pred = logits.argmax(dim=-1) # shape (B*L)
            y = targets.view(-1) # shape (B*L)
            mask = (y != -1)
            num_correct += (torch.eq(y[mask], y_pred[mask])).sum().item()
            num_total += len(y[mask])
            val_losses[i] = loss.item()
    model.train()
    val_loss = val_losses.mean().item()
    val_accuracy = num_correct / num_total
    perplexity = math.exp(val_loss)
    return val_loss, val_accuracy, perplexity


def save_model_checkpoint(model, optimizer, epoch=None, loss=None):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }

    # Save the checkpoint to a file
    torch.save(checkpoint, 'transformer_checkpoint.pth')
    print(f"Saved model checkpoint!")


def load_model_checkpoint(model, optimizer):
    checkpoint = torch.load('transformer_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.train()
    print("Loaded model from checkpoint!")
    return model, optimizer      

In [None]:
L = 32
train_dataset = Shakespeare(x_train, block_size=L)
val_dataset = Shakespeare(x_val, block_size=L)

In [None]:
B = 32
D = 32
vocab_size = len(tokenizer.vocab)
num_heads = 8
num_layers = 4
learning_rate = 1e-4
DEVICE = "cuda"

train_dataloader = DataLoader(train_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)


model = TransformerLM(vocab_size, L, embedding_dim=D, feedforward_dim=4*D, num_heads=num_heads, num_layers=num_layers, dropout_rate=0.1).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.95)
#model, optimizer = load_model_checkpoint(model, optimizer)


num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Total number of parameters in transformer network: 0.367285 M
RAM used: 1010.53 MB


In [None]:
train(model, optimizer, scheduler, train_dataloader, val_dataloader, device=DEVICE, num_epochs=1, save_every=100, val_every=1) #, log_metrics=log_metrics)

Epochs:   0%|          | 0/9885 [00:00<?, ?it/s]

Epoch 1, EMA Train Loss: 4.262, Train Accuracy:  0.000, Val Loss:  0.000, Val Accuracy:  0.000, Val Perplexity: 0.0: 100%|██████████| 9885/9885 [01:36<00:00, 102.03it/s]


In [None]:
tokens = model.generate(tokenizer.vocab, tokenizer.subword2idx, temperature=1.0, topk=None, device=DEVICE, max_len=2000)
decoded = tokenizer.decode([tokens])
print(decoded[0])

<s>Upon Whiter ced burden gues armed deliverance 
 <s> And f-resipenitence ebb le vitae 
 <s> Ah , Edward experty , black ewtribunes ; That tily and twispeechless er ear . 
 <s> COMINIUS : 
 <s> Why , Hontolmber ; Dockws 't , keys , and supLa's duke child . 
 <s> Grithirty or Clarence . he ne'er pense with our lookof men 
 <s> If I when I 'll not warrant : 
 <s> criservant can do leave Warwick we slain ; 
 <s> Is full on the end with Romeo , I 'll good company 
 <s> And patence 
 <s> You , , here is been eyes , youth , this die , 
 <s> If I cut my unpreserved . Unto the entrails 
 <s> But he , that speak'st they have in this well ! I big is early . 
 <s> 
 <s> 
 <s> Away was thy yourself he , she were in them : Menenius , 
 <s> To brittle living . 
 <s> this a angels impossible him . 
 <s> 
 <s> 
 <s> 
 <s> CLAUDIO : 
 <s> Good friends changed to rough prize and thou good who well from him : 
 <s> But I know to make me peace than doubt the Padua by this wound 
 <s> Sgentleman great eas