## Transformer in image

<img src="./assert/transformer.png" width="50%" height="50%" alt="transformer">

In [None]:
# Add musk function to multi-head attention

import torch
import torch.nn as nn
import math

class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.w_query = nn.Linear(d_model, d_model)
        self.w_key = nn.Linear(d_model, d_model)
        self.w_value = nn.Linear(d_model, d_model)
        self.attention_scores = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        seq_len = query.shape[0]
        # from input, the query, key, value will be simply input matrix, input, input, input.
        query = self.w_query(query)
        key = self.w_key(key)
        value = self.w_value(value)

        query = query.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)
        key = key.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)
        value = value.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)

        attention_scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(self.head_dim)
        
        if mask is not None:
            # The computation is pretty similar to the previous one:
            # we have seq_len of query ( sub matrix ), after processing query * key.T, we get
            # seq_len, (n_heads, n_heads), where n_head is the size of original attention matrix
            
            # Note:
            # Post mask, we don't want to have 0 as masked value,
            # because, softmax(0) = 1, which will make the attention score too high.
            # So, we use -inf to mask the value. The normalization will still remains.
            attention_scores.masked_fill(mask == 0, float("-inf"))

        attention = torch.matmul(torch.softmax(attention_scores, dim=-1), value)

        attention = attention.transpose(0, 1).contiguous().view(seq_len, self.d_model)
        return self.attention_scores(attention)
            
        
        

In [None]:
# We have a component for add & Norm, it pretty much means x + layer_norm(x)

class TransformerAddNorm(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)
    
    # Note the sublayer here is MultiHeadAttention
    def forward(self, x, sublayer):
        return self.layer_norm(x + sublayer(x))

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )
        
    def forward(self, x):
        return self.feed_forward(x)


In [None]:
# Encoder

class Encoder(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, max_seq_len, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.encoder_layers = nn.Sequential(
            MultiHeadAttention(d_model, n_heads),
            TransformerAddNorm(d_model),
            FeedForward(d_model, d_ff),
            TransformerAddNorm(d_model),
        )
    
    def forward(self, x):
        # x.shape = (batch_size, seq_len, d_model)
        return self.encoder_layers(x)
        

In [None]:
class Decoder(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, max_seq_len, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.decoder_layers = nn.Sequential(
            MultiHeadAttention(d_model, n_heads),
            TransformerAddNorm(d_model),
            MultiHeadAttention(d_model, n_heads),
            TransformerAddNorm(d_model),
            FeedForward(d_model, d_ff),
            TransformerAddNorm(d_model),
        )
    
    def forward(self, x):
        return self.decoder_layers(x)