<p style="text-align: center; font-size:50px;">About Transformer</p>

### Credits
* https://pytorch.org/tutorials/beginner/transformer_tutorial.html
* https://blog.floydhub.com/the-transformer-in-pytorch/

![img](https://cdn-images-1.medium.com/max/800/1*2vyKzFlzIHfSmOU_lnQE4A.png)

#### This is what the very famous transformer model looks like.
#### This is where the field of Artificial Intelligence and NLP really took off. 
#### I would say that fact alone warrants a good look into the inner workings of a transformer.

<p style="text-align: center; font-size:30px;">Embedding</p>

#### For anyone who dipped their toes into the world of NLP, embedding is not a foreign term. 
#### It's a long way from the simple one-hot encoding of words. 
#### An Embedding layer represents each token as a vector of certain dimensions. 
#### The weights that control each value of these multi dimensional vectors are then trained through gradient descent throughout the training process.
#### With PyTorch, it is relatively simple to use the [Embedding layer](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html), which we can simply call as such.

In [1]:
import torch 
import torch.nn as nn 

vocab_size = 10
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim= 300) # 300 is a very frequent value for embedding dimensions throughout papers but all depends on individual experiments

input = torch.tensor([[1, 3, 5, 6, 7, 8],
                     [4, 5, 6, 2, 6, 3]])

out = embedding(input)

print("====Output====\n")
print(out)
print("")
print("====Shape of Output====\n")
print(f"{out.shape} -> [batch_size, sequence_length, embedding_dimensions]")

====Output====

tensor([[[ 1.4846, -1.1310, -0.1385,  ..., -2.2551,  0.2152,  1.0611],
         [-1.4802,  0.4204,  0.8213,  ..., -1.0180,  2.0208,  0.6539],
         [ 0.9492, -0.4499,  1.1388,  ...,  1.1301,  0.8191, -2.3527],
         [-0.7887, -1.1344, -1.1274,  ..., -1.2663,  0.1031, -0.0537],
         [ 0.1318, -0.8184,  0.5276,  ..., -0.4793, -0.1821,  0.0338],
         [-0.7172, -1.2700, -0.3463,  ...,  0.7418, -0.6525, -1.4330]],

        [[ 1.0981, -1.1134,  0.0327,  ..., -1.1440,  0.0639,  0.0747],
         [ 0.9492, -0.4499,  1.1388,  ...,  1.1301,  0.8191, -2.3527],
         [-0.7887, -1.1344, -1.1274,  ..., -1.2663,  0.1031, -0.0537],
         [ 0.9758, -0.4939,  1.5580,  ..., -0.5166,  1.2630, -0.9081],
         [-0.7887, -1.1344, -1.1274,  ..., -1.2663,  0.1031, -0.0537],
         [-1.4802,  0.4204,  0.8213,  ..., -1.0180,  2.0208,  0.6539]]],
       grad_fn=<EmbeddingBackward0>)

====Shape of Output====

torch.Size([2, 6, 300]) -> [batch_size, sequence_length, embeddin

#### Now apart from the input encoding, there is also the positional encoding that needs to be taken into account. 
#### Positional encoding is what gives our word context.
#### This is represented by constants of position-specific values. 
#### The positional encoding must also have same dimension as input embeddings so that the two matrix can be summed up eventually. 

In [2]:
import math

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

#### This will add positional information to our embedding vector. 
#### It will give our individual word/token meaning and add contextual meaning on top of that.

<p style="text-align: center; font-size:30px;">Creating our Masks</p>

#### Masking in transformers is important for two different reasons. 
#### It helps to zero attention outputs whenever there is just simple padding in the input sentence. 
#### Additionally, in the decoder, it prevents itself from 'peaking ahead' at the rest of the output. 

In [3]:
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

In [4]:
generate_square_subsequent_mask(5)

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

<p style="text-align: center; font-size:30px;">Multi Headed Attention</p>

#### With our embeddings and masks, we can build the core of our model, the multi headed attention module.
#### With PyTorch, implementing this [module](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) is once again very simple.

<p style="text-align: center; font-size:30px;">Feed Forward Network</p>

#### Now this is the simplest part of them all. 
#### After passing through all of these layers and encoding, we would need to pass it through a dense forward layer.

In [5]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout = 0.1):
        super().__init__() 
        # We set d_ff as a default to 2048
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
    def forward(self, x):
        x = self.dropout(F.relu(self.linear_1(x)))
        x = self.linear_2(x)
        return x

#### This is practically all we need to know for now. 
#### This is the rough overview of how a transformer is built. 
#### Now let's apply this model to some data.

<p style="text-align: center; font-size:30px;">Data</p>

#### I will be utilizing wiki text from PyTorch's default datasets. 
#### With these texts, I will format them into batches of certain lengths, tokenize them while building a vocabulary from it. 

In [6]:
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torchtext
import datasets

dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1')

train_iter = dataset['train']
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter['text']), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])
print(f"The length of the vocab is {len(vocab)}") 

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset wikitext (/Users/kimhyunbin/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
100%|██████████| 3/3 [00:00<00:00, 1183.49it/s]


The length of the vocab is 66058


In [7]:
def data_process(raw_text_iter) -> torch.Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item['text'])), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

train_data = data_process(dataset['train'])
val_data = data_process(dataset['validation'])
test_data = data_process(dataset['test'])

In [8]:
device = torch.device('mps')

# Dividing data into equal length of numerous batches
def batchify(data: torch.Tensor, batch_size: int) -> torch.Tensor:
    seq_len = data.size(0) // batch_size
    data = data[:seq_len * batch_size]
    data = data.view(batch_size, seq_len).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape ``[seq_len, batch_size]``
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

#### We would also have to come up with a function that would give us our input and target tensors. 
#### With language modeling, our input will be a sequence of text and output will generally be the sequence of text but shifted to the right by 1. 
#### This is so that our model is trained on data and will be able to predict the immediate next words that come along. 

In [9]:
import typing

bptt = 35
def get_batch(source: torch.Tensor, i: int) -> typing.Tuple[torch.Tensor, torch.Tensor]:
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

<p style="text-align: center; font-size:30px;">Training</p>

In [10]:
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer

ntokens = len(vocab)  # size of vocabulary
emsize = 200  # embedding dimension
d_hid = 200  # dimension of the feedforward network model in ``nn.TransformerEncoder``
nlayers = 3  # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder``
nhead = 2  # number of heads in ``nn.MultiheadAttention``
dropout = 0.2  

class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output

model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)

#### The [nn.TransformerEncoderLayer](https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html) makes the whole process of having to write out the multi attention heads extremely simple.
#### Let's write the training loop code and train it!

In [11]:
import torchinfo 
torchinfo.summary(model)

Layer (type:depth-idx)                                            Param #
TransformerModel                                                  --
├─PositionalEncoding: 1-1                                         --
│    └─Dropout: 2-1                                               --
├─TransformerEncoder: 1-2                                         --
│    └─ModuleList: 2-2                                            --
│    │    └─TransformerEncoderLayer: 3-1                          242,000
│    │    └─TransformerEncoderLayer: 3-2                          242,000
│    │    └─TransformerEncoderLayer: 3-3                          242,000
├─Embedding: 1-3                                                  13,211,600
├─Linear: 1-4                                                     13,277,658
Total params: 27,215,258
Trainable params: 27,215,258
Non-trainable params: 0

In [12]:
import time 
from tqdm import tqdm

criterion = nn.CrossEntropyLoss()
lr = 5.0 
optimizer = torch.optim.SGD(model.parameters(), lr = lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def train(model: nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 200
    start_time = time.time()
    src_mask = generate_square_subsequent_mask(bptt).to(device)

    num_batches = len(train_data) // bptt
    for batch, i in enumerate(tqdm(range(0, train_data.size(0) - 1, bptt))):
        data, targets = get_batch(train_data, i)
        seq_len = data.size(0)
        if seq_len != bptt:  # only on last batch
            src_mask = src_mask[:seq_len, :seq_len]
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # Prevent exploding gradients 
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()

def evaluate(model: nn.Module, eval_data: torch.Tensor) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    src_mask = generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        for i in range(0, eval_data.size(0) - 1, bptt):
            data, targets = get_batch(eval_data, i)
            seq_len = data.size(0)
            if seq_len != bptt:
                src_mask = src_mask[:seq_len, :seq_len]
            output = model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += seq_len * criterion(output_flat, targets).item()
    return total_loss / (len(eval_data) - 1)

In [48]:
best_val_loss = float('inf')
epochs = 3

for epoch in range(1, epochs + 1):
    print(f"====Epoch {epoch}====")
    epoch_start_time = time.time()
    train(model)
    val_loss = evaluate(model, val_data)
    val_ppl = math.exp(val_loss)
    elapsed = time.time() - epoch_start_time
    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
        f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model_params.pt')

    scheduler.step()

====Epoch 1====


  7%|▋         | 201/2931 [00:44<09:56,  4.58it/s]

| epoch   1 |   200/ 2930 batches | lr 5.00 | ms/batch 222.81 | loss  9.20 | ppl  9916.78


 14%|█▎        | 401/2931 [01:28<09:15,  4.55it/s]

| epoch   1 |   400/ 2930 batches | lr 5.00 | ms/batch 217.64 | loss  8.03 | ppl  3063.10


 21%|██        | 601/2931 [02:12<08:26,  4.60it/s]

| epoch   1 |   600/ 2930 batches | lr 5.00 | ms/batch 222.96 | loss  7.28 | ppl  1454.47


 27%|██▋       | 801/2931 [02:56<07:44,  4.58it/s]

| epoch   1 |   800/ 2930 batches | lr 5.00 | ms/batch 216.70 | loss  6.96 | ppl  1053.15


 34%|███▍      | 1001/2931 [03:39<06:57,  4.63it/s]

| epoch   1 |  1000/ 2930 batches | lr 5.00 | ms/batch 218.78 | loss  6.68 | ppl   796.04


 41%|████      | 1201/2931 [04:23<06:17,  4.58it/s]

| epoch   1 |  1200/ 2930 batches | lr 5.00 | ms/batch 219.07 | loss  6.64 | ppl   767.70


 48%|████▊     | 1401/2931 [05:07<05:35,  4.56it/s]

| epoch   1 |  1400/ 2930 batches | lr 5.00 | ms/batch 218.52 | loss  6.60 | ppl   735.40


 55%|█████▍    | 1601/2931 [05:50<04:49,  4.60it/s]

| epoch   1 |  1600/ 2930 batches | lr 5.00 | ms/batch 217.76 | loss  6.56 | ppl   707.15


 61%|██████▏   | 1801/2931 [06:34<04:08,  4.55it/s]

| epoch   1 |  1800/ 2930 batches | lr 5.00 | ms/batch 217.84 | loss  6.42 | ppl   611.85


 68%|██████▊   | 2001/2931 [07:18<03:21,  4.61it/s]

| epoch   1 |  2000/ 2930 batches | lr 5.00 | ms/batch 218.01 | loss  6.42 | ppl   612.75


 75%|███████▌  | 2201/2931 [08:01<02:37,  4.64it/s]

| epoch   1 |  2200/ 2930 batches | lr 5.00 | ms/batch 216.98 | loss  6.27 | ppl   529.17


 82%|████████▏ | 2401/2931 [08:44<01:53,  4.65it/s]

| epoch   1 |  2400/ 2930 batches | lr 5.00 | ms/batch 217.14 | loss  6.37 | ppl   582.71


 89%|████████▊ | 2601/2931 [09:28<01:11,  4.61it/s]

| epoch   1 |  2600/ 2930 batches | lr 5.00 | ms/batch 217.81 | loss  6.35 | ppl   572.85


 96%|█████████▌| 2801/2931 [10:12<00:28,  4.61it/s]

| epoch   1 |  2800/ 2930 batches | lr 5.00 | ms/batch 220.02 | loss  6.26 | ppl   523.73


100%|██████████| 2931/2931 [10:40<00:00,  4.57it/s]


-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 657.30s | valid loss  6.37 | valid ppl   586.21
-----------------------------------------------------------------------------------------
====Epoch 2====


  7%|▋         | 201/2931 [00:43<09:49,  4.63it/s]

| epoch   2 |   200/ 2930 batches | lr 4.75 | ms/batch 219.19 | loss  6.27 | ppl   528.51


 14%|█▎        | 401/2931 [01:27<09:05,  4.64it/s]

| epoch   2 |   400/ 2930 batches | lr 4.75 | ms/batch 219.98 | loss  6.25 | ppl   518.40


 21%|██        | 601/2931 [02:10<08:21,  4.65it/s]

| epoch   2 |   600/ 2930 batches | lr 4.75 | ms/batch 215.73 | loss  6.06 | ppl   427.67


 27%|██▋       | 801/2931 [02:54<07:45,  4.58it/s]

| epoch   2 |   800/ 2930 batches | lr 4.75 | ms/batch 216.31 | loss  6.07 | ppl   434.11


 34%|███▍      | 1001/2931 [03:37<06:53,  4.66it/s]

| epoch   2 |  1000/ 2930 batches | lr 4.75 | ms/batch 215.39 | loss  5.94 | ppl   381.43


 41%|████      | 1201/2931 [04:20<06:14,  4.62it/s]

| epoch   2 |  1200/ 2930 batches | lr 4.75 | ms/batch 215.34 | loss  6.02 | ppl   411.76


 48%|████▊     | 1401/2931 [05:03<05:28,  4.66it/s]

| epoch   2 |  1400/ 2930 batches | lr 4.75 | ms/batch 215.44 | loss  6.05 | ppl   425.60


 55%|█████▍    | 1601/2931 [05:46<04:43,  4.69it/s]

| epoch   2 |  1600/ 2930 batches | lr 4.75 | ms/batch 214.47 | loss  6.07 | ppl   432.72


 61%|██████▏   | 1801/2931 [06:29<04:04,  4.62it/s]

| epoch   2 |  1800/ 2930 batches | lr 4.75 | ms/batch 214.31 | loss  5.95 | ppl   385.33


 68%|██████▊   | 2001/2931 [07:14<03:33,  4.36it/s]

| epoch   2 |  2000/ 2930 batches | lr 4.75 | ms/batch 225.88 | loss  6.00 | ppl   401.48


 75%|███████▌  | 2201/2931 [08:00<02:47,  4.35it/s]

| epoch   2 |  2200/ 2930 batches | lr 4.75 | ms/batch 230.89 | loss  5.86 | ppl   350.11


 82%|████████▏ | 2401/2931 [08:46<02:03,  4.28it/s]

| epoch   2 |  2400/ 2930 batches | lr 4.75 | ms/batch 229.90 | loss  5.98 | ppl   393.80


 89%|████████▊ | 2601/2931 [09:32<01:15,  4.39it/s]

| epoch   2 |  2600/ 2930 batches | lr 4.75 | ms/batch 230.71 | loss  5.98 | ppl   396.89


 96%|█████████▌| 2801/2931 [10:18<00:29,  4.35it/s]

| epoch   2 |  2800/ 2930 batches | lr 4.75 | ms/batch 230.92 | loss  5.91 | ppl   369.43


100%|██████████| 2931/2931 [10:48<00:00,  4.52it/s]


-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 664.63s | valid loss  6.23 | valid ppl   508.02
-----------------------------------------------------------------------------------------
====Epoch 3====


  7%|▋         | 201/2931 [00:46<10:16,  4.43it/s]

| epoch   3 |   200/ 2930 batches | lr 4.51 | ms/batch 231.14 | loss  5.96 | ppl   388.11


 14%|█▎        | 401/2931 [01:32<09:41,  4.35it/s]

| epoch   3 |   400/ 2930 batches | lr 4.51 | ms/batch 230.38 | loss  5.98 | ppl   395.91


 21%|██        | 601/2931 [02:18<09:00,  4.31it/s]

| epoch   3 |   600/ 2930 batches | lr 4.51 | ms/batch 230.18 | loss  5.76 | ppl   317.72


 27%|██▋       | 801/2931 [03:04<08:11,  4.34it/s]

| epoch   3 |   800/ 2930 batches | lr 4.51 | ms/batch 230.01 | loss  5.81 | ppl   334.19


 28%|██▊       | 818/2931 [03:08<08:00,  4.40it/s]

<p style="text-align: center; font-size:30px;">Evaluation</p>

#### Let's see some examples of how well our model has trained on the corpus of data.

In [13]:
model.load_state_dict(torch.load('./best_model_params.pt')) # load best model states

<All keys matched successfully>

In [41]:
def generate(prompt, max_seq_len, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
        
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    with torch.no_grad():
        for i in range(max_seq_len):
            src = torch.LongTensor([indices]).to(device)
            src_mask = generate_square_subsequent_mask(len(indices)).to(device)
            prediction = model(src.view(-1,1), src_mask)
            probs = torch.softmax(prediction[:, -1], dim=-1) 
            prediction = torch.multinomial(probs[-1], num_samples=1).item()    
            
            while prediction == vocab['<unk>']:
                prediction = torch.multinomial(probs[-1], num_samples=1).item()

            if prediction == vocab['<eos>']:
                break

            indices.append(prediction)

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices]
    return tokens

prompt = 'Think about what our children will think about the war. The war is an uncivilized form of humanity and the action we should take is'
max_seq_len = 30
seed = 0

print(' '.join(generate(prompt, max_seq_len, model, tokenizer, vocab, device, seed=None)))

think about what our children will think about the war . the war is an <unk> form of humanity and the action we should take is just over 60 years in europe . a medium or their outs wrote that they may be one , when his sudden life though himself of bad is no attempt


#### Considering we only took about 30 minutes for training, I believe we got to somewhere. 
#### Of course, Large Language Models take a lot more computational power and data to train. 
#### Hopefully, you understood what a transformer is and understood how easy it is to implement using PyTorch.