# Import the Required Libraries

In [37]:
import math
from typing import Optional, Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as fun

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
import lightning as lt

# [1] Encoder-Decoder Transformer Class

## [1.1] Positional Encoding Class

In [38]:
class PositionalEncoding (nn.Module):
    """Positional encoding module with learnable parameters if needed."""

    def __init__(self, d_model=512, max_len=5000, dropout=0.1):
        """
        Args:
            d_model: Dimension of the model embeddings.
            max_len: Maximum length of input sequences.
            dropout: Dropout probability.
        """
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # pe(pos, 2i)   = sin(pos / 10000^(2i/d_model))
        # pe(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0)/d_model))       # Scale down sine/cosine frequencies

        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)                                                          # ensure pe is not a learnable parameter
    

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
            
        Returns:
            Tensor with positional encoding added
        """
        x = x + self.pe[:x.size(1), :].unsqueeze(0)                                             # Add positional encoding to input embeddings
        return self.dropout(x)

## [1.2] Multi-head Attention Class

In [39]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention mechanism."""
    
    def __init__(self, d_model=512, num_heads=8, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model                                          # Total dimension of input/output embeddings.
        self.num_heads = num_heads                                      # Number of attention heads.
        self.d_k = d_model // num_heads                                 # Dimensionality per head.
        
        self.w_q = nn.Linear(d_model, d_model)                          # Define linear projects for w_q
        self.w_k = nn.Linear(d_model, d_model)                          # Define linear projects for w_k
        self.w_v = nn.Linear(d_model, d_model)                          # Define linear projects for w_v
        self.w_o = nn.Linear(d_model, d_model)                          # Define linear projects for w_o
        
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([self.d_k]))          # Scaling to stabilize gradients.
        
    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: Query tensor (batch_size, q_len, d_model)
            key: Key tensor (batch_size, k_len, d_model)
            value: Value tensor (batch_size, v_len, d_model)
            mask: Optional mask tensor
            
        Returns:
            Attention output tensor
        """
        batch_size = query.size(0)
        
        # 1. Apply Linear Projections
        Q = self.w_q(query)                                             # (batch_size, q_len, d_model)
        K = self.w_k(key)                                               # (batch_size, k_len, d_model)
        V = self.w_v(value)                                             # (batch_size, v_len, d_model)
        
        # 2.Split into Multiple Heads
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 3. Calculate Attention Scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale.to(query.device)
        
        # 4. Apply Masking
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
            
        # 5. Softmax and Dropout
        attention = fun.softmax(scores, dim=-1)
        attention = self.dropout(attention)
        
        # 6. Apply Attention Weights
        x = torch.matmul(attention, V)
        
        # 7. Concatenate Heads
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # Final linear projection
        return self.w_o(x)

## [1.3] Encoder and Decoder Layers

In [40]:
# [1] Feedforward ANN
class PositionwiseFeedForward(nn.Module):
    """Position-wise feed forward network."""
    
    def __init__(self, d_model=512, d_ff=2048, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)                             # Expand Dimension
        self.linear2 = nn.Linear(d_ff, d_model)                             # Project Data Back
        self.dropout = nn.Dropout(dropout)                                  # Apply Regularization
        
    def forward(self, x):
        return self.linear2(self.dropout(fun.relu(self.linear1(x))))        # Apply Activation Function

In [21]:
# [2] Encoder Layer
class EncoderLayer(nn.Module):
    """Single encoder layer with self-attention and feed forward."""

    def __init__(self, d_model=512, num_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # 1. Encoder Self-attention
        attn_output = self.self_attn(x, x, x, mask)     # Encoder Self-attention
        x = x + self.dropout1(attn_output)              # Residual Connection with Dropout Regularization
        x = self.norm1(x)                               # Apply Normalization

        # 2. Feedforward Layer
        ff_output = self.feed_forward(x)                # Add Feedforward Layer
        x = x + self.dropout2(ff_output)                # Residual Connection
        x = self.norm2(x)                               # Apply Normalization

        return x

In [41]:
# [3] Decoder Layer
class DecoderLayer(nn.Module):
    """Single decoder layer with self-attention, encoder-decoder attention, and feed forward."""

    def __init__(self, d_model=512, num_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.enc_dec_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
    
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # 1. Decoder Self-attention
        attn_output = self.self_attn(x, x, x, tgt_mask)                             # Decoder Self-attention
        x = x + self.dropout1(attn_output)                                          # Residual Connection with Dropout Regularization
        x = self.norm1(x)                                                           # Apply Normalization

        # 2. Encoder-Decoder Attention
        attn_output2 = self.enc_dec_attn(x, enc_output, enc_output, src_mask)       # Encoder Self-attention
        x = x + self.dropout2(attn_output2)                                         # Residual Connection with Dropout Regularization
        x = self.norm2(x)                                                           # Apply Normalization

        # 3. Feedforward Layer
        ff_output = self.feed_forward(x)                                            # Add Feedforward Layer
        x = x + self.dropout3(ff_output)                                            # Residual Connection
        x = self.norm3(x)                                                           # Apply Normalization

        return x

## [1.4] Encoder Class

In [42]:
class Encoder(nn.Module):
    """Transformer encoder with multiple layers."""
    
    def __init__(self, vocab_size, d_model=512, num_layers=6, num_heads=8, d_ff=2048, dropout=0.1, max_len=5000):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)                        # Define Token Embedding
        self.position_encoding = PositionalEncoding(d_model, max_len, dropout)          # Define Positional Encoding
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout)    # Define Encoder Layers
                                                            for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)                                              # Define Dropout Layer

    def forward(self, src, mask=None):          
        # 1. Embed Tokens and Add Positional Encoding
        x = self.dropout(self.token_embedding(src))                                     # Token Embeddings
        x = self.position_encoding(x)                                                   # Positional Encoding

        # 2. Pass through Encoder Layers
        for layer in self.layers:                                                       # Add Encoder Layers
            x = layer(x, mask)
        
        return x

## [1.5] Decoder Class

In [43]:
class Decoder(nn.Module):
    """Transformer decoder with multiple layers."""

    def __init__(self, vocab_size, d_model=512, num_layers=6, num_heads=8, d_ff=2048, dropout=0.1, max_len=5000):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)                                # Define Toeken Embedding
        self.position_encoding = PositionalEncoding(d_model, max_len, dropout)                  # Define Positional Encoding
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout)            # Define Decoder Layers
                                                                for _ in range(num_layers)])
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, enc_output, src_mask=None, tgt_mask=None):
        # 1. Embed Tokens and Positional Encoding
        x = self.dropout(self.token_embedding(tgt))                                             # Token Embeddings
        x = self.position_encoding(x)                                                           # Positional Encoding

        # 2. Pass through Decoder Layers
        for layer in self.layers:                                                               # Add Decoder Layers
            x = layer(x, enc_output, src_mask, tgt_mask)
        
        # 3. Final Linear Layer
        return self.fc_out(x)                                                                   # Add Final Output Layer

## [1.6] Transformer Class

In [57]:
class Transformer(lt.LightningModule):
    """Complete Transformer model for sequence-to-sequence tasks."""

    # 1. Initialization Function
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_layers=6, num_heads=8, d_ff=2048,
                        dropout=0.1, max_len=5000, lr=0.0001, warmup_steps=4000):
        
        super().__init__()
        self.save_hyperparameters()

        self.encoder = Encoder(src_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)      # Define Encoder
        self.decoder = Decoder(tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)      # Define Decoder

        self.loss_fn = nn.CrossEntropyLoss(ignore_index=0)                                                  # Define Loss Function
        self.lr = lr        
        self.warmup_steps = warmup_steps

        # Initialize Parameters with Glorot Initialization
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        
    # 2. Forward Function
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        enc_output = self.encoder(src, src_mask)
        output = self.decoder(tgt, enc_output, src_mask, tgt_mask)
        return output
    
    # 3. Training Function
    def training_step(self, batch, batch_idx):
        src, tgt = batch
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]

        # Create Masks
        src_mask = self._make_src_mask(src)
        tgt_mask = self._make_tgt_mask(tgt_input)

        # Forward Pass
        output = self(src, tgt_input, src_mask, tgt_mask)

        # Calculate Loss
        loss = self.loss_fn(output.reshape(-1, output.shape[-1]), tgt_output.reshape(-1))

        # Logging
        self.log('train_loss', loss, prog_bar=True)

        return loss
    
    # 4. Validation Function
    def validation_step(self, batch, batch_idx):
        src, tgt = batch
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]

        # Create Masks
        src_mask = self._make_src_mask(src)
        tgt_mask = self._make_tgt_mask(tgt_input)

        # Forward Pass
        output = self(src, tgt_input, src_mask, tgt_mask)

        # Calculate Loss
        loss = self.loss_fn(output.reshape(-1, output.shape[-1]), tgt_output.reshape(-1))

        # Logging
        self.log('val_loss', loss, prog_bar=True)
    
    # 5. Optimizer Configuration Function
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.lr, betas=(0.9, 0.98), eps=1e-9)
        scheduler = {
                        'scheduler': torch.optim.lr_scheduler.LambdaLR(
                                                                        optimizer,
                                                                        lr_lambda=lambda step:self._get_lr_multiplier(step)
                                                                    ),
                        'interval': 'step',
                        'frequency': 1
                    }
        return [optimizer], [scheduler]
    
    # 6. Learning Rate Warmup Function
    def _get_lr_multiplier(self, step):
        step = max(step, 1)                                                                     # Avoid division by zero
        return min(step ** (-0.5), step * self.warmup_steps ** (-1.5))
    
    # 7. Source Sequences Mask Function
    def _make_src_mask(self, src):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        return src_mask
    
    # 8. Target Sequences Mask Function
    def _make_tgt_mask(self, tgt):
        tgt_pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)
        tgt_len = tgt.size(1)
        tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()     # Lower Traingular Matrix
        return tgt_pad_mask & tgt_sub_mask
    
    # 9. Generate Sequences Function
    def generate(self, src, max_len=50, sos_idx=1, eos_idx=2):
        self.eval()
        with torch.no_grad():                                                       # Ensures Inference Mode
            src_mask = self._make_src_mask(src)                                     # Create Mask for the Input
            enc_output = self.encoder(src, src_mask)                                # Run Encoder on Source Input

            # Initialize with SOS Token
            tgt = torch.ones(src.size(0), 1).fill_(sos_idx).type_as(src)

            for _ in range(max_len - 1):                                            # Generate One Token at a Time
                tgt_mask = self._make_tgt_mask(tgt)                                 # Decoder Mask
                output = self.decoder(tgt, enc_output, src_mask, tgt_mask)          # Run Decoder

                # Get Last Predicted Token
                next_token = output[:, -1, :].argmax(-1).unsqueeze(1)               # Get Next Token
                tgt = torch.cat([tgt, next_token], dim=1)                           # Append Token to Sequences

                # Stop - If all Sequences in Batch Predicted EOS
                if (next_token==eos_idx).all():                                     # Early Stopping
                    break
        
        return tgt

## [1.7] Data Loaders

In [58]:
def create_dataloaders(src_vocab, tgt_vocab, batch_size=32, shuffle=True):
    """Create training and validation dataloaders."""

    train_inputs = torch.tensor(
                                [
                                    [
                                        src_vocab['<SOS>'], src_vocab['i'], src_vocab['love'], src_vocab['spanish']
                                    ],
                                    
                                    [
                                        src_vocab['<SOS>'], src_vocab['i'], src_vocab['speak'], src_vocab['spanish']
                                    ]
                                ]
                            )
    train_outputs = torch.tensor(
                                [
                                    [
                                        tgt_vocab['<SOS>'], tgt_vocab['yo'], tgt_vocab['amo'], tgt_vocab['espanol'], tgt_vocab['<EOS>']
                                    ],

                                    [
                                        tgt_vocab['<SOS>'], tgt_vocab['yo'], tgt_vocab['hablo'], tgt_vocab['espanol'], tgt_vocab['<EOS>']
                                    ]
                                 ]
                            )
    
    val_inputs = train_inputs.clone()
    val_outputs = train_outputs.clone()

    train_dataset = TensorDataset(train_inputs, train_outputs)
    val_dataset = TensorDataset(val_inputs, val_outputs)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader

# [2] Build the Encoder-Decoder Model

In [59]:
# 1. Define Vocabulary
input_vocab = {
                "<PAD>": 0,
                "<SOS>": 1,
                "i": 2,
                "love": 3,
                "speak": 4,
                "spanish": 5
            }
output_vocab = {
                "<PAD>": 0,
                "<SOS>": 1,
                "<EOS>": 2,
                "yo": 3,
                "amo": 4,
                "hablo": 5,
                "espanol": 6
            }

In [60]:
# 2. Create Dataloaders
train_loader, val_loader = create_dataloaders(input_vocab, output_vocab, batch_size=2)

# [3] Initialize and Train Model

In [61]:
# 1. Initialize the Model
model = Transformer(
                        src_vocab_size=len(input_vocab),
                        tgt_vocab_size=len(output_vocab),

                        d_model=64,
                        num_layers=1,
                        num_heads=4,
                        d_ff=512,

                        dropout=0.1,
                        max_len=10,
                        lr=0.001,
                        warmup_steps=20
                )

In [66]:
# 2. Train the Model
trainer = lt.Trainer(
                        max_epochs=150,
                        accelerator='auto',
                        devices=1 if torch.cuda.is_available() else None,
                        enable_progress_bar=True,
                        logger=True
                    )

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [67]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params | Mode
----------------------------------------------------
0 | encoder | Encoder          | 83.4 K | eval
1 | decoder | Decoder          | 100 K  | eval
2 | loss_fn | CrossEntropyLoss | 0      | eval
----------------------------------------------------
184 K     Trainable params
0         Non-trainable params
184 K     Total params
0.736     Total estimated model params size (MB)
0         Modules in train mode
52        Modules in eval mode


Epoch 149: 100%|██████████| 1/1 [00:00<00:00, 18.83it/s, v_num=55, train_loss=0.00105, val_loss=0.00104]

`Trainer.fit` stopped: `max_epochs=150` reached.


Epoch 149: 100%|██████████| 1/1 [00:00<00:00, 11.22it/s, v_num=55, train_loss=0.00105, val_loss=0.00104]


# 4. Translate

In [68]:
# Test generation
test_input = torch.tensor([
                            [
                                input_vocab['<SOS>'], input_vocab['i'], input_vocab['speak'], input_vocab['spanish']
                            ]
    
    ])


output = model.generate(test_input, sos_idx=output_vocab['<SOS>'], eos_idx=output_vocab['<EOS>'])

print("Generated output:", output)

Generated output: tensor([[1, 3, 5, 6, 2]])


In [69]:
# Step 1: Create reverse vocabulary mapping
output_id_to_token = {idx: token for token, idx in output_vocab.items()}

# Step 2: Get the first (and only) sentence output
output_tokens = output[0].tolist()

# Step 3: Convert token IDs to words (excluding special tokens if desired)
decoded_words = [output_id_to_token[token_id] for token_id in output_tokens if token_id not in {output_vocab['<SOS>'], output_vocab['<EOS>']}]

# Step 4: Print
print("Generated sentence:", " ".join(decoded_words))

Generated sentence: yo hablo espanol
