<h1 style="font-size:60px;"><center>Transformer</center></h1>

<div style="text-align:center"><img src="https://i.kym-cdn.com/entries/icons/original/000/036/585/Attention_is_all_you_need.jpg" /></div>

# Prerequisites

+ Python
+ Working knowledge of Pytorch
+ Neural Networks
+ Seq2Seq
+ Attention

# Where did it come from?

&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;RNN &emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; BERT

&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;LSTM & GRU  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ⬅ **2017 - [Attention Is All You Need](https://arxiv.org/pdf/1706.03762)** ➡️ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; GPT

&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;Attention &emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; Vision Transformer

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

# The Architutre

<div style="text-align:center">
  <img src="https://miro.medium.com/v2/resize:fit:856/1*ZCFSvkKtppgew3cc7BIaug.png" width="500px"/>
</div>


# Multi-Head Attention

## Self Attention

<div style="text-align:center">
  <img src="https://i.imgur.com/thdHvQx.png" width="400px"/>
</div>

In [None]:
class DummyAttention:
    def __init__(self):
        self.data = np.random.normal(size=(50,50))

        self.weight_key = np.random.normal(size=(50,50))
        self.weight_query = np.random.normal(size=(50,50))
        self.weight_value = np.random.normal(size=(50,50))

    def key(self):
        return self.weight_key @ self.data

    def query(self):
        return self.weight_query @ self.data

    def value(self):
        return self.weight_value @ self.data

    def forward(self):
        k = self.key()
        q = self.query()
        v = self.value()

        scores = k @ q.T
        scores = scores / k.shape[0]**0.5
        scores = np.exp(scores) / np.sum(scores)


        attn = scores @ v
        return attn


In [None]:
a = DummyAttention()
att = a.forward()
print(att.shape)
att

(50, 50)


array([[-8.30195779e+65, -8.89317968e+66,  7.07996198e+66, ...,
        -2.99217010e+66, -4.35827153e+66, -2.28896438e+65],
       [ 4.63385359e+61, -3.15115673e+61, -1.77858341e+62, ...,
         3.21194788e+61,  9.01899169e+60, -1.19295385e+62],
       [ 1.58280149e+99,  8.45229339e+98,  1.02801068e+99, ...,
        -6.74607218e+98, -5.04194860e+98,  4.05223682e+98],
       ...,
       [-1.66695826e+77, -4.26879902e+76, -3.32295849e+77, ...,
        -6.45407153e+76, -5.07158546e+77,  1.05065407e+77],
       [ 2.54000235e+66, -7.30570093e+66, -5.51913866e+66, ...,
         1.32398738e+66, -1.83213148e+66, -5.29677704e+66],
       [ 7.89281451e+57, -9.33433065e+57, -1.23080627e+58, ...,
        -5.40586441e+57, -3.24482361e+58,  7.45437755e+57]])

## Multi-Head

<div style="text-align:center">
  <img src="https://miro.medium.com/v2/resize:fit:1010/0*0KPEV8QidHkteKeY.png" width="500px"/>
</div>

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()

        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

        self.out = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        # Perform linear projections
        Q = self.query(q)  # (batch_size, seq_length, d_model)
        K = self.key(k)    # (batch_size, seq_length, d_model)
        V = self.value(v)  # (batch_size, seq_length, d_model)

        # Split the projections into multiple heads and reshape
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)  # (batch_size, num_heads, seq_length, d_k)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)  # (batch_size, num_heads, seq_length, d_k)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)  # (batch_size, num_heads, seq_length, d_k)

        # Calculate attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)  # (batch_size, num_heads, seq_length, seq_length)

        if mask is not None:
            mask = torch.broadcast_to(mask.unsqueeze(1), (batch_size, self.num_heads, mask.shape[-1],mask.shape[-1]))
            scores = scores.masked_fill(mask == 0, -1e20)

        attention_weights = F.softmax(scores, dim=-1)  # (batch_size, num_heads, seq_length, seq_length)

        # Apply attention weights to the values
        attention_output = torch.matmul(attention_weights, V)  # (batch_size, num_heads, seq_length, d_k)

        # Concatenate the heads and reshape
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)  # (batch_size, seq_length, d_model)

        # Final linear layer
        output = self.out(attention_output)  # (batch_size, seq_length, d_model)

        return output

# Encoder

## Feed Forward

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

## Encoder layer

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

## Embedding layer + Positional Embedding

In [None]:
class Encoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, vocab_size, max_seq_len, dropout=0.1):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = self._generate_positional_encoding(max_seq_len, d_model)
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)

    def _generate_positional_encoding(self, max_seq_len, d_model):
        pos = torch.arange(max_seq_len).unsqueeze(1)
        i = torch.arange(d_model).unsqueeze(0)
        angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / torch.tensor(d_model, dtype=torch.float32))
        pos_encoding = pos * angle_rates
        pos_encoding[:, 0::2] = torch.sin(pos_encoding[:, 0::2])
        pos_encoding[:, 1::2] = torch.cos(pos_encoding[:, 1::2])
        return pos_encoding.unsqueeze(0)

    def forward(self, x, mask=None):
        seq_len = x.size(1)
        x = self.embedding(x)
        x = x + self.pos_encoding[:, :seq_len, :]
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, mask)
        return x

# Decoder

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.enc_dec_attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, tgt_mask=None, memory_mask=None):
        self_attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_output))

        enc_dec_attn_output = self.enc_dec_attn(x, enc_output, enc_output, memory_mask)
        x = self.norm2(x + self.dropout(enc_dec_attn_output))

        ff_output = self.ff(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, vocab_size, max_seq_len, dropout=0.1):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = self._generate_positional_encoding(max_seq_len, d_model)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)

    def _generate_positional_encoding(self, max_seq_len, d_model):
        pos = torch.arange(max_seq_len).unsqueeze(1)
        i = torch.arange(d_model).unsqueeze(0)
        angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / torch.tensor(d_model, dtype=torch.float32))
        pos_encoding = pos * angle_rates
        pos_encoding[:, 0::2] = torch.sin(pos_encoding[:, 0::2])
        pos_encoding[:, 1::2] = torch.cos(pos_encoding[:, 1::2])
        return pos_encoding.unsqueeze(0)

    def forward(self, x, enc_output, tgt_mask=None, memory_mask=None):
        seq_len = x.size(1)
        x = self.embedding(x)
        x = x + self.pos_encoding[:, :seq_len, :]
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, enc_output, tgt_mask, memory_mask)
        return x

# Put it all together

## Transformer

In [None]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, d_ff, num_layers, max_seq_len, dropout=0.1):
        super(Transformer, self).__init__()
        self.encoder = Encoder(d_model, num_heads, d_ff, num_layers, src_vocab_size, max_seq_len, dropout)
        self.decoder = Decoder(d_model, num_heads, d_ff, num_layers, tgt_vocab_size, max_seq_len, dropout)
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
        enc_output = self.encoder(src, src_mask)
        dec_output = self.decoder(tgt, enc_output, tgt_mask, memory_mask)
        output = self.fc_out(dec_output)
        return output

## Masks

In [None]:
def create_src_mask(src, pad_idx):
    src_mask = (src != pad_idx).unsqueeze(-2)
    return src_mask

def create_tgt_mask(tgt, pad_idx):
    tgt_pad_mask = (tgt != pad_idx).unsqueeze(-2)
    tgt_len = tgt.size(1)
    tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
    tgt_mask = tgt_pad_mask & tgt_sub_mask
    return tgt_mask

In [None]:
src_vocab_size = 10000
tgt_vocab_size = 10000
d_model = 512
num_heads = 8
d_ff = 2048
num_layers = 6
max_seq_len = 100
dropout = 0.1
pad_idx = 0

In [None]:
transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, d_ff, num_layers, max_seq_len, dropout)
src = torch.randint(0, src_vocab_size, (32, max_seq_len))
tgt = torch.randint(0, tgt_vocab_size, (32, max_seq_len))
print(src.shape, tgt.shape)
src_mask = create_src_mask(src, pad_idx)
tgt_mask = create_tgt_mask(tgt, pad_idx)
output = transformer(src, tgt, src_mask, tgt_mask)

print(output.shape)

torch.Size([32, 100]) torch.Size([32, 100])
torch.Size([32, 100, 10000])


## Tokenization

Step to convert text into words.

Two simplest approches are:

| Method              | Vocab Size | Sequence lengths |
| :---------------- | :------: | ----: |
| Number for each char        |   Small   | Very long |
| Number for each word           |   Very Large   | Contained (same as text) |

Modern LLM use peicewise encoders, which are somewhere in between these two approches. Exact working is out of the scope of this lecture. Two popular approches are:-
+ [Byte Pair Tokenization](https://www.youtube.com/watch?v=HEikzVL-lZU)
+ [WordPiece Tokenization](https://www.youtube.com/watch?v=qpv6ms_t_1A)

# Why is transformer so revolutionary?

+ Very efficent compute wise
+ [The Bitter Lesson](http://www.incompleteideas.net/IncIdeas/BitterLesson.html) ==> Transformer is one of the greatest examples of this.
+ Generalizable across domains

# Limitations of transformers

+ High memory usage
+ Large compute and data requirements
+ Limitations of token lengths

# How it leads to fancy stuff like BERT and GPT?

<div style="text-align:center">
  <img src="https://miro.medium.com/v2/resize:fit:434/1*D5xg0yz7YzBSzS_F1efLAA.png" width="500px"/>
</div>