<a href="https://colab.research.google.com/github/vineelkondapalli/multi30k_transformer/blob/main/firsttransformer_multi30k.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:
# !pip install torch==2.3.0 torchtext==0.18.0 torchvision --upgrade
# !pip install -U spacy
# !pip install datasets
# !python -m spacy download en_core_web_sm
# !python -m spacy download de_core_news_sm
# !pip install transformers



In [20]:
# =============================================================================
# 1. IMPORTS AND SETUP
# =============================================================================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from datasets import load_dataset
from torchtext.vocab import build_vocab_from_iterator

import spacy
import math
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# =============================================================================
# 2. DATA PIPELINE
# =============================================================================
# Load tokenizers
spacy_de = spacy.load('de_core_news_sm')
spacy_en = spacy.load('en_core_web_sm')

def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

# Load the dataset
multi30k_dataset = load_dataset("bentrevett/multi30k")

# Build vocabularies
def yield_tokens(data_iter, language):
    language_tokenizers = {'de': tokenize_de, 'en': tokenize_en}
    for data_sample in data_iter:
        yield language_tokenizers[language](data_sample[language])

UNK_IDX, PAD_IDX, SOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ['<unk>', '<pad>', '<sos>', '<eos>']

SRC_VOCAB = build_vocab_from_iterator(yield_tokens(multi30k_dataset['train'], 'de'), min_freq=1, specials=special_symbols, special_first=True)
SRC_VOCAB.set_default_index(UNK_IDX)

TRG_VOCAB = build_vocab_from_iterator(yield_tokens(multi30k_dataset['train'], 'en'), min_freq=1, specials=special_symbols, special_first=True)
TRG_VOCAB.set_default_index(UNK_IDX)

# Define DataLoader and collate_fn
def text_transform(tokenizer, vocab, sos_idx, eos_idx):
    def transform(text_sample):
        tokens = tokenizer(text_sample.rstrip("\n"))
        return torch.cat((torch.tensor([sos_idx]), torch.tensor(vocab(tokens)), torch.tensor([eos_idx])))
    return transform

BATCH_SIZE = 128

def collate_fn(batch):
    src_batch, trg_batch = [], []
    for sample in batch:
        # The fix is to add SOS_IDX and EOS_IDX to the calls below
        src_batch.append(text_transform(tokenize_de, SRC_VOCAB, SOS_IDX, EOS_IDX)(sample['de']))
        trg_batch.append(text_transform(tokenize_en, TRG_VOCAB, SOS_IDX, EOS_IDX)(sample['en']))

    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=PAD_IDX)
    trg_batch = pad_sequence(trg_batch, batch_first=True, padding_value=PAD_IDX)

    return src_batch.to(device), trg_batch.to(device)

# --- IMPORTANT: Re-run the DataLoader definitions after changing the function ---
train_dataloader = DataLoader(multi30k_dataset['train'], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(multi30k_dataset['validation'], batch_size=BATCH_SIZE, collate_fn=collate_fn)
test_dataloader = DataLoader(multi30k_dataset['test'], batch_size=BATCH_SIZE, collate_fn=collate_fn)

print("Data Pipeline Ready.")

# =============================================================================
# 3. MODEL DEFINITION
# =============================================================================
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        assert hid_dim % n_heads == 0
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        Q, K, V = self.fc_q(query), self.fc_k(key), self.fc_v(value)
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        if mask is not None: energy = energy.masked_fill(mask == 0, -1e10)
        attention = torch.softmax(energy, dim=-1)
        x = torch.matmul(self.dropout(attention), V)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, -1, self.hid_dim)
        x = self.fc_o(x)
        return x, attention

class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = self.dropout(torch.relu(self.fc_1(x)))
        x = self.fc_2(x)
        return x

class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.self_attention_layer_norm = nn.LayerNorm(hid_dim)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.feedforward_layer_norm = nn.LayerNorm(hid_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, src, src_mask):
        _src, _ = self.self_attention(src, src, src, src_mask)
        src = self.self_attention_layer_norm(src + self.dropout(_src))
        _src = self.positionwise_feedforward(src)
        src = self.feedforward_layer_norm(src + self.dropout(_src))
        return src

class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, device, max_length=200):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
    def forward(self, src, src_mask):
        batch_size, src_len = src.shape
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
        for layer in self.layers:
            src = layer(src, src_mask)
        return src

class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.self_attention_layer_norm = nn.LayerNorm(hid_dim)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.encoder_attention_layer_norm = nn.LayerNorm(hid_dim)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.feedforward_layer_norm = nn.LayerNorm(hid_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, trg, enc_src, trg_mask, src_mask):
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        trg = self.self_attention_layer_norm(trg + self.dropout(_trg))
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        trg = self.encoder_attention_layer_norm(trg + self.dropout(_trg))
        _trg = self.positionwise_feedforward(trg)
        trg = self.feedforward_layer_norm(trg + self.dropout(_trg))
        return trg, attention

class Decoder(nn.Module):
    def __init__(self, output_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, device, max_length=200):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
    def forward(self, trg, enc_src, trg_mask, src_mask):
        batch_size, trg_len = trg.shape
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        output = self.fc_out(trg)
        return output, attention

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask
    def make_trg_mask(self, trg):
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=self.device)).bool()
        trg_mask = trg_pad_mask & trg_sub_mask
        return trg_mask
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        return output, attention

print("All model classes defined.")

# =============================================================================
# 4. INSTANTIATION AND TRAINING SETUP
# =============================================================================
INPUT_DIM = len(SRC_VOCAB)
OUTPUT_DIM = len(TRG_VOCAB)
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.3
DEC_DROPOUT = 0.3

enc = Encoder(INPUT_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, device)
dec = Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, device)

net = Seq2Seq(enc, dec, PAD_IDX, PAD_IDX, device).to(device)

def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)
net.apply(initialize_weights)

NUM_EPOCHS = 20
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX, label_smoothing=.1)

print("Model instantiated and ready for training.")

# =============================================================================
# 5. TRAINING LOOP
# =============================================================================
def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(iterator):
        src, trg = batch
        optimizer.zero_grad()
        output, _ = model(src, trg[:,:-1])
        output_dim = output.shape[-1]
        output = output.contiguous().view(-1, output_dim)
        trg = trg[:,1:].contiguous().view(-1)
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src, trg = batch
            output, _ = model(src, trg[:,:-1])
            output_dim = output.shape[-1]
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(iterator)

CLIP = 1
best_valid_loss = float('inf')

EARLY_STOPPING_PATIENCE = 5
patience_counter = 0

print("\nStarting training...")
for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    train_loss = train(net, train_dataloader, optimizer, criterion, CLIP)
    valid_loss = evaluate(net, valid_dataloader, criterion)
    end_time = time.time()

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(net.state_dict(), 'transformer-model.pt')
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= EARLY_STOPPING_PATIENCE:
        print("Early stopping triggered.")
        break

    scheduler.step()

    print(f'Epoch: {epoch+1:02} | Time: {end_time - start_time:.0f}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

print("\nFinished Training.")

Using device: cuda
Data Pipeline Ready.
All model classes defined.
Model instantiated and ready for training.

Starting training...
Epoch: 01 | Time: 9s
	Train Loss: 5.332 | Train PPL: 206.892
	 Val. Loss: 4.487 |  Val. PPL:  88.818
Epoch: 02 | Time: 9s
	Train Loss: 4.329 | Train PPL:  75.851
	 Val. Loss: 4.107 |  Val. PPL:  60.750
Epoch: 03 | Time: 9s
	Train Loss: 3.965 | Train PPL:  52.728
	 Val. Loss: 3.808 |  Val. PPL:  45.080
Epoch: 04 | Time: 9s
	Train Loss: 3.692 | Train PPL:  40.137
	 Val. Loss: 3.613 |  Val. PPL:  37.072
Epoch: 05 | Time: 9s
	Train Loss: 3.504 | Train PPL:  33.250
	 Val. Loss: 3.510 |  Val. PPL:  33.433
Epoch: 06 | Time: 9s
	Train Loss: 3.380 | Train PPL:  29.372
	 Val. Loss: 3.460 |  Val. PPL:  31.808
Epoch: 07 | Time: 9s
	Train Loss: 3.292 | Train PPL:  26.888
	 Val. Loss: 3.428 |  Val. PPL:  30.829
Epoch: 08 | Time: 9s
	Train Loss: 3.229 | Train PPL:  25.258
	 Val. Loss: 3.394 |  Val. PPL:  29.773
Epoch: 09 | Time: 9s
	Train Loss: 3.180 | Train PPL:  24.048

In [23]:
def translate_sentence(sentence, src_vocab, trg_vocab, model, device, max_len=50):
    model.eval() # Set the model to evaluation mode

    # Tokenize the source sentence if it's not already a list
    if isinstance(sentence, str):
        tokens = [token.lower() for token in tokenize_de(sentence)]
    else:
        tokens = [token.lower() for token in sentence]

    # Add <sos> and <eos> tokens
    tokens = [SRC_VOCAB.lookup_token(SOS_IDX)] + tokens + [SRC_VOCAB.lookup_token(EOS_IDX)]

    # Convert tokens to numerical indices
    src_indexes = [SRC_VOCAB[token] for token in tokens]

    # Convert to a tensor and add a batch dimension
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)

    # Create the source mask
    src_mask = model.make_src_mask(src_tensor)

    # Pass the source through the encoder
    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)

    # Start the decoder output with the <sos> token
    trg_indexes = [TRG_VOCAB[special_symbols[SOS_IDX]]]

    # Loop to generate the translation word by word
    for i in range(max_len):
        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)

        with torch.no_grad():
            output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)

        # Get the predicted next token (the last one in the sequence)
        pred_token = output.argmax(2)[:,-1].item()

        trg_indexes.append(pred_token)

        # If the model predicts the <eos> token, stop
        if pred_token == TRG_VOCAB[special_symbols[EOS_IDX]]:
            break

    # Convert the output indices back to tokens
    trg_tokens = [TRG_VOCAB.lookup_token(i) for i in trg_indexes]

    # Return the translation (without the <sos> token)
    return trg_tokens[1:], attention

# --- Let's try it out! ---

# 1. Load the weights from your best model
# Make sure you've saved a model from your best epoch
net.load_state_dict(torch.load('transformer-model.pt'))

# 2. Get an example from the test set using the new dataset object
example_idx = 234
sample = multi30k_dataset['test'][example_idx]
src_text = sample['de']
trg_text = sample['en']

# Tokenize the source text for printing and for the function
src_tokens = tokenize_de(src_text)

print(f'Source Sentence: {" ".join(src_tokens)}')
print(f'Target Sentence: {trg_text}')

# 3. Translate the sentence
# The function expects the sentence as a list of tokens
translation, attention = translate_sentence(src_tokens, SRC_VOCAB, TRG_VOCAB, net, device)

print(f'Predicted Translation: {" ".join(translation)}')

Source Sentence: Zwei braune Hunde spielen grob miteinander .
Target Sentence: Two brown dogs playing in a rough manner.
Predicted Translation: Two brown are playing in a wooden floor . <eos>


introducing beam translation(picking best result out of top k translations)

In [24]:
def beam_search_translate(sentence, src_vocab, trg_vocab, model, device, beam_width=3, max_len=50):
    model.eval()

    if isinstance(sentence, str):
        tokens = [token.lower() for token in tokenize_de(sentence)]
    else:
        tokens = [token.lower() for token in sentence]

    tokens = [SRC_VOCAB.lookup_token(SOS_IDX)] + tokens + [SRC_VOCAB.lookup_token(EOS_IDX)]
    src_indexes = [SRC_VOCAB[token] for token in tokens]
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
    src_mask = model.make_src_mask(src_tensor)

    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)

    # Start with a single beam: ([<sos>], 0 score)
    beams = [([TRG_VOCAB[special_symbols[SOS_IDX]]], 0.0)]
    completed_beams = []

    for _ in range(max_len):
        new_beams = []
        for seq, score in beams:
            # If a beam has ended, add it to completed_beams and skip it
            if seq[-1] == TRG_VOCAB[special_symbols[EOS_IDX]]:
                completed_beams.append((seq, score))
                continue

            trg_tensor = torch.LongTensor(seq).unsqueeze(0).to(device)
            trg_mask = model.make_trg_mask(trg_tensor)

            with torch.no_grad():
                output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)

            # Get log probabilities for the next word
            pred_log_probs = F.log_softmax(output[:,-1], dim=-1)

            # Get the top 'k' next words and their log probabilities
            top_k_log_probs, top_k_indexes = torch.topk(pred_log_probs, beam_width)

            # Create new beams from the top k options
            for i in range(beam_width):
                new_seq = seq + [top_k_indexes[0][i].item()]
                new_score = score + top_k_log_probs[0][i].item()
                new_beams.append((new_seq, new_score))

        # If all beams have finished, we can stop early
        if not new_beams:
            break

        # Add any completed beams from this step
        completed_beams.extend([beam for beam in new_beams if beam[0][-1] == TRG_VOCAB[special_symbols[EOS_IDX]]])

        # Prune the new beams: keep only the top 'k' overall
        # Exclude beams that have just completed
        uncompleted_new_beams = [beam for beam in new_beams if beam[0][-1] != TRG_VOCAB[special_symbols[EOS_IDX]]]
        beams = sorted(uncompleted_new_beams, key=lambda x: x[1], reverse=True)[:beam_width]

        # If we have enough completed beams, we can stop
        if len(completed_beams) >= beam_width:
            break

    # If no beams completed, use the current best ones
    if not completed_beams:
        completed_beams = beams

    # Normalize scores by length and find the best one
    completed_beams.sort(key=lambda x: x[1]/len(x[0]), reverse=True)
    best_seq = completed_beams[0][0]

    trg_tokens = [TRG_VOCAB.lookup_token(i) for i in best_seq]

    return trg_tokens[1:], None # Attention is not returned in this simplified version

In [27]:
# Load the weights from your best model
net.load_state_dict(torch.load('transformer-model.pt'))

# Get an example from the test set
example_idx = 128
sample = multi30k_dataset['test'][example_idx]
src_text = sample['de']
trg_text = sample['en']
src_tokens = tokenize_de(src_text)

print(f'Source Sentence: {" ".join(src_tokens)}')
print(f'Target Sentence: {trg_text}\n')

# --- Greedy Search (beam_width = 1) ---
greedy_translation, _ = beam_search_translate(src_tokens, SRC_VOCAB, TRG_VOCAB, net, device, beam_width=1)
print(f'Greedy Translation: {" ".join(greedy_translation)}')

# --- Beam Search (beam_width = 3) ---
beam_translation, _ = beam_search_translate(src_tokens, SRC_VOCAB, TRG_VOCAB, net, device, beam_width=3)
print(f'Beam Search (k=3) Translation: {" ".join(beam_translation)}')

Source Sentence: Eine alte Frau sitzt an einem Webstuhl und stellt Stoff her .
Target Sentence: An old woman working at a loom making cloth.

Greedy Translation: An old worker is sitting on a metal ramp and an outdoor area . <eos>
Beam Search (k=3) Translation: An old worker is sitting on a ramp and an urban area <eos>
