In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.legacy.data import Field, BucketIterator
from torchtext.legacy.datasets import Multi30k
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from nltk.translate.bleu_score import corpus_bleu
import numpy as np
import random
import spacy
import torch.nn.functional as F
import math
import time

SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

spacy_de = spacy.load('de_core_news_sm')
spacy_en = spacy.load('en_core_web_sm')

def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)][::-1]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

SRC = Field(tokenize=tokenize_de, init_token='<sos>', eos_token='<eos>', lower=True)
TRG = Field(tokenize=tokenize_en, init_token='<sos>', eos_token='<eos>', lower=True)
data_folder = "/home/yhz2023/code_file/data"
train_data, valid_data, test_data = Multi30k.splits(exts=('.de', '.en'), fields=(SRC, TRG), root=data_folder)
SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)

BATCH_SIZE = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size=BATCH_SIZE, 
    device=device)

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, rnn_type='GRU'):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn_type = rnn_type
        if rnn_type == 'GRU':
            self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional=True)
        elif rnn_type == 'LSTM':
            self.rnn = nn.LSTM(emb_dim, enc_hid_dim, bidirectional=True)
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        if self.rnn_type == 'LSTM':
            hidden = hidden[0]  # LSTM returns hidden state and cell state, we take hidden state
        s = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)))
        return outputs, s

class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim, bias=False)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)
        
    def forward(self, s, enc_output):
        batch_size = enc_output.shape[1]
        src_len = enc_output.shape[0]
        s = s.unsqueeze(1).repeat(1, src_len, 1)
        enc_output = enc_output.transpose(0, 1)
        energy = torch.tanh(self.attn(torch.cat((s, enc_output), dim = 2)))
        attention = self.v(energy).squeeze(2)
        
        return F.softmax(attention, dim=1)

class BahdanauAttention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias=False)
        
    def forward(self, hidden, encoder_outputs):
        src_len = encoder_outputs.shape[0]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = self.v(energy).squeeze(2)
        return F.softmax(attention, dim=1)

class LuongAttention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        self.attn = nn.Linear(dec_hid_dim, enc_hid_dim * 2) 

    def forward(self, hidden, encoder_outputs):
        # hidden: [batch size, dec_hid_dim]
        hidden = self.attn(hidden) 
        hidden = hidden.unsqueeze(2)  # [batch size, enc_hid_dim * 2, 1]

        # encoder_outputs: [src_len, batch size, enc_hid_dim * 2]
        encoder_outputs = encoder_outputs.permute(1, 0, 2)  # [batch size, enc_hid_dim * 2, src_len]
        energy = torch.bmm(encoder_outputs, hidden).squeeze(2)  # [batch size, src_len]

        return F.softmax(energy, dim=1)

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, dec_input, s, enc_output):
        dec_input = dec_input.unsqueeze(1)
        embedded = self.dropout(self.embedding(dec_input)).transpose(0, 1)
        a = self.attention(s, enc_output).unsqueeze(1)
        enc_output = enc_output.transpose(0, 1)
        c = torch.bmm(a, enc_output).transpose(0, 1)
        rnn_input = torch.cat((embedded, c), dim = 2)
        dec_output, dec_hidden = self.rnn(rnn_input, s.unsqueeze(0))
        embedded = embedded.squeeze(0)
        dec_output = dec_output.squeeze(0)
        c = c.squeeze(0)
        pred = self.fc_out(torch.cat((dec_output, c, embedded), dim = 1))
        
        return pred, dec_hidden.squeeze(0)

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        batch_size = src.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
       
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        enc_output, s = self.encoder(src)
        dec_input = trg[0,:]
        
        for t in range(1, trg_len):
            dec_output, s = self.decoder(dec_input, s, enc_output)
            outputs[t] = dec_output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = dec_output.argmax(1) 
            dec_input = trg[t] if teacher_force else top1

        return outputs

INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

def initialize_model(rnn_type, attn_type, input_dim, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
    encoder = Encoder(input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, rnn_type=rnn_type)
    
    if attn_type == 'bahdanau':
        attention = BahdanauAttention(enc_hid_dim, dec_hid_dim)
    elif attn_type == 'luong':
        attention = LuongAttention(enc_hid_dim, dec_hid_dim)
    else:
        raise ValueError("Unsupported attention type")

    decoder = Decoder(output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention)
    model = Seq2Seq(encoder, decoder, device).to(device)
    return model

def train_and_evaluate(model, train_iterator, valid_iterator, optimizer, criterion, n_epochs):
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for i, batch in enumerate(train_iterator):
            src, trg = batch.src, batch.trg
            optimizer.zero_grad()
            output = model(src, trg)
            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)
            loss = criterion(output, trg)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        valid_loss = evaluate(model, valid_iterator, criterion)
        print(f'Epoch: {epoch+1}, Train Loss: {train_loss / len(train_iterator)}, Valid Loss: {valid_loss}')

    return model

def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src, trg = batch.src, batch.trg
            output = model(src, trg, 0)  # Turn off teacher forcing
            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(iterator)

def translate(model, iterator, trg_field):
    model.eval()
    original_texts = []
    generated_texts = []

    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch.src.to(device)
            trg = batch.trg.to(device)

            output = model(src, trg, 0)  # Turn off teacher forcing
            output = output.argmax(2).transpose(0, 1).tolist()

            # Convert indexes to strings
            for j in range(trg.shape[1]):
                original_text = [trg_field.vocab.itos[token] for token in trg[:, j] if token != trg_field.vocab.stoi['<pad>']]
                generated_text = [trg_field.vocab.itos[token] for token in output[j] if token != trg_field.vocab.stoi['<pad>']]

                original_texts.append([original_text])
                generated_texts.append(generated_text)

    return original_texts, generated_texts

def calculate_bleu(original_texts, generated_texts):
    return corpus_bleu(original_texts, generated_texts)

configurations = [
    ("GRU", "bahdanau"),
    ("GRU", "luong"),
    ("LSTM", "bahdanau"),
    ("LSTM", "luong")
]

results = []

for rnn_type, attn_type in configurations:
    model = initialize_model(rnn_type, attn_type, INPUT_DIM, OUTPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
    optimizer = optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)
    
    print(f'Training with {rnn_type} RNN and {attn_type} attention...')
    train_and_evaluate(model, train_iterator, valid_iterator, optimizer, criterion, n_epochs=10)
    
    test_loss = evaluate(model, test_iterator, criterion)
    print(f'Test Loss for {rnn_type} + {attn_type}: {test_loss}')
    results.append((rnn_type, attn_type, test_loss))

    original_texts, generated_texts = translate(model, test_iterator, TRG)
    bleu_score = calculate_bleu(original_texts, generated_texts)
    print(f'BLEU Score for {rnn_type} + {attn_type}: {bleu_score:.2f}')

for result in results:
    print(f"RNN Type: {result[0]}, Attention Type: {result[1]}, Test Loss: {result[2]}")


Training with GRU RNN and bahdanau attention...
Epoch: 1, Train Loss: 4.340084726064741, Valid Loss: 3.65247443318367
Epoch: 2, Train Loss: 3.1268660707095646, Valid Loss: 3.295613706111908
Epoch: 3, Train Loss: 2.675323183841117, Valid Loss: 3.2827012836933136
Epoch: 4, Train Loss: 2.374042877541765, Valid Loss: 3.2510056495666504
Epoch: 5, Train Loss: 2.153458273358282, Valid Loss: 3.2489570677280426
Epoch: 6, Train Loss: 2.0075077480156516, Valid Loss: 3.2760799527168274
Epoch: 7, Train Loss: 1.9041924266563113, Valid Loss: 3.23938250541687
Epoch: 8, Train Loss: 1.7949174437753954, Valid Loss: 3.299146056175232
Epoch: 9, Train Loss: 1.7038813587852513, Valid Loss: 3.3199033737182617
Epoch: 10, Train Loss: 1.6570884798066732, Valid Loss: 3.4353507459163666
Test Loss for GRU + bahdanau: 3.431431472301483
BLEU Score for GRU + bahdanau: 0.24
Training with GRU RNN and luong attention...
Epoch: 1, Train Loss: 4.621248229484726, Valid Loss: 4.406205058097839
Epoch: 2, Train Loss: 3.7314060