<a href="https://colab.research.google.com/github/ymoslem/PyTorchNLP/blob/main/Ex4-NMT-Transformer-from-scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# NMT with Transformer Arcticture

* **Paper:** <a href="https://arxiv.org/abs/1706.03762">Attention is all you need</a>

<center><img src="https://drive.google.com/uc?id=1LjdP4THlAZryfpP51FfDArS0qX9kMqtf" width="45%"></center>

In [None]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embedding_size, heads):
        super(SelfAttention, self).__init__()
        self.embedding_size = embedding_size
        self.heads = heads
        self.head_dim = embedding_size // heads
        
        assert (self.head_dim * heads == embedding_size), "Embed size must be dividable by heads"
        
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        
        self.fc_out = nn.Linear(heads*self.head_dim, embedding_size)
        
    def forward(self, values, keys, query, mask):
        N = query.shape[0]  # batch
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]  # source/target len
        
        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)
        
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd, nkhd -> nhqk" , [queries, keys])
        # queries shape: (N, query_len, heads, head_dim)
        # keys shape: (N, keys_len, heads, head_dim)
        # energy shape: (N, heads, query_len, keys_len)
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float(-1e20))

        attention = torch.softmax(energy / (self.embedding_size ** (1/2)), dim=3)
        
        out = torch.einsum("nhql, nlhd -> nqhd ", [attention, values]).reshape(
            N, query_len, self.heads*self.head_dim
        )
        # attention shape: (N, heads, query_len, keys_len)
        # values shape: (N, values_len, heads, head_dim)
        # after einsum: (N, query_len, heads, head_dim) then flatten last two dimintions
        
        out = self.fc_out(out)
        
        return out

# Compare PostNorm to PreNorm: https://arxiv.org/pdf/2002.04745.pdf
class TransformerBlock(nn.Module):
    def __init__(self, embedding_size, heads, dropout, forward_expantion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embedding_size, heads)
        self.norm1 = nn.LayerNorm(embedding_size)
        self.norm2 = nn.LayerNorm(embedding_size)
        
        self.feedforward = nn.Sequential(
            nn.Linear(embedding_size, forward_expantion*embedding_size),
            nn.ReLU(),
            nn.Linear(embedding_size*forward_expantion, embedding_size)
        )
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        
        # adding query and x serves the purpose of a "skip connection"
        x = self.dropout(self.norm1(attention + query))
        forward = self.feedforward(x)
        out = self.dropout(self.norm2(forward + x))
        
        return out

class Encoder(nn.Module):
    def __init__(self,
                 src_vocab_size,
                 embedding_size,
                 num_layers,
                 heads,
                 device,
                 forward_expantion,
                 dropout,
                 max_length
    ):
        super(Encoder, self).__init__()
        self.embedding_size = embedding_size
        self.device = device
        self.words_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.position_embedding = nn.Embedding(max_length, embedding_size)
        
        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embedding_size,
                    heads,
                    dropout=dropout,
                    forward_expantion=forward_expantion
                )
            for _ in range(num_layers)
            ]
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        N, seq_length = x.shape
        
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        
        out = self.dropout(self.words_embedding(x) + self.position_embedding(positions))
        
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out
        
        
class DecoderBlock(nn.Module):
    def __init__(self, embedding_size, heads, forward_expantion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embedding_size, heads)
        self.norm = nn.LayerNorm(embedding_size)
        self.transformer_block = TransformerBlock(embedding_size, heads, dropout, forward_expantion)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, value, key, src_mask, tgt_mask):
        attention = self.attention(x, x, x, tgt_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        
        return query
    

class Decoder(nn.Module):
    def __init__(self,
                 tgt_vocab_size,
                 embedding_size,
                 num_layers,
                 heads,
                 device,
                 forward_expantion,
                 dropout,
                 max_length
    ):
        super(Decoder, self).__init__()
        
        self.device = device
        self.word_embedding = nn.Embedding(tgt_vocab_size, embedding_size)
        self.position_embedding = nn.Embedding(max_length, embedding_size)
        
        self.layers = nn.ModuleList(
            [DecoderBlock(embedding_size, heads, forward_expantion, dropout, device)
            for _ in range(num_layers)]
        )
        
        self.fc_out = nn.Linear(embedding_size, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, tgt_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
        
        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, tgt_mask)
        
        out = self.fc_out(x)
        
        return out

class Transformer(nn.Module):
    def __init__(self,
                 src_vocab_size,
                 tgt_vocab_size,
                 src_pad_idx,
                 tgt_pad_idx,
                 embedding_size=256,
                 num_layers=6,
                 forward_expantion=4,
                 heads=8,
                 dropout=0,
                 device="cuda",
                 max_length=100
    ):
        super(Transformer, self).__init__()
        
        self.encoder = Encoder(
            src_vocab_size,
            embedding_size,
            num_layers,
            heads,
            device,
            forward_expantion,
            dropout,
            max_length
        )
        
        self.decoder = Decoder(
            tgt_vocab_size,
            embedding_size,
            num_layers,
            heads,
            device,
            forward_expantion,
            dropout,
            max_length
        )
        
        self.src_pad_idx = src_pad_idx
        self.tgt_pad_idx = tgt_pad_idx
        self.device = device
        
    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # src_mask shape: (N, 1, 1, src_length)
        
        return src_mask.to(self.device)
    def make_tgt_mask(self, tgt):
        N, tgt_length = tgt.shape
        tgt_mask = torch.tril(torch.ones((tgt_length, tgt_length))).expand(
            N, 1, tgt_length, tgt_length
        )
        
        return tgt_mask.to(self.device)
    
    def forward(self, src, tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_tgt_mask(tgt)
        
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(tgt, enc_src, src_mask, tgt_mask)
        
        return out

In [None]:
# Mock test

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(device)
tgt = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

src_pad_idx = 0
tgt_pad_idx = 0
src_vocab_size = 10
tgt_vocab_size = 10

model = Transformer(src_vocab_size, tgt_vocab_size, src_pad_idx, tgt_pad_idx).to(device)

out = model(x, tgt[:, :-1])
print(out.shape)

torch.Size([2, 7, 10])
