In [1]:
import torch
import math
import torch.nn as nn
from typing_extensions import Annotated


class PositionalEncoding(nn.Module):
    def __init__(self,
                 max_len: Annotated[int, "It means how many no of words are there in a sequence"],
                 d_model: Annotated[int, "It tells in how many dimension each and every word represents"]) -> None:
        super().__init__()
        positional_encoder = torch.zeros(
            max_len, d_model)  # (max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(
            1).float()  # (max_len, 1)

        div_term = torch.exp(torch.arange(0, d_model, 2).float(
        ) * (-math.log(10000.0) / d_model))  # (d_model//2,)

        positional_encoder[:, 0::2] = torch.sin(position * div_term)
        positional_encoder[:, 1::2] = torch.cos(position * div_term)

        # Shape it to (1, max_len, d_model) for broadcasting with input: (batch_size, seq_len, d_model)
        positional_encoder = positional_encoder.unsqueeze(0)

        self.register_buffer("positional_encoder", positional_encoder)

    def forward(self, input_data: torch.Tensor) -> torch.Tensor:
        """
        input_data: shape (batch_size, seq_len, d_model)
        returns: same shape with positional encoding added
        """
        seq_len = input_data.size(1)
        return input_data + self.positional_encoder[:, :seq_len]


In [12]:
import math
import torch
import torch.nn as nn
from typing import List, Annotated

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: Annotated[int, "No of self attention needed"],
                 embed_dim: Annotated[int, "dimension of each word"],
                 seq_length: Annotated[int, "Length of sentence after padding"],
                 bias : Annotated[bool, "Required bias during trining"] = False,
                 mask: Annotated[bool, "normal MHA or masked MHA?"] = False) -> None:
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim % num_heads != 0"
        self.num_heads = num_heads
        self.seq_length = seq_length
        self.embed_dim = embed_dim
        self.head_dim = self.embed_dim // self.num_heads
        self.wq = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.wk = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.wv = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.output_projection = nn.Linear(self.embed_dim, self.embed_dim)
        self.require_mask = mask
        print("All parameters are set for multihead attention")

    def forward(self, batched_input_data:Annotated[torch.Tensor, "batch of data from the input data"]) -> torch.Tensor:
        batch = batched_input_data.size(0)
        q = self.wq(batched_input_data)
        k = self.wk(batched_input_data)
        v = self.wv(batched_input_data)

        # Split the q, k, v(embed_dim) dimension as (num_head, embed_dim / num_head)
        q = q.reshape(batch, self.seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.reshape(batch, self.seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.reshape(batch, self.seq_length, self.num_heads, self.head_dim).transpose(1, 2)


        # Calculate Attention
        k_transpose = k.transpose(-2, -1)
        score = (q @ k_transpose) / math.sqrt(self.head_dim)
        mask = torch.triu(torch.ones(self.seq_length, self.seq_length), diagonal=1).bool() if self.require_mask else torch.zeros(self.seq_length, self.seq_length).bool()

        # Anyhow broadcasting works no need of unsqeeze but its good practice to 
        # avoid broadcasting in Attentions, but clearly this step is optional
        mask = mask.unsqueeze(0).unsqueeze(0)
        score = score.masked_fill(mask, float("-inf"))
        attention_score = torch.softmax(score, dim=-1)
        attention = attention_score @ v

        # concat output of all heads
        attention = attention.transpose(1, 2)
        attention = attention.reshape(batch, self.seq_length, self.embed_dim)

        # Since they are simple concatination to acutally mix all heads details we need a linear layer
        
        mha_output = self.output_projection(attention)
        return mha_output
    
    # def __call__(self, batched_input_data:Annotated[torch.Tensor, "batch of data from the input data"]):
    #     return self.forward(batched_input_data=batched_input_data)

In [3]:
import math
import torch
import torch.nn as nn

class LayerNormalization(nn.Module):
    def __init__(self, embed_dim:int, eps:float = 1e-9) -> None:
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(embed_dim)).float()
        self.beta = nn.Parameter(torch.ones(embed_dim)).float()
        self.eps = eps

    def forward(self, input_data:torch.Tensor) -> torch.Tensor:
        # Assume input dim(2, 3, 6)
        mean = torch.mean(input_data, dim=-1, keepdim=True) # (2, 3, 1)
        std = torch.std(input_data, dim=-1, keepdim=True) # (2, 3, 1)

        # to normalize (2, 3, 6) - (2, 3, 1) = (2, 3, 6) due to broadcasting
        normalized_input_data = (input_data - mean) / (std + self.eps)

        # some weights do not require normalized output so alpha learnable parameter is introduced 
        return self.alpha * normalized_input_data + self.beta
        

In [17]:
import math
import torch
import torch.nn as nn
from typing import List, Annotated

class FeedForward(nn.Module):
    def __init__(self, embed_dim:int, hidden_dim:int, dropout:float = 0.1,  bias:bool = True):
        super().__init__()
        self.w1 = nn.Linear(embed_dim, hidden_dim, bias=bias)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        self.w2 = nn.Linear(hidden_dim, embed_dim, bias=bias)

    def forward(self, input_data:torch.Tensor) -> torch.Tensor:
        layer = self.w1(input_data)
        output = self.relu(layer)
        dropout = self.dropout(output)
        layer = self.w2(dropout)
        
        return layer


### Suppose we have a dataset (rows = 100, seq_len = 10, embed_dim = 8) 

In [None]:

input_data = torch.arange(100 * 10 * 8).reshape(100, 10, 8)
input_data.shape

torch.Size([100, 10, 8])

In [7]:
batch = 2
seq_len = 10
embed_dim = 8
num_heads = 2
batch_data = input_data[:batch]
batch_data.shape

torch.Size([2, 10, 8])

### Add postional encoding to all batch

In [8]:
pe = PositionalEncoding(seq_len, embed_dim)
positional_encoded_batch_data = pe(batch_data)
positional_encoded_batch_data.shape

torch.Size([2, 10, 8])

### Now we have positional encoded data its time to send this data in MHA

In [13]:
mha = MultiHeadAttention(num_heads=num_heads, seq_length=seq_len, embed_dim=embed_dim)
mha_output_for_batched_data = mha(positional_encoded_batch_data)

All parameters are set for multihead attention


### Apply Add & Norm

In [16]:
add_output_bach_data = mha_output_for_batched_data + positional_encoded_batch_data
ln = LayerNormalization(embed_dim=embed_dim)
layer_normalized_batch_data = ln(add_output_bach_data)

layer_normalized_batch_data.shape

torch.Size([2, 10, 8])

### Add this to a FFN to capture the non-linearity in the data

In [19]:
ffn = FeedForward(embed_dim=embed_dim, hidden_dim=1024)
ffn_batched_data = ffn(layer_normalized_batch_data)

### Get the encoder output after add & Norm

In [None]:
ffn_layered_batch_data = ffn_batched_data + layer_normalized_batch_data
encoder_output = ln(ffn_layered_batch_data)

### This is how encoder block works but since we need to do this Nx = 8 as mentioned in the paper lets implement this using loop

In [None]:
Nx = 8
pe = PositionalEncoding(10, 8)
mha = MultiHeadAttention(num_heads=num_heads, embed_dim=embed_dim, seq_length=seq_len)
ln1 = LayerNormalization(embed_dim=embed_dim)
ln2 = LayerNormalization(embed_dim=embed_dim)
ffn = FeedForward(embed_dim=embed_dim, hidden_dim=1024)

# This is for birst batch of the data
positional_encoded_batch_data = pe(batch_data)
for i in range(Nx):
    mha_output_for_batched_data = mha(positional_encoded_batch_data)
    layer_normalized_batch_data = ln1(mha_output_for_batched_data + positional_encoded_batch_data)
    ffn_batched_data = ffn(layer_normalized_batch_data)
    encoder_output = ln2(ffn_batched_data + layer_normalized_batch_data)
    positional_encoded_batch_data = encoder_output

### we made a mistake
- every encoder block in Nx should have separate MHA, FFN layer_norm parameters but we continued 
updating same weights.

- to make separate parameters either we intalize and keep track of everything(unnecary code)

- use nn.ModuleList to keep track of all parameter for each encoder

In [27]:
class EncoderBlock(nn.Module):
    def __init__(self, seq_len:int, embed_dim:int, hidden_dim:int,num_heads:int, dropout = 0.1, bias:bool = False):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads=num_heads, embed_dim=embed_dim, seq_length=seq_len)
        self.ln1 = LayerNormalization(embed_dim=embed_dim)
        self.ln2 = LayerNormalization(embed_dim=embed_dim)
        self.ffn = FeedForward(embed_dim=embed_dim, hidden_dim=hidden_dim, dropout=dropout, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        # Multi-head attention + residual + norm
        x_mha = self.mha(x)
        x = self.ln1(x + self.dropout(x_mha))

        # Feedforward + residual + norm
        x_ffn = self.ffn(x)
        x = self.ln2(x + self.dropout(x_ffn))
        return x



class Encoder(nn.Module):
    def __init__(self, Nx, embed_dim, seq_len, num_heads, ff_hidden_dim, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            EncoderBlock(seq_len, embed_dim, ff_hidden_dim, num_heads, dropout)
            for _ in range(Nx)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x



### Test with above input

In [30]:
input_data = torch.arange(100 * 10 * 8).reshape(100, 10, 8).float()
batch = 2
seq_len = 10
embed_dim = 8
num_heads = 2
batch_data = input_data[:batch]
batch_data.shape

torch.Size([2, 10, 8])

In [28]:
encoder = Encoder(Nx=8,
                  embed_dim=embed_dim,
                  seq_len=seq_len,
                  ff_hidden_dim=1024,
                  num_heads=2)

All parameters are set for multihead attention
All parameters are set for multihead attention
All parameters are set for multihead attention
All parameters are set for multihead attention
All parameters are set for multihead attention
All parameters are set for multihead attention
All parameters are set for multihead attention
All parameters are set for multihead attention


In [32]:
encoder_output = encoder(batch_data)

### And thats it we are successfully able to build the encoder block