# Transformer From Scratch

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim

import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import nltk
from nltk.tokenize import word_tokenize

## Create Tokenizer & Vocabulary

In [None]:
import pickle

filename = 'english-german-both.pkl'

with open(filename, 'rb') as file:
    data = pickle.load(file)

In [None]:
print(f"English: '{data[0,0]}'")
print(f"German: '{data[0, 1]}'")

In [None]:
from tqdm import tqdm
import nltk
from nltk.tokenize import word_tokenize

# Defining a set that will contain every word present in each language.
en_vocab_set = set({"<sos>", "<eos>"})
ge_vocab_set = set({"<sos>", "<eos>"})

# Creating the set
for i in range(len(data)):
    en_tokens = word_tokenize(data[i, 0])
    ge_tokens = word_tokenize(data[i, 1])

    for tok in en_tokens:
        en_vocab_set.add(tok)
    for tok in ge_tokens:
        ge_vocab_set.add(tok)

# Defining a mapping from token to integer
en_vocab = {"<pad>": 0}
en_vocab_reversed = {0: "<pad>"}
ge_vocab = {"<pad>": 0}
ge_vocab_reversed = {0: "<pad>"}

for i, word in enumerate(en_vocab_set):
    en_vocab[word] = i+1
    en_vocab_reversed[i+1] = word

for i, word in enumerate(ge_vocab_set):
    ge_vocab[word] = i+1
    ge_vocab_reversed[i+1] = word

In [None]:
max_tok_len = 0
processed_data_en = []
processed_data_ge = []
for ex in data:
    en_sentence = ["<sos>"] + word_tokenize(ex[0]) + ["<eos>"]
    ge_sentence = ["<sos>"] + word_tokenize(ex[1]) + ["<eos>"]

    en_sent_ints = [en_vocab[tok] for tok in en_sentence]
    ge_sent_ints = [ge_vocab[tok] for tok in ge_sentence]

    processed_data_en.append(torch.tensor(en_sent_ints))
    processed_data_ge.append(torch.tensor(ge_sent_ints))
    
    max_tok_len = max(max_tok_len, len(en_sent_ints), len(ge_sent_ints))

In [None]:
padded_data_en = []
padded_data_ge = []

for ex_en, ex_ge in zip(processed_data_en, processed_data_ge):
    padded_data_en.append(F.pad(ex_en, (0, max_tok_len-len(ex_en))))
    padded_data_ge.append(F.pad(ex_ge, (0, max_tok_len-len(ex_ge))))

In [None]:
idx = 745

print('--- English ---')
print(padded_data_en[idx])
print([en_vocab_reversed[val.item()] for val in padded_data_en[idx]])

print()
print('--- German ---')

print(padded_data_ge[idx])
print([ge_vocab_reversed[val.item()] for val in padded_data_ge[idx]])

## Transformer Modules

### Input & Output Embeddings

In [None]:
class InputEmbedding(nn.Module):
    def __init__(self, vocab_size, dmodel):
        super().__init__()
        self.embedder = nn.Embedding(vocab_size, dmodel)

    def forward(self, x):
        return self.embedder(x)


class OutputEmbedding(nn.Module):
    def __init__(self, vocab_size, dmodel):
        super().__init__()
        self.embedder = nn.Embedding(vocab_size, dmodel)

    def forward(self, x):
        return self.embedder(x)

### Scaled Dot Product Attention 

In [None]:
class ScaledDotAttention(nn.Module):
    def __init__(self, dmodel, dk):
        super().__init__()
        self.Wq = nn.Linear(dmodel, dk)
        self.Wk = nn.Linear(dmodel, dk)
        self.Wv = nn.Linear(dmodel, dk)

        self.dk = dk

    def apply_masks(self, query_key_matrix, padding_mask, causal_mask):
        masked_query_key_matrix = torch.where(padding_mask == 1, 
                                              torch.full_like(query_key_matrix, -1e9), 
                                              query_key_matrix)

        if causal_mask is not None:
            masked_query_key_matrix = torch.where(causal_mask == 1, 
                                                  torch.full_like(masked_query_key_matrix, -1e9), 
                                                  masked_query_key_matrix)

        return masked_query_key_matrix

    def forward(self, Qx, Kx, Vx, padding_mask, causal_mask=None):
        Q = self.Wq(Qx)
        K = self.Wk(Kx)
        V = self.Wv(Vx)
        
        query_key_matrix = torch.matmul(Q, torch.transpose(K, 1, 2)) / np.sqrt(self.dk)
        masked_query_key_matrix = self.apply_masks(query_key_matrix, padding_mask, causal_mask)
        key_query_softmax = F.softmax(masked_query_key_matrix, dim=-1)

        return torch.matmul(key_query_softmax, V)

### Multi Headed Attention

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, num_heads, dmodel, dk):
        super().__init__()
        self.attention_heads = nn.ModuleList([ScaledDotAttention(dmodel, dk) for i in range(num_heads)])

        self.num_heads = num_heads
        self.Wo = nn.Linear(self.num_heads * dk, dmodel)

    def forward(self, Qx, Kx, Vx, padding_mask, causal_mask=None):

        attention_results = [self.attention_heads[i](Qx, Kx, Vx, padding_mask, causal_mask) for i in range(self.num_heads)]
        concat_results = torch.cat(attention_results, dim=2)
        
        return self.Wo(concat_results)

### Feed Forward

In [None]:
class FeedForward(nn.Module):
    def __init__(self, dmodel, dff):
        super().__init__()
        self.inner_linear = nn.Linear(dmodel, dff)
        self.outer_linear = nn.Linear(dff, dmodel)

    def forward(self, x):
        return self.outer_linear(F.relu(self.inner_linear(x)))   

### Layer Norm

In [None]:
class AddAndNorm(nn.Module):
    def __init__(self, dmodel):
        super().__init__()
        self.layer_norm = nn.LayerNorm(dmodel)

    def forward(self, x, skipped_x):
        return self.layer_norm(x + skipped_x)

### Positional Encoding

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, dmodel):
        super().__init__()
        pass
        
    def forward(self, x):
        batch_size, tok_len, dmodel = x.shape # (B x num_tok x dmodel)
        positional_encodings = torch.zeros_like(x)

        for pos in range(tok_len):
            for i in range(dmodel):
                if i % 2 == 0:
                    positional_encodings[:, pos, i] = np.sin(pos / (10000 ** (2*i/ dmodel)))
                else:
                    positional_encodings[:, pos, i] = np.cos(pos / (10000 ** (2*i/ dmodel)))

        return x + positional_encodings

### Causal Mask

In [None]:
def create_causal_mask(batch_size, max_tok_len):
    matrix = torch.ones((max_tok_len, max_tok_len))
    matrix = torch.triu(matrix) - torch.eye(max_tok_len)
    return matrix.unsqueeze(0).repeat(batch_size, 1, 1)

### Padding Mask

In [None]:
def create_padding_mask(x, max_tok_len):
    batch_size = x.shape[0]
    
    mask = (x == 0)
    expanded_mask = mask.unsqueeze(1).expand(batch_size, max_tok_len, max_tok_len)
    full_mask = expanded_mask | torch.transpose(expanded_mask, 1, 2)

    return full_mask.float()  

## Building Encoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, num_heads, dmodel, dk, dff):
        super().__init__()
        self.multi_headed_attention = MultiHeadedAttention(num_heads, dmodel, dk)
        self.add_norm_1 = AddAndNorm(dmodel)
        self.add_norm_2 = AddAndNorm(dmodel)
        self.feed_forward = FeedForward(dmodel, dff)

    def forward(self, x, padding_mask):
        x1 = self.multi_headed_attention(x, x, x, padding_mask)
        x1 = self.add_norm_1(x1, x)

        x2 = self.feed_forward(x1)
        x2 = self.add_norm_2(x2, x1)

        return x2

## Building Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, num_heads, dmodel, dk, dff):
        super().__init__()
        self.masked_multi_headed_attention = MultiHeadedAttention(num_heads, dmodel, dk)
        self.multi_headed_attention = MultiHeadedAttention(num_heads, dmodel, dk)
        self.add_norm_1 = AddAndNorm(dmodel)
        self.add_norm_2 = AddAndNorm(dmodel)
        self.add_norm_3 = AddAndNorm(dmodel)
        self.feed_forward = FeedForward(dmodel, dff)
        

    def forward(self, dec_input, enc_out, padding_mask, causal_mask):
        x1 = self.masked_multi_headed_attention(dec_input, dec_input, dec_input, padding_mask, causal_mask)
        x1 = self.add_norm_1(x1, dec_input)

        x2 = self.multi_headed_attention(x1, enc_out, enc_out, padding_mask)
        x2 = self.add_norm_2(x2, x1)

        x3 = self.feed_forward(x2)
        x3 = self.add_norm_3(x3, x2)

        return x3      

## Building Transformer

In [None]:
class Transformer(nn.Module):
    def __init__(self, dk, num_heads, dmodel, dff, num_blocks, src_vocab_size, tgt_vocab_size, max_tok_len):
        super().__init__()
        self.dk = dk
        self.num_heads = num_heads
        self.dmodel = dmodel
        self.dff = dff
        self.num_blocks = num_blocks
        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        self.max_tok_len = max_tok_len
        
        self.create_layers()

    
    def create_layers(self):
        self.positional_encoding = PositionalEncoding(self.dmodel)
        self.encoder_embedding = InputEmbedding(self.src_vocab_size, self.dmodel)
        self.decoder_embedding = OutputEmbedding(self.tgt_vocab_size, self.dmodel)
        
        self.encoders = nn.ModuleList([Encoder(self.num_heads, self.dmodel, self.dk, self.dff) for _ in range(self.num_blocks)])
        self.decoders = nn.ModuleList([Decoder(self.num_heads, self.dmodel, self.dk, self.dff) for _ in range(self.num_blocks)])

        self.final_linear = nn.Linear(self.dmodel, self.tgt_vocab_size)

    
    def forward(self, enc_input, dec_input, padding_mask, causal_mask):
        encoder_inputs = self.positional_encoding(self.encoder_embedding(enc_input))

        for i in range(self.num_blocks):
            encoder_inputs = self.encoders[i](encoder_inputs, padding_mask)

        encoder_outputs = encoder_inputs

        decoder_inputs = self.positional_encoding(self.decoder_embedding(dec_input))
        for i in range(self.num_blocks):
            decoder_inputs = self.decoders[i](decoder_inputs, encoder_outputs, padding_mask, causal_mask)

        linear_proj = self.final_linear(decoder_inputs)
        return linear_proj

## Build Dataset

In [None]:
batch_size = 32

In [None]:
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader

dataset = TensorDataset(torch.stack(padded_data_en), torch.stack(padded_data_ge))

train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

## Training Loop

In [None]:
configs = {
    "dk": 64,
    "num_heads": 8,
    "dmodel": 512,
    "dff": 2048,
    "num_blocks": 1,
    "src_vocab_size": len(en_vocab),
    "tgt_vocab_size": len(ge_vocab),
    "max_tok_len": max_tok_len - 1,
}

In [None]:
import torch

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Transformer(**configs).to(device)
model.train()

optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss(ignore_index=0).to(device)

causal_mask = create_causal_mask(batch_size, max_tok_len-1).to(device)
losses = []

num_epochs = 10
for epoch in range(num_epochs):
    total_loss = 0
    for batch_idx, batch in enumerate(train_dataloader):

        src, trg = batch
        src, trg = src.to(device), trg.to(device)
        
        padding_mask = create_padding_mask(src[:,1:], configs["max_tok_len"]).to(device)

        output = model(src[:,1:], trg[:,:-1], padding_mask, causal_mask)
        loss = loss_fn(output.transpose(1, 2), trg[:,1:])
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 50 == 0:
            print(f'Epoch {epoch} Batch {batch_idx} Loss {loss.item()}')

    losses.append(total_loss / len(train_dataloader))
    print(f'Epoch {epoch} Average Loss {total_loss / len(train_dataloader)}')
