<p style="text-align: center; font-size:50px;">Implementing Transformers</p>

#### In this notebook, I will show you how transformers can be implemented. 
#### PyTorch is a fantastic framework that supports transformers with just a few lines of code. 
#### To run data through the models to show that it works, we wil be using the wikitext that is available on Hugging Face but do not expect superb results from it. 
#### I can tell you from personal experience that with free cloud notebooks online utilizing strong GPUs and large amount of memory, they are not sufficient enough to build a language model that would converse like a human being. 
#### Therefore, the focus for this notebook will purely be about the implementation of Transformers. 
#### I will closely follow the architecture shown in the "Attention is all you need" paper and the official guide about transformers on PyTorch's website. 

# Credit 
* https://arxiv.org/pdf/1706.03762.pdf
* https://pytorch.org/tutorials/beginner/transformer_tutorial.html
* https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec

In [3]:
import torch 
import torch.nn as nn 
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.nn import functional as F
import math
import torchinfo 
from torchinfo import summary
from tqdm import tqdm
import gc
import datasets

# Loading Data

In [4]:
import datasets
from datasets import load_dataset
from torchtext.datasets import WikiText2

dataset = datasets.load_dataset('wikitext', 'wikitext-2-v1')

Downloading builder script:   0%|          | 0.00/8.48k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/6.84k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.25k [00:00<?, ?B/s]

Downloading and preparing dataset wikitext/wikitext-2-v1 to /root/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126...


Downloading data:   0%|          | 0.00/4.48M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Dataset wikitext downloaded and prepared to /root/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
device = torch.device('cuda') # Setting up the device 

print(device)

# Tokenizing each word 
tokenizer = get_tokenizer('basic_english')
# Building a vocab 
vocab = build_vocab_from_iterator(map(tokenizer, dataset['train']['text']), min_freq=2, specials=['<unk>', '<eos>'])
vocab.set_default_index(vocab['<unk>']) 

print(f"The length of the vocab is {len(vocab)}")                         
print(vocab.get_itos()[:10])    

cuda
The length of the vocab is 28783
['<unk>', '<eos>', 'the', ',', '.', 'of', 'and', 'in', 'to', 'a']


In [6]:
# Splitting the data 

encoded_data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in dataset['train']['text'] if len(item) > 0]
n = int(0.8*len(encoded_data)) # first 80% will be train, rest val
train_data = encoded_data[:n]
val_data = encoded_data[n:]

del encoded_data
gc.collect()

# train_data and val_data will have encoded data (numerical representation of our text inputs)

93

In [7]:
flat_train_data = torch.cat(train_data)
flat_val_data = torch.cat(val_data)

del train_data
del val_data
gc.collect()

0

In [8]:
# Batching 
# We have to come up with a function to be able to appropriately batch our data

batch_size = 64 # The batch size our model will process in parallel 
block_size = 128 # Maximum length that feeds into our model 
number_of_batches_train = len(flat_train_data) // (batch_size*block_size) # Total number of batches in the training set
number_of_batches_val = len(flat_val_data) // (batch_size*block_size) # Total number of batches in the validation set 

def get_batch(split, batch_idx):
  data = flat_train_data if split == 'train' else flat_val_data  
  X = torch.stack([data[batch_idx*batch_size*block_size + i*block_size:batch_idx*batch_size*block_size + i*block_size+block_size] for i in range(batch_size)])
  y = torch.stack([data[batch_idx*batch_size*block_size + i*block_size+1:batch_idx*batch_size*block_size + i*block_size+block_size+1] for i in range(batch_size)])
  return X.to(device), y.to(device)

X, y = get_batch('train', 197)
print(f"This is what X looks like:\n{X}\n")
print(f"This is what y looks like:\n{y}")
print("")
print(f"We have {number_of_batches_train} number of training batches\n")
print(f"We have {number_of_batches_val} number of validation batches")

This is what X looks like:
tensor([[   23,   169,     8,  ...,    20,    69,   114],
        [    5,    37,  1951,  ...,    10,    10,    10],
        [    2,   209,   113,  ...,     6, 12079,    10],
        ...,
        [   11,   409,   522,  ...,   302,  1456,  1757],
        [ 2083,   130,  5603,  ...,  3195,     4,    91],
        [   37,  1144,     6,  ...,  2417,  3256,     5]], device='cuda:0')

This is what y looks like:
tensor([[  169,     8,    71,  ...,    69,   114,     5],
        [   37,  1951,     3,  ...,    10,    10,     2],
        [  209,   113,    34,  ..., 12079,    10,    10],
        ...,
        [  409,   522,     8,  ...,  1456,  1757,  2083],
        [  130,  5603,    15,  ...,     4,    91,    37],
        [ 1144,     6,    28,  ...,  3256,     5,  1416]], device='cuda:0')

We have 199 number of training batches

We have 51 number of validation batches


# Building model using PyTorch's nn.Transformer module.

#### Firstly, we will build our transformer using the nn.Transformer module. 

#### We will first build the positional encoder. 
#### Since Transformers do not take in input sequentially unlike RNNs, we would have to relay the information about each token's positioning relative to the whole input. 
#### The research paper states that it simply utilized sin and cos functions to relay this information. 

In [9]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()

        position = torch.arange(max_len).unsqueeze(1) # positioning vector 
        sin_div_term = 10000 ** (torch.arange(0, d_model, 2) / d_model)
        cos_div_term = 10000 ** (torch.arange(1, d_model, 2) / d_model)
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position / sin_div_term)
        pe[0, :, 1::2] = torch.cos(position / cos_div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        batch_size, seq_len, embed_dim = x.size()

        x = x + self.pe[:, :seq_len, :] # only until the actual sequence length 

        return x.to(device)

print("Testing to ensure positional encoder works\n")
eg = torch.randn(64, 128, 300)

try:    
    pe = PositionalEncoding(300)
    out = pe(eg)
    print(f"The shape of the output is {out.shape}")
except:
    print("Error!")

Testing to ensure positional encoder works

The shape of the output is torch.Size([64, 128, 300])


#### Ok! 
#### Seems like our positional encoder is working properly.
#### We can now proceed on to utilize the transformer module from PyTorch.

In [10]:
vocab_size = len(vocab)
d_model = 300 
n_head = 3
n_layers = 7
d_hidden = 2000
dropout = 0.4

def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
    """Generates an upper-triangular matrix of ``-inf``, with zeros on ``diag``."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1).to(device)

class Transformer(nn.Module):

    def __init__(self, vocab_size:int, d_model:int, n_head:int, n_layers:int, d_hidden:int, dropout:float = 0.1):
        super().__init__()
        self.positional_encoder = PositionalEncoding(d_model, max_len = 200) 
        self.embedding = nn.Embedding(vocab_size, d_model)

        # Encoder 
        encoder_layers = nn.TransformerEncoderLayer(d_model, n_head, d_hidden, dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, n_layers)

        # Decoder
        decoder_layers = nn.TransformerDecoderLayer(d_model, n_head, d_hidden, dropout, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layers, n_layers)

        self.out = nn.Linear(d_model, vocab_size)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.out.bias.data.zero_()
        self.out.weight.data.uniform_(-initrange, initrange)

    def forward(self, x, target, mask) -> torch.Tensor:
        x = self.embedding(x)
        x = self.positional_encoder(x)
        x = self.transformer_encoder(x)
        x = self.transformer_decoder(self.positional_encoder(self.embedding(target)), x, mask)
        x = self.out(x)

        logits = x.clone().detach()

        # Calculating loss 
        B, T, C = x.shape
        x = x.view(B*T, C)
        target = target.view(B*T)
        loss = nn.CrossEntropyLoss()(x, target)
        return logits, loss

mask = generate_square_subsequent_mask(block_size)
model = Transformer(vocab_size, d_model, n_head, n_layers, d_hidden, dropout).to(device)
summary(model)


Layer (type:depth-idx)                                            Param #
Transformer                                                       --
├─PositionalEncoding: 1-1                                         --
├─Embedding: 1-2                                                  8,634,900
├─TransformerEncoder: 1-3                                         --
│    └─ModuleList: 2-1                                            --
│    │    └─TransformerEncoderLayer: 3-1                          1,564,700
│    │    └─TransformerEncoderLayer: 3-2                          1,564,700
│    │    └─TransformerEncoderLayer: 3-3                          1,564,700
│    │    └─TransformerEncoderLayer: 3-4                          1,564,700
│    │    └─TransformerEncoderLayer: 3-5                          1,564,700
│    │    └─TransformerEncoderLayer: 3-6                          1,564,700
│    │    └─TransformerEncoderLayer: 3-7                          1,564,700
├─TransformerDecoder: 1-4                 

#### Before we come up with any training loop, it's good practice to observe whether our model is able to take in our input and output without any problems.
#### Let's try out our model with the first batch of our data.

In [11]:
X, y = get_batch('train', 0)
try:
    out, loss = model(X, y, mask)
    print("Model working successfully")
except:
    print("Error occured!")

Model working successfully


#### Perfect!
#### Seems like everything is working properly. 
#### Let's write up our training and validation loop now and watch the training progress.

In [12]:
for p in model.parameters(): # Ensuring model's parameters stay in a range to prevent explosion or vanishing
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [13]:
@torch.no_grad()
def validate():
    model.eval()
    validation_loss = 0 
    count = 0 
    mask = generate_square_subsequent_mask(block_size)
    for batch_idx in range(number_of_batches_val):
        X, y = get_batch('val', batch_idx)
        _, loss = model(X, y, mask)
        validation_loss += loss.item()
        count += 1
    return validation_loss / count 
    

In [None]:
import warnings
warnings.filterwarnings("ignore")

epoch = 100
log_interval = 50
learning_rate = 0.01
eval_iters = 50

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, eps=1e-9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
train_loss = 0
count = 0
best_loss = float('inf')
mask = generate_square_subsequent_mask(block_size)

for epoch in range(epoch):
    print(f"====Epoch {epoch}====")
    for batch_idx in range(number_of_batches_train):

        # sample a batch of data
        X, y = get_batch('train', batch_idx)

        # Validating
        if (batch_idx+1) % eval_iters == 0:
            losses = validate()
            print(f"epoch: {epoch+1}, batch_idx: {batch_idx+1}, train loss: {train_loss/count:.4f}, val loss: {losses:.4f}, lr: {scheduler.get_last_lr()[0]:.6f}")
            if losses < best_loss:
                print("Saving better model")
                torch.save(model.state_dict(), './best_params.pt') # Saving best model parameters
                best_loss = losses
            print("\n")
            scheduler.step()

        model.train()
        # Training
        logits, loss = model(X, y, mask)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        count += 1   

====Epoch 0====
epoch: 1, batch_idx: 50, train loss: 7.6192, val loss: 8.0036, lr: 0.010000
Saving better model


epoch: 1, batch_idx: 100, train loss: 7.3791, val loss: 8.1459, lr: 0.009500


epoch: 1, batch_idx: 150, train loss: 7.2881, val loss: 7.8042, lr: 0.009025
Saving better model


====Epoch 1====
epoch: 2, batch_idx: 50, train loss: 7.1818, val loss: 7.9196, lr: 0.008574


epoch: 2, batch_idx: 100, train loss: 7.1396, val loss: 7.4546, lr: 0.008145
Saving better model


epoch: 2, batch_idx: 150, train loss: 7.1099, val loss: 7.4691, lr: 0.007738


====Epoch 2====
epoch: 3, batch_idx: 50, train loss: 7.0723, val loss: 7.5743, lr: 0.007351


epoch: 3, batch_idx: 100, train loss: 7.0599, val loss: 7.4409, lr: 0.006983
Saving better model


epoch: 3, batch_idx: 150, train loss: 7.0495, val loss: 7.3011, lr: 0.006634
Saving better model


====Epoch 3====
epoch: 4, batch_idx: 50, train loss: 7.0304, val loss: 7.3483, lr: 0.006302


epoch: 4, batch_idx: 100, train loss: 7.0235, val 

#### The losses definitely seem to decrease but only until a point such as around 7.1 for the validation loss. 
#### Our model's architecture definitely is not complex enough with only around 40 million parameters. 
#### However I hope you got to see how we can implement transformers using PyTorch and it's accompanying nn.Transformer module. 