In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter

In [3]:
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

In [4]:
import warnings
from pathlib import Path
from typing import Any
from tqdm import tqdm
import math

In [None]:
class InputEmbeddings(nn.Module):
    """
    InputEmbeddings class:
    - Converts input token indices into dense vectors of dimension `d_model`.
    - Scales embeddings by sqrt(d_model) to stabilize gradients in Transformer models.
    - Parameters:
        - d_model (int): size of each embedding vector.
        - num_embeddings (int): size of the vocabulary.
    - Usage: feeds into Transformer encoder/decoder as input embeddings.
    
    Inputs:
        - x (Tensor): shape (batch_size, seq_length), dtype=torch.long containing token indices.
    Outputs:
        - embeddings (Tensor): shape (batch_size, seq_length, d_model) scaled embeddings ready for the Transformer.
    """
    def __init__(self, d_model: int, num_embeddings: int) -> None:
        super().__init__()
        self.d_model = d_model # Dimension of Vec
        self.num_embeddings = num_embeddings # Size of Vocab
        self.embedding = nn.Embedding(num_embeddings, d_model)
    
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model) # Normalization

In [None]:
class PositionalEncoding(nn.Module):
    """
    PositionalEncoding class:
    - Adds positional information to token embeddings so the Transformer can capture the order of tokens.
    - Uses sine and cosine functions of different frequencies for each dimension.
    - Parameters:
        - d_model (int): dimension of embedding vectors.
        - seq_len (int): maximum sequence length.
        - dropout (float): dropout rate applied after adding positional encoding.
    - Usage: added to token embeddings before feeding into Transformer layers.

    Inputs:
        - x (Tensor): shape (batch_size, seq_length, d_model) token embeddings from InputEmbeddings.
    Outputs:
        - x (Tensor): shape (batch_size, seq_length, d_model) token embeddings with positional information added, dropout applied.
    """
    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        
        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        
        i = torch.arange(0, d_model, 2, dtype=torch.float)
        div_term = torch.exp(i * (-math.log(10000)) / d_model)
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)
    
    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :])
        return self.dropout(x)

In [None]:
class LayerNormalization(nn.Module):
    """
    LayerNormalization class:
    - Normalizes inputs across the last dimension to have zero mean and unit variance.
    - Learnable parameters (alpha, bias) allow the model to scale and shift the normalized values.
    - Helps stabilize and accelerate training of deep networks, especially Transformers.
    - Parameters:
        - eps (float): small value to avoid division by zero (default 1e-6).
    
    Inputs:
        - x (Tensor): shape (..., features), can be any shape with the last dim as features.
    Outputs:
        - normalized_x (Tensor): same shape as input, normalized along the last dimension.
    """
    def __init__(self, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps = eps
        
        self.alpha = nn.Parameter(torch.ones(1))
        self.bias = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        
        return self.alpha * (x-mean) / (std + self.eps) + self.bias

In [None]:
from abc import ABC, abstractmethod

class SubLayer(nn.Module, ABC):
    """
    SubLayer abstract class:
    - Base class for Transformer sub-layers (e.g., attention, feed-forward).
    - Enforces implementation of the `forward` method in all subclasses.
    - Can accept additional keyword arguments for flexibility.
    
    Inputs:
        - x (Tensor): shape (batch_size, seq_length, d_model), input embeddings or outputs from previous layer.
        - **kwargs: optional additional arguments needed by specific sub-layers.
    Outputs:
        - Tensor: processed output, same shape as input in most cases.
    """
    @abstractmethod
    def forward(self, x, **kwargs):
        raise NotImplementedError("Subclasses must implement the forward method")

In [None]:
class FeedForward(SubLayer):
    """
    FeedForward sub-layer for Transformers:
    - Implements a 2-layer position-wise feed-forward network.
    - Applies ReLU activation and dropout between the two linear layers.
    - Expands and then projects back to `d_model` dimensions.
    
    Parameters:
        - d_model (int): input and output dimension of the sub-layer.
        - d_ff (int): hidden layer dimension (usually larger than d_model).
        - dropout (float): dropout probability applied after the activation.
    
    Inputs:
        - x (Tensor): shape (batch_size, seq_length, d_model), input from previous layer.
    Outputs:
        - Tensor: shape (batch_size, seq_length, d_model), transformed output.
    """
    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
    
    def forward(self, x):
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

In [None]:
class ResidualConnection(nn.Module):
    """
    ResidualConnection module for Transformers:
    - Wraps a sub-layer (e.g., FeedForward or Attention) with a residual connection and layer normalization.
    - Supports both Pre-Norm and Post-Norm variants:
        * Pre-Norm: normalization before the sub-layer.
        * Post-Norm: normalization after adding the sub-layer output.
    - Applies dropout to the sub-layer output before adding the residual.

    Parameters:
        - dropout (float): dropout probability applied to sub-layer output.
        - pre_norm (bool): if True, use Pre-Norm; else, use Post-Norm.

    Inputs:
        - x (Tensor): shape (batch_size, seq_length, d_model), input to the residual block.
        - sub_layer (SubLayer): a Transformer sub-layer implementing the forward(x) method.

    Outputs:
        - Tensor: shape (batch_size, seq_length, d_model), output after residual addition and normalization.
    """
    def __init__(self, dropout: float, pre_norm: bool = True) -> None:
        super().__init__()
        self.pre_norm = pre_norm
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization()

    def forward(self, x, sub_layer: SubLayer):
        if self.pre_norm:
            return x + self.dropout(sub_layer(self.norm(x))) # Pre-Norm
        else:
            return self.norm(x + self.dropout(sub_layer(x))) # Post-Norm


In [None]:
class MultiHeadAttention(SubLayer):
    """
    Multi-Head Attention module for Transformers:
    - Computes attention over queries, keys, and values split across multiple heads.
    - Each head learns to focus on different parts of the input sequence.
    - Supports optional attention masking and dropout.

    Parameters:
        - d_model (int): dimension of input embeddings.
        - h (int): number of attention heads.
        - dropout (float): dropout probability applied to attention scores.

    Inputs:
        - q (Tensor): shape (batch_size, seq_length, d_model), query embeddings.
        - k (Tensor): shape (batch_size, seq_length, d_model), key embeddings.
        - v (Tensor): shape (batch_size, seq_length, d_model), value embeddings.
        - mask (Tensor, optional): shape broadcast-able to (batch_size, h, seq_length, seq_length), 
        used to mask out positions (e.g., for padding or causal attention).

    Outputs:
        - Tensor: shape (batch_size, seq_length, d_model), output after multi-head attention.
        - attention_scores (Tensor): shape (batch_size, h, seq_length, seq_length), 
        attention weights for each head.
    """
    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.h = h
        
        assert d_model % h == 0
        
        self.d_k = d_model // h
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    @staticmethod
    def attention(q, k, v, mask=None, dropout=None):
        d_k = q.shape[-1]
        attention_scores = (q @ k.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            attention_scores = attention_scores.masked_fill_(mask==0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1)
        if dropout is not None:
            attention_scores = dropout(attention_scores)

        return (attention_scores @ v), attention_scores
        
    def forward(self, q, k, v, mask=None):
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)
        
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
        
        x, self.attention_scores = self.attention(query, key, value, mask, self.dropout)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], value.shape[1], self.h * self.d_k)
        return self.w_o(x)

In [None]:
class EncoderBlock(nn.Module):
    """
    EncoderBlock module for Transformers:
    - Combines a Multi-Head Self-Attention layer and a Feed-Forward layer.
    - Each sub-layer is wrapped with a residual connection and layer normalization.
    - Self-Attention supports optional masking (e.g., padding or causal attention).

    Parameters:
        - self_attention_block (MultiHeadAttention): multi-head attention sub-layer.
        - feedforward_block (FeedForward): feed-forward sub-layer.
        - dropout (float): dropout probability applied to sub-layer outputs.

    Inputs:
        - x (Tensor): shape (batch_size, seq_length, d_model), input embeddings.
        - source_mask (Tensor, optional): shape broadcast-able to (batch_size, h, seq_length, seq_length), 
        used to mask out certain positions in attention.

    Outputs:
        - Tensor: shape (batch_size, seq_length, d_model), output after attention and feed-forward layers.
    """
    def __init__(self, self_attention_block: MultiHeadAttention, feedforward_block: FeedForward, dropout=0.1):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feedforward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
    
    def forward(self, x, source_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, source_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

In [None]:
class Encoder(nn.Module):
    """
    Encoder module for Transformers:
    - Stacks multiple EncoderBlocks to form the full encoder.
    - Each block contains self-attention and feed-forward layers with residual connections and layer normalization.
    - Optional final LayerNormalization applied after all blocks (Post-Norm).

    Parameters:
        - layers (nn.ModuleList): list of EncoderBlock instances to be applied sequentially.

    Inputs:
        - x (Tensor): shape (batch_size, seq_length, d_model), input embeddings.
        - mask (Tensor, optional): shape broadcast-able to (batch_size, seq_length, seq_length),
        used to mask out certain positions in self-attention (e.g., padding tokens or causal attention).

    Outputs:
        - Tensor: shape (batch_size, seq_length, d_model), encoded representations after passing through all blocks
        and post-layer normalization.
    """
    def __init__(self, layers):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()
        
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x) # Post-Norm

In [None]:
class DecoderBlock(nn.Module):
    """
    Single Decoder block for Transformers:
    - Consists of self-attention, cross-attention, and feed-forward layers.
    - Each layer uses residual connections with dropout and layer normalization.
    - Self-attention prevents attending to future tokens using a target mask.
    - Cross-attention attends over encoder outputs with an optional source mask to ignore padding.

    Parameters:
        - self_attention_block (MultiHeadAttention): multi-head self-attention for the target sequence.
        - cross_attention_block (MultiHeadAttention): multi-head attention to attend to encoder outputs.
        - feed_forward_block (FeedForward): position-wise feed-forward network.
        - dropout (float): dropout rate applied in residual connections.

    Inputs:
        - x (Tensor): shape (batch_size, target_len, d_model), input embeddings or previous decoder output.
        - encoder_output (Tensor): shape (batch_size, source_len, d_model), output of the encoder.
        - source_mask (Tensor, optional): shape broadcast-able to (batch_size, 1, 1, source_len),
        used to mask out padding tokens in the source sequence.
        - target_mask (Tensor, optional): shape broadcast-able to (batch_size, 1, target_len, target_len),
        used to mask future tokens in self-attention.

    Outputs:
        - Tensor: shape (batch_size, target_len, d_model), output of the decoder block after
        self-attention, cross-attention, and feed-forward layers with residual connections.
    """
    def __init__(
        self,
        self_attention_block: MultiHeadAttention,
        cross_attention_block: MultiHeadAttention,
        feed_forward_block: FeedForward,
        dropout: float,
    ):
        
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])
        
    def forward(self, x, encoder_output, source_mask, target_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, target_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, source_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x

In [None]:
class Decoder(nn.Module):
    """
    Transformer Decoder:
    - Stack of DecoderBlocks with masked self-attention, cross-attention, and feed-forward layers.
    - Applies final LayerNorm after all blocks.

    Parameters:
        - layers (nn.ModuleList): list of DecoderBlock modules.

    Inputs:
        - x (Tensor): (batch_size, target_len, d_model), input embeddings or previous decoder output.
        - encoder_output (Tensor): (batch_size, source_len, d_model), encoder outputs.
        - source_mask (Tensor, optional): masks padding in encoder sequence.
        - target_mask (Tensor, optional): masks future tokens in target sequence.

    Output:
        - Tensor: (batch_size, target_len, d_model), final decoder representation.
    """
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()
    
    def forward(self, x, encoder_output, source_mask, target_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, source_mask, target_mask)
        return self.norm(x)