In [1]:
import torch
import torch.nn as nn
import math
from torch import Tensor
from typing import Optional, Tuple

class PositionalEncoding(nn.Module):
    """
    Injects positional information into the input embeddings.
    This helps the model understand the order of tokens in a sequence,
    as the self-attention mechanism itself is permutation-invariant.
    """
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        """
        Args:
            d_model (int): The dimensionality of the embeddings (must be even).
            dropout (float): Dropout probability.
            max_len (int): The maximum possible length of the input sequences.
        """
        super().__init__()
        if d_model % 2 != 0:
            raise ValueError(f"d_model must be even, got {d_model}")

        self.dropout = nn.Dropout(p=dropout)

        # Create a positional encoding matrix of shape (max_len, d_model)
        position = torch.arange(max_len).unsqueeze(1) # Shape: (max_len, 1)
        # Calculate the division term for the sinusoidal functions
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) # Shape: (d_model / 2)
        pe = torch.zeros(max_len, d_model)
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        # Add a batch dimension and register as a buffer (not a model parameter)
        pe = pe.unsqueeze(0).transpose(0, 1) # Shape: (max_len, 1, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x (Tensor): Input embeddings of shape (seq_len, batch_size, d_model).

        Returns:
            Tensor: Embeddings with added positional information, same shape as input.
        """
        # Add positional encoding to the input embeddings
        # x.size(0) is the sequence length of the current batch
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class MultiHeadAttention(nn.Module):
    """
    Implements the Multi-Head Attention mechanism as described in
    "Attention Is All You Need". Allows the model to jointly attend
    to information from different representation subspaces at different positions.
    """
    def __init__(self, d_model: int, nhead: int, dropout: float = 0.1):
        """
        Args:
            d_model (int): Total dimension of the model.
            nhead (int): Number of parallel attention heads. d_model must be divisible by nhead.
            dropout (float): Dropout probability.
        """
        super().__init__()
        if d_model % nhead != 0:
            raise ValueError(f"d_model ({d_model}) must be divisible by nhead ({nhead})")

        self.d_model = d_model
        self.nhead = nhead
        self.d_k = d_model // nhead # Dimension of each head

        # Linear layers for Query, Key, Value, and the final output projection
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)

    def _scaled_dot_product_attention(self, Q: Tensor, K: Tensor, V: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
        """
        Computes the scaled dot-product attention.

        Args:
            Q (Tensor): Queries, shape (batch_size, nhead, seq_len_q, d_k)
            K (Tensor): Keys, shape (batch_size, nhead, seq_len_k, d_k)
            V (Tensor): Values, shape (batch_size, nhead, seq_len_v, d_k) (seq_len_k == seq_len_v)
            mask (Optional[Tensor]): Mask to prevent attention to certain positions.
                                     Shape (batch_size, 1, seq_len_q, seq_len_k) or (seq_len_q, seq_len_k)

        Returns:
            Tuple[Tensor, Tensor]: Output tensor and attention weights.
                                    Output shape: (batch_size, nhead, seq_len_q, d_k)
                                    Attention weights shape: (batch_size, nhead, seq_len_q, seq_len_k)
        """
        # Calculate attention scores (QK^T / sqrt(d_k))
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale # Shape: (batch_size, nhead, seq_len_q, seq_len_k)

        # Apply mask if provided (e.g., for padding or future tokens)
        if mask is not None:
            # Ensure mask has compatible dimensions for broadcasting
            if mask.dim() == 2: # Sequence mask (tgt_len, src_len) -> (1, 1, tgt_len, src_len)
                 mask = mask.unsqueeze(0).unsqueeze(0)
            elif mask.dim() == 3: # Padding mask (batch_size, 1, src_len) -> (batch_size, 1, 1, src_len)
                 mask = mask.unsqueeze(2)
            # Apply the mask by setting masked positions to a very small number (-inf)
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9) # Use -1e9 instead of float('-inf') for stability

        # Apply softmax to get attention weights
        attn_weights = torch.softmax(attn_scores, dim=-1) # Shape: (batch_size, nhead, seq_len_q, seq_len_k)
        attn_weights = self.dropout(attn_weights)

        # Multiply weights by values (Weighted sum)
        output = torch.matmul(attn_weights, V) # Shape: (batch_size, nhead, seq_len_q, d_k)

        return output, attn_weights

    def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
        """
        Forward pass for Multi-Head Attention.

        Args:
            query (Tensor): Query tensor, shape (seq_len_q, batch_size, d_model).
            key (Tensor): Key tensor, shape (seq_len_k, batch_size, d_model).
            value (Tensor): Value tensor, shape (seq_len_v, batch_size, d_model) (seq_len_k == seq_len_v).
            mask (Optional[Tensor]): Mask, shape depends on use case (padding or look-ahead).
                                     Typical shapes: (tgt_seq_len, src_seq_len) for decoder self-attention mask,
                                     (batch_size, src_seq_len) for encoder padding mask.

        Returns:
            Tuple[Tensor, Tensor]: The attention output and the attention weights.
                                    Output shape: (seq_len_q, batch_size, d_model)
                                    Attention weights shape: (batch_size, nhead, seq_len_q, seq_len_k)
        """
        seq_len_q, batch_size, _ = query.size()
        seq_len_k, _, _ = key.size()
        seq_len_v, _, _ = value.size() # Should be same as seq_len_k

        # 1. Linear projections
        Q = self.W_q(query) # Shape: (seq_len_q, batch_size, d_model)
        K = self.W_k(key)   # Shape: (seq_len_k, batch_size, d_model)
        V = self.W_v(value) # Shape: (seq_len_v, batch_size, d_model)

        # 2. Reshape for multi-head attention
        # (seq_len, batch_size, d_model) -> (seq_len, batch_size, nhead, d_k) -> (batch_size, nhead, seq_len, d_k)
        Q = Q.view(seq_len_q, batch_size, self.nhead, self.d_k).permute(1, 2, 0, 3)
        K = K.view(seq_len_k, batch_size, self.nhead, self.d_k).permute(1, 2, 0, 3)
        V = V.view(seq_len_v, batch_size, self.nhead, self.d_k).permute(1, 2, 0, 3)

        # 3. Scaled dot-product attention
        # attn_output shape: (batch_size, nhead, seq_len_q, d_k)
        # attn_weights shape: (batch_size, nhead, seq_len_q, seq_len_k)
        attn_output, attn_weights = self._scaled_dot_product_attention(Q, K, V, mask=mask)

        # 4. Concatenate heads and project
        # (batch_size, nhead, seq_len_q, d_k) -> (seq_len_q, batch_size, nhead, d_k) -> (seq_len_q, batch_size, d_model)
        attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(seq_len_q, batch_size, self.d_model)

        # 5. Final linear projection
        output = self.W_o(attn_output) # Shape: (seq_len_q, batch_size, d_model)

        return output, attn_weights

class PositionwiseFeedForward(nn.Module):
    """
    Implements the two-layer feed-forward network applied to each position separately and identically.
    FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
    """
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        """
        Args:
            d_model (int): Input and output dimension.
            d_ff (int): Inner dimension of the feed-forward layer.
            dropout (float): Dropout probability.
        """
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.activation = nn.ReLU() # Or GELU

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x (Tensor): Input tensor, shape (seq_len, batch_size, d_model).

        Returns:
            Tensor: Output tensor, same shape as input.
        """
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

class EncoderLayer(nn.Module):
    """
    Represents one layer of the Transformer Encoder.
    Consists of a self-attention mechanism followed by a position-wise feed-forward network.
    Includes residual connections and layer normalization.
    """
    def __init__(self, d_model: int, nhead: int, d_ff: int, dropout: float = 0.1):
        """
        Args:
            d_model (int): The number of expected features in the input.
            nhead (int): The number of heads in the multiheadattention models.
            d_ff (int): The dimension of the feedforward network model.
            dropout (float): The dropout value.
        """
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, nhead, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout=dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        """
        Forward pass for the Encoder Layer.

        Args:
            src (Tensor): Input sequence to the encoder layer, shape (src_seq_len, batch_size, d_model).
            src_mask (Optional[Tensor]): Mask for the source sequence (rarely used in encoder).
            src_key_padding_mask (Optional[Tensor]): Mask for padding tokens in the source sequence,
                                                      shape (batch_size, src_seq_len).

        Returns:
            Tensor: Output tensor from the encoder layer, shape (src_seq_len, batch_size, d_model).
        """
        # 1. Multi-Head Self-Attention + Add & Norm
        # The mask passed here should combine src_mask and src_key_padding_mask if needed.
        # For self-attention, query, key, and value are all 'src'.
        # Note: PyTorch's MultiheadAttention expects mask shape (N, S) for key padding
        #       and (L, S) or (N*num_heads, L, S) for attn_mask.
        #       Our custom implementation handles padding mask broadcasting.
        attn_output, _ = self.self_attn(src, src, src, mask=src_key_padding_mask) # Pass padding mask
        src = src + self.dropout1(attn_output) # Residual connection
        src = self.norm1(src) # Layer normalization

        # 2. Feed Forward + Add & Norm
        ff_output = self.feed_forward(src)
        src = src + self.dropout2(ff_output) # Residual connection
        src = self.norm2(src) # Layer normalization

        return src

class DecoderLayer(nn.Module):
    """
    Represents one layer of the Transformer Decoder.
    Consists of self-attention, encoder-decoder attention, and a feed-forward network.
    Includes residual connections and layer normalization.
    """
    def __init__(self, d_model: int, nhead: int, d_ff: int, dropout: float = 0.1):
        """
        Args:
            d_model (int): The number of expected features in the input.
            nhead (int): The number of heads in the multiheadattention models.
            d_ff (int): The dimension of the feedforward network model.
            dropout (float): The dropout value.
        """
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = MultiHeadAttention(d_model, nhead, dropout=dropout) # Encoder-Decoder Attention
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout=dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt: Tensor, memory: Tensor,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        """
        Forward pass for the Decoder Layer.

        Args:
            tgt (Tensor): Target sequence input to the decoder layer, shape (tgt_seq_len, batch_size, d_model).
            memory (Tensor): Output from the encoder (memory), shape (src_seq_len, batch_size, d_model).
            tgt_mask (Optional[Tensor]): Mask to prevent attending to future tokens in the target sequence,
                                         shape (tgt_seq_len, tgt_seq_len).
            memory_mask (Optional[Tensor]): Mask for the encoder output (rarely used).
            tgt_key_padding_mask (Optional[Tensor]): Mask for padding tokens in the target sequence,
                                                      shape (batch_size, tgt_seq_len).
            memory_key_padding_mask (Optional[Tensor]): Mask for padding tokens in the source sequence
                                                         (passed from encoder), shape (batch_size, src_seq_len).

        Returns:
            Tensor: Output tensor from the decoder layer, shape (tgt_seq_len, batch_size, d_model).
        """
        # 1. Masked Multi-Head Self-Attention (on target sequence) + Add & Norm
        # Combine tgt_mask (look-ahead) and tgt_key_padding_mask
        # Our attention implementation expects the combined mask.
        # If using PyTorch's MultiheadAttention: pass tgt_mask to attn_mask, tgt_key_padding_mask to key_padding_mask
        self_attn_output, _ = self.self_attn(tgt, tgt, tgt, mask=tgt_mask) # Pass look-ahead mask
        tgt = tgt + self.dropout1(self_attn_output) # Residual connection
        tgt = self.norm1(tgt) # Layer normalization

        # 2. Multi-Head Encoder-Decoder Attention + Add & Norm
        # Query=tgt, Key=memory, Value=memory
        # Mask comes from the source padding (memory_key_padding_mask)
        enc_dec_attn_output, _ = self.multihead_attn(tgt, memory, memory, mask=memory_key_padding_mask) # Pass src padding mask
        tgt = tgt + self.dropout2(enc_dec_attn_output) # Residual connection
        tgt = self.norm2(tgt) # Layer normalization

        # 3. Feed Forward + Add & Norm
        ff_output = self.feed_forward(tgt)
        tgt = tgt + self.dropout3(ff_output) # Residual connection
        tgt = self.norm3(tgt) # Layer normalization

        return tgt

class Encoder(nn.Module):
    """
    The Transformer Encoder stack.
    """
    def __init__(self, encoder_layer: nn.Module, num_layers: int, norm: Optional[nn.Module] = None):
        """
        Args:
            encoder_layer (nn.Module): An instance of the EncoderLayer.
            num_layers (int): The number of sub-encoder-layers in the encoder.
            norm (Optional[nn.Module]): An optional layer normalization module.
        """
        super().__init__()
        self.layers = nn.ModuleList([encoder_layer for _ in range(num_layers)])
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        """
        Pass the input through the encoder layers.

        Args:
            src (Tensor): The sequence to the encoder, shape (src_seq_len, batch_size, d_model).
            mask (Optional[Tensor]): The mask for the src sequence (rarely used).
            src_key_padding_mask (Optional[Tensor]): The mask for the src keys per batch,
                                                      shape (batch_size, src_seq_len).

        Returns:
            Tensor: Output tensor from the encoder, shape (src_seq_len, batch_size, d_model).
        """
        output = src
        for mod in self.layers:
            output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

class Decoder(nn.Module):
    """
    The Transformer Decoder stack.
    """
    def __init__(self, decoder_layer: nn.Module, num_layers: int, norm: Optional[nn.Module] = None):
        """
        Args:
            decoder_layer (nn.Module): An instance of the DecoderLayer.
            num_layers (int): The number of sub-decoder-layers in the decoder.
            norm (Optional[nn.Module]): An optional layer normalization module.
        """
        super().__init__()
        self.layers = nn.ModuleList([decoder_layer for _ in range(num_layers)])
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, tgt: Tensor, memory: Tensor,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        """
        Pass the inputs (and mask) through the decoder layer stack.

        Args:
            tgt (Tensor): The sequence to the decoder, shape (tgt_seq_len, batch_size, d_model).
            memory (Tensor): The sequence from the last layer of the encoder, shape (src_seq_len, batch_size, d_model).
            tgt_mask (Optional[Tensor]): The mask for the tgt sequence, shape (tgt_seq_len, tgt_seq_len).
            memory_mask (Optional[Tensor]): The mask for the memory sequence (rarely used).
            tgt_key_padding_mask (Optional[Tensor]): The mask for the tgt keys per batch,
                                                      shape (batch_size, tgt_seq_len).
            memory_key_padding_mask (Optional[Tensor]): The mask for the memory keys per batch,
                                                         shape (batch_size, src_seq_len).

        Returns:
            Tensor: Output tensor from the decoder, shape (tgt_seq_len, batch_size, d_model).
        """
        output = tgt
        for mod in self.layers:
            output = mod(output, memory,
                         tgt_mask=tgt_mask,
                         memory_mask=memory_mask,
                         tgt_key_padding_mask=tgt_key_padding_mask,
                         memory_key_padding_mask=memory_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

class Transformer(nn.Module):
    """
    The main Transformer model, combining Encoder and Decoder.
    Based on the paper "Attention Is All You Need".
    """
    def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6, d_ff: int = 2048, dropout: float = 0.1,
                 activation: str = "relu", # activation is used in PositionwiseFeedForward, not directly here
                 src_vocab_size: int = 10000, tgt_vocab_size: int = 10000,
                 max_seq_len: int = 512):
        """
        Args:
            d_model (int): The number of expected features in the encoder/decoder inputs.
            nhead (int): The number of heads in the multiheadattention models.
            num_encoder_layers (int): The number of sub-encoder-layers in the encoder.
            num_decoder_layers (int): The number of sub-decoder-layers in the decoder.
            d_ff (int): The dimension of the feedforward network model.
            dropout (float): The dropout value.
            activation (str): The activation function of encoder/decoder intermediate layer, relu or gelu. (Not directly used here, passed to layers)
            src_vocab_size (int): Size of the source vocabulary.
            tgt_vocab_size (int): Size of the target vocabulary.
            max_seq_len (int): Maximum sequence length for positional encoding.
        """
        super().__init__()

        # --- Embeddings and Positional Encoding ---
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=max_seq_len)

        # --- Encoder ---
        encoder_layer = EncoderLayer(d_model, nhead, d_ff, dropout)
        encoder_norm = nn.LayerNorm(d_model)
        self.encoder = Encoder(encoder_layer, num_encoder_layers, encoder_norm)

        # --- Decoder ---
        decoder_layer = DecoderLayer(d_model, nhead, d_ff, dropout)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)

        # --- Final Output Layer ---
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

        # --- Initialization ---
        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

    def _reset_parameters(self):
        """Initiate parameters in the transformer model."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def _generate_square_subsequent_mask(self, sz: int) -> Tensor:
        """
        Generates a square mask for the sequence. The masked positions are filled with float('-inf').
        Unmasked positions are filled with float(0.0).
        Used for the decoder's self-attention to prevent looking ahead.
        Shape: (sz, sz)
        """
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def _create_padding_mask(self, sequence: Tensor, pad_idx: int) -> Tensor:
        """
        Creates a mask for padding tokens.
        Args:
            sequence (Tensor): Input sequence tensor, shape (batch_size, seq_len).
            pad_idx (int): Index of the padding token.
        Returns:
            Tensor: Padding mask, shape (batch_size, 1, seq_len). Returns 0 where padded, 1 otherwise.
                    (Note: Attention mechanisms often expect mask==0 for positions to *ignore*)
                    Let's return shape (batch_size, seq_len) where True means padding.
                    The attention layer will handle broadcasting.
        """
        # return (sequence == pad_idx).unsqueeze(1) # Shape: (batch_size, 1, seq_len)
        return (sequence == pad_idx) # Shape: (batch_size, seq_len)


    def forward(self, src: Tensor, tgt: Tensor,
                src_pad_idx: int = 0, tgt_pad_idx: int = 0) -> Tensor:
        """
        Forward pass of the Transformer model.

        Args:
            src (Tensor): Source sequence tensor, shape (batch_size, src_seq_len).
                          Contains token indices.
            tgt (Tensor): Target sequence tensor, shape (batch_size, tgt_seq_len).
                          Contains token indices. Typically starts with <SOS> token.
            src_pad_idx (int): Index of the padding token in the source vocabulary.
            tgt_pad_idx (int): Index of the padding token in the target vocabulary.

        Returns:
            Tensor: Output tensor, shape (batch_size, tgt_seq_len, tgt_vocab_size).
                    Represents the probability distribution over the target vocabulary for each position.
        """
        # 1. Create Masks
        src_seq_len = src.shape[1]
        tgt_seq_len = tgt.shape[1]

        # Look-ahead mask for target self-attention (prevents attending to future tokens)
        # Shape: (tgt_seq_len, tgt_seq_len)
        tgt_mask = self._generate_square_subsequent_mask(tgt_seq_len).to(tgt.device)

        # Padding masks (identify padding tokens)
        # Shape: (batch_size, src_seq_len) and (batch_size, tgt_seq_len)
        # True where padded, False otherwise.
        src_padding_mask = self._create_padding_mask(src, src_pad_idx)
        tgt_padding_mask = self._create_padding_mask(tgt, tgt_pad_idx)


        # 2. Embeddings and Positional Encoding
        # Input shape: (batch_size, seq_len)
        # Embedding output shape: (batch_size, seq_len, d_model)
        # Transformer layers expect (seq_len, batch_size, d_model), so transpose.
        src_emb = self.pos_encoder(self.src_embedding(src).transpose(0, 1) * math.sqrt(self.d_model))
        tgt_emb = self.pos_encoder(self.tgt_embedding(tgt).transpose(0, 1) * math.sqrt(self.d_model))
        # Shape after pos_encoder: (seq_len, batch_size, d_model)

        # 3. Encoder
        # memory shape: (src_seq_len, batch_size, d_model)
        memory = self.encoder(src_emb, src_key_padding_mask=src_padding_mask)

        # 4. Decoder
        # decoder_output shape: (tgt_seq_len, batch_size, d_model)
        decoder_output = self.decoder(tgt_emb, memory,
                                      tgt_mask=tgt_mask, # Look-ahead mask
                                      tgt_key_padding_mask=tgt_padding_mask, # Target padding mask
                                      memory_key_padding_mask=src_padding_mask) # Source padding mask

        # 5. Final Linear Layer
        # Input shape: (tgt_seq_len, batch_size, d_model)
        # Output shape: (tgt_seq_len, batch_size, tgt_vocab_size)
        output = self.fc_out(decoder_output)

        # Transpose back to (batch_size, tgt_seq_len, tgt_vocab_size) for standard loss calculation
        return output.transpose(0, 1)


# --- Example Usage ---
if __name__ == '__main__':
    # Hyperparameters (example values)
    SRC_VOCAB_SIZE = 5000
    TGT_VOCAB_SIZE = 6000
    D_MODEL = 512       # Embedding dimension, must be divisible by NHEAD
    NHEAD = 8           # Number of attention heads
    NUM_ENCODER_LAYERS = 3 # Number of encoder layers
    NUM_DECODER_LAYERS = 3 # Number of decoder layers
    D_FF = 2048         # Dimension of the feed-forward layer
    MAX_SEQ_LEN = 100   # Maximum sequence length
    DROPOUT = 0.1
    BATCH_SIZE = 32
    SRC_SEQ_LEN = 60    # Example source sequence length
    TGT_SEQ_LEN = 55    # Example target sequence length (e.g., during training with teacher forcing)
    PAD_IDX = 0         # Example padding index

    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create the model
    transformer_model = Transformer(
        d_model=D_MODEL,
        nhead=NHEAD,
        num_encoder_layers=NUM_ENCODER_LAYERS,
        num_decoder_layers=NUM_DECODER_LAYERS,
        d_ff=D_FF,
        dropout=DROPOUT,
        src_vocab_size=SRC_VOCAB_SIZE,
        tgt_vocab_size=TGT_VOCAB_SIZE,
        max_seq_len=MAX_SEQ_LEN
    ).to(device)

    # Generate dummy input data
    # Source sequence (batch_size, src_seq_len)
    src_input = torch.randint(1, SRC_VOCAB_SIZE, (BATCH_SIZE, SRC_SEQ_LEN), device=device)
    # Target sequence (batch_size, tgt_seq_len) - for training (teacher forcing)
    tgt_input = torch.randint(1, TGT_VOCAB_SIZE, (BATCH_SIZE, TGT_SEQ_LEN), device=device)

    # Add some padding for demonstration
    src_input[0, -10:] = PAD_IDX # Pad last 10 tokens of first sequence in batch
    tgt_input[1, -5:] = PAD_IDX  # Pad last 5 tokens of second sequence in batch

    print(f"Source Input Shape: {src_input.shape}") # (BATCH_SIZE, SRC_SEQ_LEN)
    print(f"Target Input Shape: {tgt_input.shape}") # (BATCH_SIZE, TGT_SEQ_LEN)

    # --- Forward Pass ---
    # The model expects src/tgt padding indices to create masks internally
    output = transformer_model(src_input, tgt_input, src_pad_idx=PAD_IDX, tgt_pad_idx=PAD_IDX)

    print(f"Output Shape: {output.shape}") # Should be (BATCH_SIZE, TGT_SEQ_LEN, TGT_VOCAB_SIZE)

    # --- Inference Example (Greedy Decoding - very basic) ---
    # During inference, we generate the target sequence token by token.
    print("\n--- Basic Inference Example (Greedy) ---")
    transformer_model.eval() # Set model to evaluation mode

    # Start with a <SOS> token (assuming index 1)
    tgt_sequence = torch.ones((BATCH_SIZE, 1), dtype=torch.long, device=device) * 1 # Shape: (BATCH_SIZE, 1)

    max_output_len = 20 # Maximum length of the generated sequence

    with torch.no_grad(): # No need to track gradients during inference
        # Encode the source sequence once
        src_padding_mask_inf = transformer_model._create_padding_mask(src_input, PAD_IDX)
        src_emb_inf = transformer_model.pos_encoder(transformer_model.src_embedding(src_input).transpose(0, 1) * math.sqrt(D_MODEL))
        memory_inf = transformer_model.encoder(src_emb_inf, src_key_padding_mask=src_padding_mask_inf)
        # memory_inf shape: (src_seq_len, batch_size, d_model)

        for _ in range(max_output_len):
            tgt_seq_len_inf = tgt_sequence.shape[1]
            tgt_mask_inf = transformer_model._generate_square_subsequent_mask(tgt_seq_len_inf).to(device)
            tgt_padding_mask_inf = transformer_model._create_padding_mask(tgt_sequence, PAD_IDX)

            # Embed the current target sequence
            tgt_emb_inf = transformer_model.pos_encoder(transformer_model.tgt_embedding(tgt_sequence).transpose(0, 1) * math.sqrt(D_MODEL))

            # Decode using the current target sequence and encoder memory
            decoder_output_inf = transformer_model.decoder(
                tgt_emb_inf, memory_inf,
                tgt_mask=tgt_mask_inf,
                tgt_key_padding_mask=tgt_padding_mask_inf,
                memory_key_padding_mask=src_padding_mask_inf
            ) # Shape: (tgt_seq_len, batch_size, d_model)

            # Get the output for the *last* token only
            last_token_output = transformer_model.fc_out(decoder_output_inf[-1, :, :]) # Shape: (batch_size, tgt_vocab_size)

            # Find the token with the highest probability (greedy decoding)
            next_token = last_token_output.argmax(dim=-1) # Shape: (batch_size)

            # Append the predicted token to the target sequence
            tgt_sequence = torch.cat([tgt_sequence, next_token.unsqueeze(1)], dim=1) # Shape: (batch_size, current_len + 1)

            # Optional: Stop if all sequences in the batch generated an <EOS> token (e.g., index 2)
            # if (next_token == 2).all():
            #     break

    print(f"Generated Target Sequence Shape: {tgt_sequence.shape}") # (BATCH_SIZE, generated_len)
    print(f"Example Generated Sequence (first in batch):\n{tgt_sequence[0]}")



Using device: cpu
Source Input Shape: torch.Size([32, 60])
Target Input Shape: torch.Size([32, 55])


RuntimeError: The size of tensor a (32) must match the size of tensor b (60) at non-singleton dimension 2

In [6]:
import torch
import torch.nn as nn
import math
from torch import Tensor
from typing import Optional, Tuple

# Note: Positional Encoding is omitted for simplicity in this primitive version.
# This means the model has no inherent understanding of word order.

class SimpleAttention(nn.Module):
    """
    A very basic single-head scaled dot-product attention mechanism.
    Simplified version without multi-head complexity.
    """
    def __init__(self, d_model: int, dropout: float = 0.1):
        """
        Args:
            d_model (int): Dimension of the model.
            dropout (float): Dropout probability.
        """
        super().__init__()
        self.d_model = d_model
        # Use d_model directly as d_k since it's single-head
        self.d_k = d_model

        # Linear layers for Query, Key, Value
        # Note: In a truly minimal version, these could even be omitted,
        # using the input directly, but we keep them for closer analogy.
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        # No final output projection Wo needed as we don't concatenate heads

        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
        """
        Forward pass for Simple Attention.

        Args:
            query (Tensor): Query tensor, shape (seq_len_q, batch_size, d_model).
            key (Tensor): Key tensor, shape (seq_len_k, batch_size, d_model).
            value (Tensor): Value tensor, shape (seq_len_v, batch_size, d_model) (seq_len_k == seq_len_v).
            mask (Optional[Tensor]): Mask, shape depends on use case.
                                     Boolean tensor: True indicates position should be masked.
                                     Expected shapes: (tgt_seq_len, src_seq_len) or (batch_size, src_seq_len).

        Returns:
            Tuple[Tensor, Tensor]: The attention output and the attention weights.
                                    Output shape: (seq_len_q, batch_size, d_model)
                                    Attention weights shape: (batch_size, seq_len_q, seq_len_k)
        """
        seq_len_q, batch_size, _ = query.size()
        seq_len_k, _, _ = key.size()

        # 1. Linear projections
        Q = self.W_q(query) # Shape: (seq_len_q, batch_size, d_model)
        K = self.W_k(key)   # Shape: (seq_len_k, batch_size, d_model)
        V = self.W_v(value) # Shape: (seq_len_v, batch_size, d_model)

        # Transpose batch and sequence length dimensions for attention calculation
        # (seq_len, batch_size, d_model) -> (batch_size, seq_len, d_model)
        Q = Q.transpose(0, 1)
        K = K.transpose(0, 1)
        V = V.transpose(0, 1)

        # 2. Scaled dot-product attention
        # (batch_size, seq_len_q, d_model) @ (batch_size, d_model, seq_len_k) -> (batch_size, seq_len_q, seq_len_k)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        # Apply mask if provided
        if mask is not None:
            # Ensure mask has compatible dimensions for broadcasting to attn_scores (batch_size, seq_len_q, seq_len_k)
            if mask.dim() == 2:
                # Case 1: Look-ahead mask (tgt_len, tgt_len) or (seq_len_q, seq_len_k)
                if mask.shape[0] == seq_len_q and mask.shape[1] == seq_len_k:
                    # Unsqueeze batch dimension
                    mask = mask.unsqueeze(0) # Shape: (1, seq_len_q, seq_len_k)
                # Case 2: Padding mask (batch_size, src_len) or (batch_size, seq_len_k)
                elif mask.shape[0] == batch_size and mask.shape[1] == seq_len_k:
                    # Unsqueeze query sequence length dimension
                    mask = mask.unsqueeze(1) # Shape: (batch_size, 1, seq_len_k)
                else:
                     # Handle potential mismatch or unexpected shape
                     raise ValueError(f"Mask shape {mask.shape} incompatible with attention scores shape {attn_scores.shape}")

            elif mask.dim() == 3:
                 # Assume mask is already broadcastable, e.g., (batch_size, 1, seq_len_k) or (batch_size, seq_len_q, seq_len_k)
                 # No unsqueezing needed if dimensions match or broadcast correctly.
                 pass # Keep mask as is
            elif mask.dim() == 1: # Less common, maybe (src_len)?
                 # Needs careful handling depending on intent. Let's assume it's (seq_len_k)
                 # Needs expansion to (1, 1, seq_len_k) for broadcasting
                 if mask.shape[0] == seq_len_k:
                     mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len_k)
                 else:
                     raise ValueError(f"Mask shape {mask.shape} incompatible with attention scores shape {attn_scores.shape}")

            # Check for final broadcast compatibility before applying
            try:
                # Apply the mask (True values in mask indicate positions to ignore)
                # Use a large negative number for masked positions before softmax
                attn_scores = attn_scores.masked_fill(mask, -1e9)
            except RuntimeError as e:
                print(f"Error applying mask. Attn scores shape: {attn_scores.shape}, Mask shape after potential unsqueeze: {mask.shape}")
                raise e


        # Apply softmax to get attention weights
        attn_weights = torch.softmax(attn_scores, dim=-1) # Shape: (batch_size, seq_len_q, seq_len_k)
        attn_weights = self.dropout(attn_weights)

        # Multiply weights by values
        # (batch_size, seq_len_q, seq_len_k) @ (batch_size, seq_len_v, d_model) -> (batch_size, seq_len_q, d_model)
        output = torch.matmul(attn_weights, V)

        # Transpose back to (seq_len_q, batch_size, d_model)
        output = output.transpose(0, 1)

        # Return attention weights in (batch_size, seq_len_q, seq_len_k) format
        return output, attn_weights

# Optional: Simplified FeedForward (could even be removed entirely)
class SimpleFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.activation = nn.ReLU()

    def forward(self, x: Tensor) -> Tensor:
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

class SimpleEncoderLayer(nn.Module):
    """ Simplified Encoder Layer with single attention and optional feedforward """
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = SimpleAttention(d_model, dropout=dropout)
        self.feed_forward = SimpleFeedForward(d_model, d_ff, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout) # Single dropout for simplicity

    def forward(self, src: Tensor, src_padding_mask: Optional[Tensor] = None) -> Tensor:
        # Self Attention + Add & Norm
        attn_output, _ = self.self_attn(src, src, src, mask=src_padding_mask)
        src = self.norm1(src + self.dropout(attn_output)) # Residual connection

        # Feed Forward + Add & Norm
        ff_output = self.feed_forward(src)
        src = self.norm2(src + self.dropout(ff_output)) # Residual connection
        return src

class SimpleDecoderLayer(nn.Module):
    """ Simplified Decoder Layer """
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = SimpleAttention(d_model, dropout=dropout)
        self.cross_attn = SimpleAttention(d_model, dropout=dropout) # Encoder-Decoder Attention
        self.feed_forward = SimpleFeedForward(d_model, d_ff, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout) # Single dropout for simplicity

    def forward(self, tgt: Tensor, memory: Tensor,
                tgt_mask: Optional[Tensor] = None, # Look-ahead mask
                memory_padding_mask: Optional[Tensor] = None, # Src padding mask
                tgt_padding_mask: Optional[Tensor] = None) -> Tensor: # Tgt padding mask

        # Masked Self-Attention (on target) + Add & Norm
        # Combine look-ahead and target padding mask if needed (simplified here)
        # Note: SimpleAttention expects boolean mask where True means MASK
        # Need to combine tgt_mask (float -inf/0) and tgt_padding_mask (bool True/False)
        # For simplicity, we only pass the look-ahead mask here. Proper handling is more complex.
        self_attn_output, _ = self.self_attn(tgt, tgt, tgt, mask=tgt_mask)
        tgt = self.norm1(tgt + self.dropout(self_attn_output))

        # Encoder-Decoder Attention + Add & Norm
        # Query=tgt, Key=memory, Value=memory
        cross_attn_output, _ = self.cross_attn(tgt, memory, memory, mask=memory_padding_mask)
        tgt = self.norm2(tgt + self.dropout(cross_attn_output))

        # Feed Forward + Add & Norm
        ff_output = self.feed_forward(tgt)
        tgt = self.norm3(tgt + self.dropout(ff_output))
        return tgt

class SimpleTransformer(nn.Module):
    """
    A very simplified Transformer model.
    - No positional encoding
    - Single-layer Encoder & Decoder
    - Simple single-head attention
    """
    def __init__(self, d_model: int = 128, d_ff: int = 256, dropout: float = 0.1,
                 src_vocab_size: int = 1000, tgt_vocab_size: int = 1200,
                 pad_idx: int = 0):
        """
        Args:
            d_model (int): Embedding dimension.
            d_ff (int): Feedforward inner dimension.
            dropout (float): Dropout rate.
            src_vocab_size (int): Size of source vocabulary.
            tgt_vocab_size (int): Size of target vocabulary.
            pad_idx (int): Index for padding token.
        """
        super().__init__()
        self.pad_idx = pad_idx
        self.d_model = d_model

        # --- Embeddings ---
        self.src_embedding = nn.Embedding(src_vocab_size, d_model, padding_idx=pad_idx)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model, padding_idx=pad_idx)
        # --- NO Positional Encoding ---

        # --- Single Encoder Layer ---
        self.encoder_layer = SimpleEncoderLayer(d_model, d_ff, dropout)
        # --- Single Decoder Layer ---
        self.decoder_layer = SimpleDecoderLayer(d_model, d_ff, dropout)

        # --- Final Output Layer ---
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

        # --- Initialization ---
        self._reset_parameters()

    def _reset_parameters(self):
        """Initiate parameters."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def _generate_square_subsequent_mask(self, sz: int, device: torch.device) -> Tensor:
        """
        Generates a boolean mask where True means position should be masked.
        Shape: (sz, sz)
        """
        mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1)
        return mask.bool() # True for upper triangle (masked positions)

    def _create_padding_mask(self, sequence: Tensor) -> Tensor:
        """
        Creates a boolean mask for padding tokens (True where padded).
        Shape: (batch_size, seq_len)
        """
        return (sequence == self.pad_idx)

    def encode(self, src: Tensor) -> Tensor:
        """ Encodes the source sequence. """
        src_padding_mask = self._create_padding_mask(src) # (batch_size, src_seq_len)
        # Embeddings: (batch_size, src_seq_len) -> (batch_size, src_seq_len, d_model)
        # Transpose for Encoder Layer: -> (src_seq_len, batch_size, d_model)
        src_emb = self.src_embedding(src).transpose(0, 1) * math.sqrt(self.d_model)
        # NO Positional Encoding
        # Encoder expects mask as (batch_size, src_seq_len)
        memory = self.encoder_layer(src_emb, src_padding_mask=src_padding_mask)
        return memory, src_padding_mask

    def decode(self, tgt: Tensor, memory: Tensor, memory_padding_mask: Optional[Tensor] = None) -> Tensor:
         """ Decodes the target sequence given memory. """
         tgt_seq_len = tgt.shape[1]
         # Create masks
         # Look-ahead mask (True means mask): (tgt_seq_len, tgt_seq_len)
         tgt_mask = self._generate_square_subsequent_mask(tgt_seq_len, tgt.device)
         # Target padding mask (True means mask): (batch_size, tgt_seq_len)
         tgt_padding_mask = self._create_padding_mask(tgt)

         # Embeddings: (batch_size, tgt_seq_len) -> (batch_size, tgt_seq_len, d_model)
         # Transpose for Decoder Layer: -> (tgt_seq_len, batch_size, d_model)
         tgt_emb = self.tgt_embedding(tgt).transpose(0, 1) * math.sqrt(self.d_model)
         # NO Positional Encoding

         # Decoder Layer
         # Pass necessary masks: look-ahead, source padding, target padding
         decoder_output = self.decoder_layer(
             tgt_emb, memory,
             tgt_mask=tgt_mask,
             memory_padding_mask=memory_padding_mask,
             tgt_padding_mask=tgt_padding_mask # Pass target padding mask if needed by attention
         )
         return decoder_output # Shape: (tgt_seq_len, batch_size, d_model)

    def forward(self, src: Tensor, tgt: Tensor) -> Tensor:
        """
        Full forward pass for training.

        Args:
            src (Tensor): Source sequence tensor, shape (batch_size, src_seq_len).
            tgt (Tensor): Target sequence tensor (shifted right), shape (batch_size, tgt_seq_len).

        Returns:
            Tensor: Output logits, shape (batch_size, tgt_seq_len, tgt_vocab_size).
        """
        # Encode source
        memory, src_padding_mask = self.encode(src) # memory: (src_seq_len, batch_size, d_model)

        # Decode target
        decoder_output = self.decode(tgt, memory, memory_padding_mask=src_padding_mask)
        # decoder_output: (tgt_seq_len, batch_size, d_model)

        # Final linear layer
        output_logits = self.fc_out(decoder_output) # (tgt_seq_len, batch_size, tgt_vocab_size)

        # Transpose back to (batch_size, tgt_seq_len, tgt_vocab_size)
        return output_logits.transpose(0, 1)


# --- Example Usage ---
if __name__ == '__main__':
    # Hyperparameters (smaller values for simplicity)
    SRC_VOCAB_SIZE = 1000
    TGT_VOCAB_SIZE = 1200
    D_MODEL = 128       # Embedding dimension
    D_FF = 256          # Feed-forward dimension
    DROPOUT = 0.1
    PAD_IDX = 0         # Padding index

    BATCH_SIZE = 16
    SRC_SEQ_LEN = 20
    TGT_SEQ_LEN = 18

    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create the simplified model
    simple_transformer = SimpleTransformer(
        d_model=D_MODEL,
        d_ff=D_FF,
        dropout=DROPOUT,
        src_vocab_size=SRC_VOCAB_SIZE,
        tgt_vocab_size=TGT_VOCAB_SIZE,
        pad_idx=PAD_IDX
    ).to(device)

    # Generate dummy input data
    src_input = torch.randint(1, SRC_VOCAB_SIZE, (BATCH_SIZE, SRC_SEQ_LEN), device=device)
    tgt_input = torch.randint(1, TGT_VOCAB_SIZE, (BATCH_SIZE, TGT_SEQ_LEN), device=device)

    # Add padding
    src_input[0, -5:] = PAD_IDX
    tgt_input[1, -3:] = PAD_IDX

    print(f"Source Input Shape: {src_input.shape}")
    print(f"Target Input Shape: {tgt_input.shape}")

    # --- Forward Pass (Training) ---
    output = simple_transformer(src_input, tgt_input)
    print(f"Output Shape (Training): {output.shape}") # Should be (BATCH_SIZE, TGT_SEQ_LEN, TGT_VOCAB_SIZE)

    # --- Basic Inference Example (Greedy) ---
    print("\n--- Basic Inference Example (Greedy) ---")
    simple_transformer.eval()

    # Assume SOS token index is 1, EOS is 2
    SOS_IDX = 1
    EOS_IDX = 2
    max_output_len = 15

    with torch.no_grad():
        # Encode source sequence once
        memory_inf, src_padding_mask_inf = simple_transformer.encode(src_input)

        # Start with SOS token for each sequence in the batch
        tgt_sequence = torch.full((BATCH_SIZE, 1), SOS_IDX, dtype=torch.long, device=device) # (BATCH_SIZE, 1)

        for _ in range(max_output_len):
            # Decode the current target sequence
            decoder_output_inf = simple_transformer.decode(
                tgt_sequence, memory_inf, memory_padding_mask=src_padding_mask_inf
            ) # Shape: (current_tgt_len, batch_size, d_model)

            # Get logits for the *last* predicted token
            last_token_logits = simple_transformer.fc_out(decoder_output_inf[-1, :, :]) # Shape: (batch_size, tgt_vocab_size)

            # Greedy choice: pick token with highest logit
            next_token = last_token_logits.argmax(dim=-1) # Shape: (batch_size)

            # Append predicted token to the sequence
            tgt_sequence = torch.cat([tgt_sequence, next_token.unsqueeze(1)], dim=1)

            # Simple stopping condition: Check if EOS was generated for all sequences
            # (A more robust implementation would handle finished sequences individually)
            if (next_token == EOS_IDX).all():
                 break

    print(f"Generated Target Sequence Shape: {tgt_sequence.shape}")
    print(f"Example Generated Sequence (first in batch):\n{tgt_sequence[0]}")



Using device: cpu
Source Input Shape: torch.Size([16, 20])
Target Input Shape: torch.Size([16, 18])
Output Shape (Training): torch.Size([16, 18, 1200])

--- Basic Inference Example (Greedy) ---
Generated Target Sequence Shape: torch.Size([16, 16])
Example Generated Sequence (first in batch):
tensor([   1,  965,  415, 1104,  491,  848,  106,  436,   89, 1122,  415,  950,
         698, 1161,  890, 1046])
