# ðŸ›  BERT+GPT Hybrid Transformer (D2L Transformer Chapter Pattern)

This notebook is based on the **Transformer Encoder/Decoder pattern** and **PositionalEncoding**  
introduced in the D2L "Transformer" chapter, but reconfigured into a **BERT-style encoder**  
and a **GPT-style decoder**.

> âš  **Important:**  
> - Because the dataset is very small and the number of training epochs is limited, the actual translation quality will be poor.  
> - The purpose is to understand the architecture and training procedure, not to achieve production-quality translation.  
> - High-quality translation would require a much larger dataset and extended training.

## Learning Goals
- Implement Transformer encoder and decoder from scratch
- Apply masking and padding correctly
- Train with Teacher Forcing
- Observe epoch-by-epoch loss changes

In [None]:
# ==============================================================================
# 1. Environment Setup
# ==============================================================================
# !pip install numpy==1.26.4
# !pip install d2l --no-deps

import torch
import torch.nn as nn
from d2l import torch as d2l
import random

# ==============================================================================
# 2. Helper Functions
# ==============================================================================

def safe_tokenize(sentence_tokens, vocab):
    """
    Converts a list of tokens to indices, falling back to '<unk>' for out-of-vocabulary tokens.
    """
    return [vocab[token] if token in vocab.token_to_idx else vocab['<unk>']
            for token in sentence_tokens]

def init_weights(module):
    """
    Initializes weights for different layers according to best practices for Transformers.
    """
    if isinstance(module, nn.Linear):
        nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, mean=0, std=0.02)
    elif isinstance(module, nn.LayerNorm):
        nn.init.ones_(module.weight)
        nn.init.zeros_(module.bias)

# ==============================================================================
# 3. Model Definition: Building a BERT-GPT Style Seq2Seq Model
# ==============================================================================

class BertStyleEncoder(nn.Module):
    """A BERT-style Transformer Encoder."""
    def __init__(self, vocab_size, d_model, num_layers, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = d2l.PositionalEncoding(d_model, dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=num_heads, dim_feedforward=ff_dim,
            dropout=dropout, activation='gelu', batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # Scale embedding by sqrt(d_model) as suggested in the "Attention Is All You Need" paper.
        x = self.embedding(src) * (self.embedding.embedding_dim ** 0.5)
        x = self.pos_encoding(x)
        return self.encoder(x, mask=src_mask, src_key_padding_mask=src_key_padding_mask)

class GptStyleDecoder(nn.Module):
    """A GPT-style Transformer Decoder."""
    def __init__(self, vocab_size, d_model, num_layers, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = d2l.PositionalEncoding(d_model, dropout)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=num_heads, dim_feedforward=ff_dim,
            dropout=dropout, activation='gelu', batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

        # Weight Tying: Share weights between the embedding layer and the final output layer.
        self.fc_out.weight = self.embedding.weight

    def forward(self, tgt, memory, tgt_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        x = self.embedding(tgt) * (self.embedding.embedding_dim ** 0.5)
        x = self.pos_encoding(x)
        x = self.decoder(
            x, memory, tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask
        )
        return self.fc_out(x)

class TransformerSeq2Seq(nn.Module):
    """The main Sequence-to-Sequence model combining the encoder and decoder."""
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_layers=6, num_heads=8, ff_dim=2048, dropout=0.1):
        super().__init__()
        self.encoder = BertStyleEncoder(src_vocab_size, d_model, num_layers, num_heads, ff_dim, dropout)
        self.decoder = GptStyleDecoder(tgt_vocab_size, d_model, num_layers, num_heads, ff_dim, dropout)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None,
                src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        memory = self.encoder(src, src_mask, src_key_padding_mask)
        output = self.decoder(tgt, memory, tgt_mask, tgt_key_padding_mask, memory_key_padding_mask)
        return output

# ==============================================================================
# 4. Inference Function: Beam Search
# ==============================================================================

def predict_translation_beam_search(model, sentence_tokens, src_vocab, tgt_vocab, device, max_len=20, beam_size=3):
    """
    Performs inference using beam search to generate a translation.
    """
    model.eval() # Set model to evaluation mode

    # Preprocess the source sentence
    src_indices = safe_tokenize(sentence_tokens, src_vocab) + [src_vocab['<eos>']]
    src_tensor = torch.tensor(src_indices, device=device).unsqueeze(0)
    src_key_padding_mask = (src_tensor == src_vocab['<pad>']).to(device)

    with torch.no_grad():
        # Encoder output is calculated only once
        memory = model.encoder(src_tensor, src_key_padding_mask=src_key_padding_mask)

        # Initialize beams: a list of tuples (sequence, score)
        beams = [(torch.tensor([[tgt_vocab['<bos>']]], device=device), 0.0)]

        # Autoregressive decoding loop
        for _ in range(max_len):
            new_beams = []
            for seq, score in beams:
                # If a beam has already ended, add it to the results and continue
                if seq[0, -1].item() == tgt_vocab['<eos>']:
                    new_beams.append((seq, score))
                    continue

                # Prepare decoder inputs
                tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq.size(1)).to(device)

                # Get model output and log probabilities
                out = model.decoder(seq, memory, tgt_mask=tgt_mask, memory_key_padding_mask=src_key_padding_mask)
                log_probs = torch.log_softmax(out[:, -1, :], dim=-1)

                # Get the top k candidates for the next token
                topk_log_probs, topk_indices = torch.topk(log_probs, beam_size, dim=-1)

                # Expand the beam with new candidates
                for k in range(beam_size):
                    next_seq = torch.cat([seq, topk_indices[:, k].unsqueeze(0)], dim=1)
                    next_score = score + topk_log_probs[0, k].item()
                    new_beams.append((next_seq, next_score))

            # Prune the beams: keep only the top `beam_size` best sequences
            beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]

            # Stop if the best beam has reached the end token
            if beams[0][0][0, -1].item() == tgt_vocab['<eos>']:
                break

    # Post-process the best sequence to get the final translation
    best_seq_indices = beams[0][0][0].tolist()
    # Remove <bos> and <eos> tokens
    best_seq_indices = best_seq_indices[1:]
    if tgt_vocab['<eos>'] in best_seq_indices:
        best_seq_indices = best_seq_indices[:best_seq_indices.index(tgt_vocab['<eos>'])]

    return ' '.join([tgt_vocab.to_tokens(idx) for idx in best_seq_indices])

# ==============================================================================
# 5. Training and Execution
# ==============================================================================

# Load the dataset (e.g., English to French translation)
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size=64, num_steps=20)

def train_loop(model, data_iter, lr, num_epochs, tgt_vocab, device, src_vocab):
    """The main training loop for the Seq2Seq model."""
    # Use CrossEntropyLoss, ignoring the padding token index for loss calculation.
    loss_fn = nn.CrossEntropyLoss(ignore_index=tgt_vocab['<pad>'])
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.to(device)

    for epoch in range(num_epochs):
        model.train() # Set model to training mode
        total_loss = 0
        batch_count = 0

        for batch in data_iter:
            src, _, tgt, _ = batch
            src, tgt = src.to(device), tgt.to(device)
            optimizer.zero_grad()

            # Prepare decoder input (teacher forcing) and target
            dec_input = tgt[:, :-1]  # Exclude the last token (<eos>)
            dec_target = tgt[:, 1:]   # Exclude the first token (<bos>)

            # Generate necessary masks
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(dec_input.size(1)).to(device)
            src_key_padding_mask = (src == src_vocab['<pad>']).to(device)
            tgt_key_padding_mask = (dec_input == tgt_vocab['<pad>']).to(device)

            # Forward pass
            output = model(src, dec_input,
                           tgt_mask=tgt_mask,
                           src_key_padding_mask=src_key_padding_mask,
                           tgt_key_padding_mask=tgt_key_padding_mask,
                           memory_key_padding_mask=src_key_padding_mask)

            # Calculate loss
            loss = loss_fn(output.reshape(-1, output.shape[-1]), dec_target.reshape(-1))

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            batch_count += 1

        avg_loss = total_loss / batch_count
        print(f"Epoch {epoch+1}, Loss {avg_loss:.4f}")

        # ----- Display a translation example after each epoch -----
        test_sentence_en = "i ate a sandwich .".split()
        translated_sentence = predict_translation_beam_search(model, test_sentence_en, src_vocab, tgt_vocab, device)
        print(f"  Input (en): {' '.join(test_sentence_en)}")
        print(f"  Translation (fr): {translated_sentence}\n")

# ----- Main Execution Block -----
if __name__ == '__main__':
    device = d2l.try_gpu()
    model = TransformerSeq2Seq(len(src_vocab), len(tgt_vocab))
    model.apply(init_weights)
    train_loop(model, train_iter, lr=0.001, num_epochs=100,
                 tgt_vocab=tgt_vocab, device=device, src_vocab=src_vocab)

Epoch 1, Loss 8.2201
  Input (en): i ate a sandwich .
  Translation (fr): suis suis suis suis suis suis suis suis suis suis suis suis suis suis suis suis suis suis suis suis

Epoch 2, Loss 3.4390
  Input (en): i ate a sandwich .
  Translation (fr): 

Epoch 3, Loss 2.9938
  Input (en): i ate a sandwich .
  Translation (fr): 

Epoch 4, Loss 2.9396
  Input (en): i ate a sandwich .
  Translation (fr): 

Epoch 5, Loss 2.9198
  Input (en): i ate a sandwich .
  Translation (fr): 

Epoch 6, Loss 2.8560
  Input (en): i ate a sandwich .
  Translation (fr): 

Epoch 7, Loss 2.8916
  Input (en): i ate a sandwich .
  Translation (fr): 

Epoch 8, Loss 2.9152
  Input (en): i ate a sandwich .
  Translation (fr): 

Epoch 9, Loss 2.8922
  Input (en): i ate a sandwich .
  Translation (fr): 

Epoch 10, Loss 2.8805
  Input (en): i ate a sandwich .
  Translation (fr): 

Epoch 11, Loss 2.8835
  Input (en): i ate a sandwich .
  Translation (fr): 

Epoch 12, Loss 2.8621
  Input (en): i ate a sandwich .
  Transl