In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import re
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from nltk.translate.bleu_score import sentence_bleu
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings 
warnings.filterwarnings('ignore')

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

Using device: mps


In [3]:
def clean_text(text):
    text = text.lower().strip()
    text = re.sub(r"[^a-zA-Z¿?¡!.']+", " ", text)
    return text

pairs = []

with open("spa.txt", encoding="utf-8") as f:
    for line in f:
        parts = line.strip().split("\t")
        if len(parts) >= 2:
            pairs.append((clean_text(parts[0]), clean_text(parts[1])))

pairs = pairs[:10000]
random.shuffle(pairs)

train_size = int(0.8 * len(pairs))
val_size = int(0.1 * len(pairs))

train_pairs = pairs[:train_size]
val_pairs = pairs[train_size:train_size+val_size]
test_pairs = pairs[train_size+val_size:]

print("Train:", len(train_pairs), "Test:", len(test_pairs))

Train: 8000 Test: 1000


In [4]:
class Vocab:
    def __init__(self, sentences):
        counter = Counter()
        for sent in sentences:
            counter.update(sent.split())

        self.itos = ["<pad>", "<sos>", "<eos>", "<unk>"]
        self.itos += list(counter.keys())
        self.stoi = {w:i for i,w in enumerate(self.itos)}

    def numericalize(self, text):
        return [self.stoi.get(w, self.stoi["<unk>"]) for w in text.split()]

src_vocab = Vocab([p[0] for p in train_pairs])
trg_vocab = Vocab([p[1] for p in train_pairs])

INPUT_DIM = len(src_vocab.itos)
OUTPUT_DIM = len(trg_vocab.itos)

print("Input vocab:", INPUT_DIM)
print("Output vocab:", OUTPUT_DIM)

Input vocab: 2931
Output vocab: 5122


In [5]:
class TranslationDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        return self.pairs[idx]

def collate_fn(batch):
    src_batch, trg_batch = [], []

    for src, trg in batch:
        src_tensor = torch.tensor(
            [src_vocab.stoi["<sos>"]] +
            src_vocab.numericalize(src) +
            [src_vocab.stoi["<eos>"]]
        )

        trg_tensor = torch.tensor(
            [trg_vocab.stoi["<sos>"]] +
            trg_vocab.numericalize(trg) +
            [trg_vocab.stoi["<eos>"]]
        )

        src_batch.append(src_tensor)
        trg_batch.append(trg_tensor)

    src_batch = nn.utils.rnn.pad_sequence(src_batch, padding_value=0)
    trg_batch = nn.utils.rnn.pad_sequence(trg_batch, padding_value=0)

    return src_batch.to(device), trg_batch.to(device)

train_loader = DataLoader(
    TranslationDataset(train_pairs),
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn
)

In [6]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hidden_dim)

    def forward(self, src):
        embedded = self.embedding(src)
        outputs, (hidden, cell) = self.lstm(embedded)
        return outputs, hidden, cell

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, input, hidden, cell):
        input = input.unsqueeze(0)
        embedded = self.embedding(input)
        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        prediction = self.fc(output.squeeze(0))
        return prediction, hidden, cell


In [7]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, trg, teacher_forcing_ratio=0.5):

        trg_len = trg.shape[0]
        batch_size = trg.shape[1]
        vocab_size = self.decoder.fc.out_features

        outputs = torch.zeros(trg_len, batch_size, vocab_size).to(device)

        _, hidden, cell = self.encoder(src)
        input = trg[0]

        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[t] = output

            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)

            input = trg[t] if teacher_force else top1

        return outputs

In [8]:
EMB_DIM = 256
HIDDEN_DIM = 512

vanilla_encoder = Encoder(INPUT_DIM, EMB_DIM, HIDDEN_DIM)
vanilla_decoder = Decoder(OUTPUT_DIM, EMB_DIM, HIDDEN_DIM)

vanilla_model = Seq2Seq(vanilla_encoder, vanilla_decoder).to(device)

vanilla_optimizer = optim.Adam(vanilla_model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=0)

In [9]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.W = nn.Linear(hidden_dim*2, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        src_len = encoder_outputs.shape[0]

        hidden = hidden[-1].unsqueeze(1).repeat(1, src_len, 1)
        encoder_outputs = encoder_outputs.permute(1,0,2)

        energy = torch.tanh(self.W(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = self.v(energy).squeeze(2)

        return torch.softmax(attention, dim=1)


class BahdanauDecoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hidden_dim, attention):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.lstm = nn.LSTM(emb_dim + hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim*2, output_dim)
        self.attention = attention

    def forward(self, input, hidden, cell, encoder_outputs):

        input = input.unsqueeze(0)
        embedded = self.embedding(input)

        attn_weights = self.attention(hidden, encoder_outputs)
        attn_weights = attn_weights.unsqueeze(1)

        encoder_outputs = encoder_outputs.permute(1,0,2)
        context = torch.bmm(attn_weights, encoder_outputs)
        context = context.permute(1,0,2)

        rnn_input = torch.cat((embedded, context), dim=2)
        output, (hidden, cell) = self.lstm(rnn_input, (hidden, cell))

        prediction = self.fc(
            torch.cat((output.squeeze(0), context.squeeze(0)), dim=1)
        )

        return prediction, hidden, cell, attn_weights

In [10]:
class LuongAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.W = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, hidden, encoder_outputs):
        hidden = hidden[-1].unsqueeze(2)
        encoder_outputs = encoder_outputs.permute(1,0,2)

        energy = torch.bmm(self.W(encoder_outputs), hidden).squeeze(2)
        return torch.softmax(energy, dim=1)


class LuongDecoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hidden_dim, attention):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim*2, output_dim)
        self.attention = attention

    def forward(self, input, hidden, cell, encoder_outputs):

        input = input.unsqueeze(0)
        embedded = self.embedding(input)

        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))

        attn_weights = self.attention(hidden, encoder_outputs)
        attn_weights = attn_weights.unsqueeze(1)

        encoder_outputs = encoder_outputs.permute(1,0,2)
        context = torch.bmm(attn_weights, encoder_outputs)

        prediction = self.fc(
            torch.cat((output.squeeze(0), context.squeeze(1)), dim=1)
        )

        return prediction, hidden, cell, attn_weights

In [11]:
def train_model(model, optimizer, model_name="Model"):
    model.train()
    epoch_loss = 0

    progress_bar = tqdm(train_loader, desc=f"{model_name}", leave=False)

    for src, trg in progress_bar:

        optimizer.zero_grad()

        # Vanilla
        if isinstance(model.decoder, Decoder):
            output = model(src, trg)

        # Attention Models
        else:
            trg_len = trg.shape[0]
            batch_size = trg.shape[1]
            vocab_size = OUTPUT_DIM
            outputs = torch.zeros(trg_len, batch_size, vocab_size).to(device)

            encoder_outputs, hidden, cell = model.encoder(src)
            input = trg[0]

            for t in range(1, trg_len):
                output_step, hidden, cell, _ = model.decoder(
                    input, hidden, cell, encoder_outputs
                )
                outputs[t] = output_step

                teacher_force = random.random() < 0.5
                top1 = output_step.argmax(1)
                input = trg[t] if teacher_force else top1

            output = outputs

        output_dim = output.shape[-1]
        output = output[1:].reshape(-1, output_dim)
        trg = trg[1:].reshape(-1)

        loss = criterion(output, trg)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        progress_bar.set_postfix(loss=loss.item())

    return epoch_loss / len(train_loader)

In [12]:
# Bahdanau
bahdanau_encoder = Encoder(INPUT_DIM, EMB_DIM, HIDDEN_DIM)
bahdanau_attention = BahdanauAttention(HIDDEN_DIM)
bahdanau_decoder = BahdanauDecoder(
    OUTPUT_DIM, EMB_DIM, HIDDEN_DIM, bahdanau_attention
)
bahdanau_model = Seq2Seq(bahdanau_encoder, bahdanau_decoder).to(device)
bahdanau_optimizer = optim.Adam(bahdanau_model.parameters())

# Luong
luong_encoder = Encoder(INPUT_DIM, EMB_DIM, HIDDEN_DIM)
luong_attention = LuongAttention(HIDDEN_DIM)
luong_decoder = LuongDecoder(
    OUTPUT_DIM, EMB_DIM, HIDDEN_DIM, luong_attention
)
luong_model = Seq2Seq(luong_encoder, luong_decoder).to(device)
luong_optimizer = optim.Adam(luong_model.parameters())

In [16]:
EPOCHS = 50

for epoch in range(EPOCHS):

    print(f"\nEpoch {epoch+1}:")
    vanilla_loss = train_model(vanilla_model, vanilla_optimizer, "Vanilla" )
    bahdanau_loss = train_model(bahdanau_model, bahdanau_optimizer, "Bahdanau")
    luong_loss = train_model(luong_model, luong_optimizer, "Luong")

    print(f"Vanilla Loss:   {vanilla_loss:.4f}")
    print(f"Bahdanau Loss:  {bahdanau_loss:.4f}")
    print(f"Luong Loss:     {luong_loss:.4f}")


Epoch 1:


                                                                       

Vanilla Loss:   0.4997
Bahdanau Loss:  0.5395
Luong Loss:     0.5402

Epoch 2:


                                                                       

Vanilla Loss:   0.4851
Bahdanau Loss:  0.5272
Luong Loss:     0.5332

Epoch 3:


                                                                       

Vanilla Loss:   0.4673
Bahdanau Loss:  0.5036
Luong Loss:     0.5237

Epoch 4:


                                                                       

Vanilla Loss:   0.4743
Bahdanau Loss:  0.4935
Luong Loss:     0.5072

Epoch 5:


                                                                       

Vanilla Loss:   0.4650
Bahdanau Loss:  0.4809
Luong Loss:     0.4999

Epoch 6:


                                                                       

Vanilla Loss:   0.4507
Bahdanau Loss:  0.4614
Luong Loss:     0.4998

Epoch 7:


                                                                       

Vanilla Loss:   0.4500
Bahdanau Loss:  0.4546
Luong Loss:     0.4815

Epoch 8:


                                                                       

Vanilla Loss:   0.4530
Bahdanau Loss:  0.4445
Luong Loss:     0.4728

Epoch 9:


                                                                       

Vanilla Loss:   0.4363
Bahdanau Loss:  0.4308
Luong Loss:     0.4744

Epoch 10:


                                                                       

Vanilla Loss:   0.4272
Bahdanau Loss:  0.4210
Luong Loss:     0.4466

Epoch 11:


                                                                       

Vanilla Loss:   0.4310
Bahdanau Loss:  0.4260
Luong Loss:     0.4466

Epoch 12:


                                                                       

Vanilla Loss:   0.4157
Bahdanau Loss:  0.4184
Luong Loss:     0.4373

Epoch 13:


                                                                       

Vanilla Loss:   0.4133
Bahdanau Loss:  0.4123
Luong Loss:     0.4271

Epoch 14:


                                                                       

Vanilla Loss:   0.4165
Bahdanau Loss:  0.4037
Luong Loss:     0.4202

Epoch 15:


                                                                       

Vanilla Loss:   0.4007
Bahdanau Loss:  0.4069
Luong Loss:     0.4272

Epoch 16:


                                                                       

Vanilla Loss:   0.4115
Bahdanau Loss:  0.3857
Luong Loss:     0.4155

Epoch 17:


                                                                       

Vanilla Loss:   0.4113
Bahdanau Loss:  0.3814
Luong Loss:     0.4090

Epoch 18:


                                                                       

Vanilla Loss:   0.3924
Bahdanau Loss:  0.3700
Luong Loss:     0.3994

Epoch 19:


                                                                        

Vanilla Loss:   0.3893
Bahdanau Loss:  0.3737
Luong Loss:     0.3953

Epoch 20:


                                                                       

Vanilla Loss:   0.3837
Bahdanau Loss:  0.3647
Luong Loss:     0.3926

Epoch 21:


                                                                       

Vanilla Loss:   0.3735
Bahdanau Loss:  0.3641
Luong Loss:     0.3892

Epoch 22:


                                                                       

Vanilla Loss:   0.3805
Bahdanau Loss:  0.3636
Luong Loss:     0.3745

Epoch 23:


                                                                       

Vanilla Loss:   0.3763
Bahdanau Loss:  0.3575
Luong Loss:     0.3734

Epoch 24:


                                                                       

Vanilla Loss:   0.3757
Bahdanau Loss:  0.3551
Luong Loss:     0.3679

Epoch 25:


                                                                       

Vanilla Loss:   0.3708
Bahdanau Loss:  0.3486
Luong Loss:     0.3639

Epoch 26:


                                                                        

Vanilla Loss:   0.3739
Bahdanau Loss:  0.3396
Luong Loss:     0.3594

Epoch 27:


                                                                        

Vanilla Loss:   0.3673
Bahdanau Loss:  0.3416
Luong Loss:     0.3557

Epoch 28:


                                                                       

Vanilla Loss:   0.3641
Bahdanau Loss:  0.3349
Luong Loss:     0.3510

Epoch 29:


                                                                       

Vanilla Loss:   0.3588
Bahdanau Loss:  0.3308
Luong Loss:     0.3514

Epoch 30:


                                                                       

Vanilla Loss:   0.3550
Bahdanau Loss:  0.3283
Luong Loss:     0.3464

Epoch 31:


                                                                       

Vanilla Loss:   0.3490
Bahdanau Loss:  0.3293
Luong Loss:     0.3482

Epoch 32:


                                                                       

Vanilla Loss:   0.3497
Bahdanau Loss:  0.3322
Luong Loss:     0.3463

Epoch 33:


                                                                       

Vanilla Loss:   0.3401
Bahdanau Loss:  0.3290
Luong Loss:     0.3426

Epoch 34:


                                                                       

Vanilla Loss:   0.3378
Bahdanau Loss:  0.3232
Luong Loss:     0.3439

Epoch 35:


                                                                       

Vanilla Loss:   0.3415
Bahdanau Loss:  0.3200
Luong Loss:     0.3405

Epoch 36:


                                                                       

Vanilla Loss:   0.3352
Bahdanau Loss:  0.3146
Luong Loss:     0.3308

Epoch 37:


                                                                       

Vanilla Loss:   0.3347
Bahdanau Loss:  0.3095
Luong Loss:     0.3213

Epoch 38:


                                                                       

Vanilla Loss:   0.3416
Bahdanau Loss:  0.3126
Luong Loss:     0.3211

Epoch 39:


                                                                       

Vanilla Loss:   0.3344
Bahdanau Loss:  0.3070
Luong Loss:     0.3209

Epoch 40:


                                                                       

Vanilla Loss:   0.3264
Bahdanau Loss:  0.3022
Luong Loss:     0.3158

Epoch 41:


                                                                       

Vanilla Loss:   0.3153
Bahdanau Loss:  0.3037
Luong Loss:     0.3152

Epoch 42:


                                                                       

Vanilla Loss:   0.3233
Bahdanau Loss:  0.2973
Luong Loss:     0.3152

Epoch 43:


                                                                       

Vanilla Loss:   0.3165
Bahdanau Loss:  0.3049
Luong Loss:     0.3092

Epoch 44:


                                                                        

Vanilla Loss:   0.3237
Bahdanau Loss:  0.3001
Luong Loss:     0.3111

Epoch 45:


                                                                       

Vanilla Loss:   0.3150
Bahdanau Loss:  0.2900
Luong Loss:     0.3173

Epoch 46:


                                                                       

Vanilla Loss:   0.3127
Bahdanau Loss:  0.2993
Luong Loss:     0.3045

Epoch 47:


                                                                       

Vanilla Loss:   0.3125
Bahdanau Loss:  0.2925
Luong Loss:     0.3031

Epoch 48:


                                                                       

Vanilla Loss:   0.3185
Bahdanau Loss:  0.2888
Luong Loss:     0.3072

Epoch 49:


                                                                       

Vanilla Loss:   0.3117
Bahdanau Loss:  0.2965
Luong Loss:     0.3023

Epoch 50:


                                                                       

Vanilla Loss:   0.3057
Bahdanau Loss:  0.2868
Luong Loss:     0.3106




In [17]:
def evaluate_bleu(model):
    model.eval()
    scores = []

    with torch.no_grad():
        for src_sentence, trg_sentence in test_pairs[:1000]:

            src_tensor = torch.tensor(
                [src_vocab.stoi["<sos>"]] +
                src_vocab.numericalize(src_sentence) +
                [src_vocab.stoi["<eos>"]]
            ).unsqueeze(1).to(device)

            encoder_outputs, hidden, cell = model.encoder(src_tensor)
            input_token = torch.tensor([trg_vocab.stoi["<sos>"]]).to(device)

            generated = []

            for _ in range(20):

                if isinstance(model.decoder, Decoder):
                    output, hidden, cell = model.decoder(
                        input_token, hidden, cell
                    )
                else:
                    output, hidden, cell, _ = model.decoder(
                        input_token, hidden, cell, encoder_outputs
                    )

                top1 = output.argmax(1)

                if top1.item() == trg_vocab.stoi["<eos>"]:
                    break

                generated.append(trg_vocab.itos[top1.item()])
                input_token = top1

            reference = [trg_sentence.split()]
            scores.append(sentence_bleu(reference, generated))

    return np.mean(scores)


print("\nBLEU Scores:")
print("Vanilla:", evaluate_bleu(vanilla_model))
print("Bahdanau:", evaluate_bleu(bahdanau_model))
print("Luong:", evaluate_bleu(luong_model))


BLEU Scores:
Vanilla: 0.034573508129145496
Bahdanau: 0.03648960622222519
Luong: 0.035924854104106264
