In [13]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

import numpy as np
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, AutoTokenizer, AutoModel, PreTrainedTokenizerFast
from tokenizers import Tokenizer
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import tensorflow as tf
import numpy as np
import random
nltk.download('punkt')

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [15]:
SOS_token = 0
EOS_token = 2

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

In [16]:
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z!?]+", r" ", s)
    return s.strip()

In [17]:
def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\
        read().strip().split('\n')


    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs

In [18]:
vi_sentences_path = "/kaggle/input/berttokenize/Bert/tokenize_vi.txt"
en_sentences_path = "/kaggle/input/berttokenize/Bert/tokenize_en.txt"
tokenizer_en = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer_vi = AutoTokenizer.from_pretrained("vinai/phobert-base")

In [19]:
vi_vocab_size = tokenizer_vi.vocab_size
en_vocab_size = tokenizer_en.vocab_size

print(f"Vietnamese Vocabulary Size: {vi_vocab_size}")
print(f"English Vocabulary Size: {en_vocab_size}")

Vietnamese Vocabulary Size: 64000
English Vocabulary Size: 30522


In [20]:
input_text = "Run!	Corre!"

input_ids = tokenizer_en.encode(input_text, return_tensors="pt")

print(input_ids)

decoded_text = tokenizer_en.decode(input_ids[0], skip_special_tokens=True)

print("Decoded Text:", decoded_text)

tensor([[  101,  2448,   999,  2522, 14343,   999,   102]])
Decoded Text: run! corre!


In [21]:
input_text = "Với bài toán dịch Anh - Việt, việc kiểm tra cách mà tokenizer mã hóa câu tiếng Anh và tái mã hóa lại câu tiếng Việt là rất quan trọng. Dưới đây là hướng dẫn cụ thể"

input_ids = tokenizer_vi.encode(input_text, return_tensors="pt")

print(input_ids)

decoded_text = tokenizer_vi.decode(input_ids[0], skip_special_tokens=True)

print("Decoded Text:", decoded_text)

tensor([[    0,   321,   387,  4698,  1626,   157,    31, 28171,  1187,     4,
            49,  7303,  5761,   139,    64, 52479, 10507, 15275,  1624,  1340,
         11095,  1517,   528,   355,   157,     6,  1278,  1624,  1340, 11095,
          1517,    44,   528,   355,   350,     8,    59,  2665, 37096, 10838,
          2543,    97,     8,   455,   376,  1591,  4623,     2]])
Decoded Text: Với bài toán dịch Anh - Việt, việc kiểm tra cách mà tokenizer mã hóa câu tiếng Anh và tái mã hóa lại câu tiếng Việt là rất quan trọng. Dưới đây là hướng dẫn cụ thể


In [22]:
def count_sentences(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()
    return len(lines)


num_sentences_vi = count_sentences(vi_sentences_path)
num_sentences_en = count_sentences(en_sentences_path)

print(f"Number of sentences in tokenized vietnamese: {num_sentences_vi}")
print(f"Number of sentences in tokenized english: {num_sentences_en}")

Number of sentences in tokenized vietnamese: 2977999
Number of sentences in tokenized english: 2977999


In [23]:
vi_vocab = tokenizer_vi.get_vocab()  
en_vocab = tokenizer_en.get_vocab()  

print("First 20 tokens in the English vocabulary:")
for i, (token, _) in enumerate(list(en_vocab.items())[:20]):
    print(f"{i+1}. {token}")

print("\nFirst 20 tokens in the Vietnamese vocabulary:")
for i, (token, _) in enumerate(list(vi_vocab.items())[:20]):
    print(f"{i+1}. {token}")


First 20 tokens in the English vocabulary:
1. ##rce
2. insults
3. hugh
4. √
5. ##ened
6. three
7. ட
8. rub
9. ancestral
10. ##qu
11. langley
12. selfish
13. ##pton
14. pcs
15. ##alia
16. differs
17. morris
18. ##wee
19. cuts
20. categorized

First 20 tokens in the Vietnamese vocabulary:
1. <s>
2. <pad>
3. </s>
4. <unk>
5. ,
6. .
7. và
8. của
9. là
10. các
11. có
12. được
13. trong
14. cho
15. đã
16. với
17. một
18. không
19. người
20. )


In [24]:
def read_tokenized_sentences(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        lines = file.readlines()
    return [list(map(int, line.strip().split())) for line in lines]

tokenized_en = read_tokenized_sentences(en_sentences_path)
tokenized_vi = read_tokenized_sentences(vi_sentences_path)

In [25]:
def sample_data(english_sentences, vietnamese_sentences, sample_ratio=0.015):
    dataset_size = len(english_sentences)
    sample_size = int(sample_ratio * dataset_size)
    indices = np.random.choice(dataset_size, sample_size, replace=False)

    sampled_en = [english_sentences[i] for i in indices]
    sampled_vi = [vietnamese_sentences[i] for i in indices]

    return sampled_en, sampled_vi

sampled_en, sampled_vi = sample_data(tokenized_en, tokenized_vi, sample_ratio=0.015)

In [26]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class TranslationDataset(Dataset):
    def __init__(self, english_sentences, vietnamese_sentences, device='cpu'):
        """
        Args:
            english_sentences (list): List of English sentence tokenized lists.
            vietnamese_sentences (list): List of Vietnamese sentence tokenized lists.
            device (str): The device to which tensors should be moved ('cpu' or 'cuda').
        """
        self.english_sentences = english_sentences
        self.vietnamese_sentences = vietnamese_sentences
        self.device = device

    def __len__(self):
        return len(self.english_sentences)

    def __getitem__(self, idx):
        en_sentence = torch.tensor(self.english_sentences[idx], dtype=torch.long, device=self.device)
        vi_sentence = torch.tensor(self.vietnamese_sentences[idx], dtype=torch.long, device=self.device)
        
        return en_sentence, vi_sentence

def create_pytorch_dataset(english_sentences, vietnamese_sentences, train_split=0.8, device='cpu'):
    dataset_size = len(english_sentences)
    indices = np.arange(dataset_size)
    np.random.shuffle(indices)

    train_size = int(train_split * dataset_size)
    train_indices = indices[:train_size]
    val_indices = indices[train_size:]

    train_dataset = TranslationDataset(
        [english_sentences[i] for i in train_indices],
        [vietnamese_sentences[i] for i in train_indices],
        device=device
    )

    val_dataset = TranslationDataset(
        [english_sentences[i] for i in val_indices],
        [vietnamese_sentences[i] for i in val_indices],
        device=device
    )

    return train_dataset, val_dataset

train_dataset, val_dataset = create_pytorch_dataset(sampled_en, sampled_vi, train_split=0.8, device = device)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

for en_batch, vi_batch in train_dataloader:
    print("English batch:", en_batch)
    print("Vietnamese batch:", vi_batch)
    break
for batch_idx, (inputs, targets) in enumerate(train_dataloader):
    print(inputs)
    print(targets)
    break


English batch: tensor([[ 101, 2043, 2057,  ...,    0,    0,    0],
        [ 101, 2054, 1996,  ...,    0,    0,    0],
        [ 101, 1045, 2572,  ...,    0,    0,    0],
        ...,
        [ 101, 2292, 1055,  ...,    0,    0,    0],
        [ 101, 1999, 2456,  ..., 2368, 1010,  102],
        [ 101, 2505, 2842,  ...,    0,    0,    0]], device='cuda:0')
Vietnamese batch: tensor([[    0,   251,   572,  ...,     1,     1,     1],
        [    0,  2510,  7493,  ...,     1,     1,     1],
        [    0,  8051,     8,  ...,     1,     1,     1],
        ...,
        [    0, 12127,    70,  ...,     1,     1,     1],
        [    0,   125,   266,  ...,  4704,  1493,     2],
        [    0,   631,   148,  ...,     1,     1,     1]], device='cuda:0')
tensor([[  101,  2129, 15140,  ...,     0,     0,     0],
        [  101,  2017, 22374,  ...,     0,     0,     0],
        [  101,  1045,  2064,  ...,     0,     0,     0],
        ...,
        [  101,  1996,  3537,  ...,     0,     0,     0],


In [27]:
import torch
import torch.nn.functional as F

def masked_loss(y_true, y_pred):
    loss = F.cross_entropy(y_pred, y_true, reduction='none')

    mask = (y_true != 0).float()

    loss = loss * mask

    return loss.sum() / mask.sum()


In [28]:
import torch

def masked_acc(y_true, y_pred):
    y_pred = torch.argmax(y_pred, dim=-1)
    
    mask = (y_true != 0).float()

    correct = (y_true == y_pred).float()

    correct = correct * mask

    return correct.sum() / mask.sum()


In [29]:
import torch
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

def bleu_score(y_true, y_pred):

    y_true = y_true.cpu().numpy() if isinstance(y_true, torch.Tensor) else y_true
    y_pred = y_pred.cpu().numpy() if isinstance(y_pred, torch.Tensor) else y_pred
    
    y_pred = torch.argmax(y_pred, dim=-1).cpu().numpy()

    smoothing_function = SmoothingFunction().method4

    bleu_scores = []

    for i in range(len(y_true)):
        reference = [y_true[i]] 
        candidate = y_pred[i]   

        score = sentence_bleu(reference, candidate, smoothing_function=smoothing_function)
        bleu_scores.append(score)

    return torch.tensor(bleu_scores, dtype=torch.float32).mean()


In [30]:
VOCAB_SIZE = 64000
UNITS = 512
MAX_LENGTH = 50

In [31]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, dropout_p=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size).to(device)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True).to(device)
        self.dropout = nn.Dropout(dropout_p).to(device)
        self.hidden_transform = nn.Linear(hidden_size * 2, hidden_size).to(device)

    def forward(self, input):
        input = input.to(device)
        embedded = self.embedding(input)
        embedded = self.dropout(embedded)
        output, hidden = self.gru(embedded)       
        output = self.hidden_transform(output)

        return output, hidden

In [32]:
class CrossAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(CrossAttention, self).__init__()
        assert hidden_size % num_heads == 0, "Hidden size must be divisible by the number of heads."
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.Wq = nn.Linear(hidden_size, hidden_size).to(device)  
        self.Wk = nn.Linear(hidden_size, hidden_size).to(device)  
        self.Wv = nn.Linear(hidden_size, hidden_size).to(device) 

        self.Wo = nn.Linear(hidden_size, hidden_size).to(device)

        self.softmax = nn.Softmax(dim=-1).to(device)

    def forward(self, query, keys):
        query = query.to(device)
        keys = keys.to(device)
        batch_size = query.size(0)

        query_proj = self.Wq(query) 
        key_proj = self.Wk(keys)    
        value_proj = self.Wv(keys)  

        query_proj = query_proj.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        key_proj = key_proj.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        value_proj = value_proj.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(query_proj, key_proj.transpose(-2, -1))  
        scores = scores / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))  

        attention_weights = self.softmax(scores)  

        context = torch.matmul(attention_weights, value_proj) 

        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        output = self.Wo(context)  

        return output, attention_weights

In [33]:
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size,n_layers=1, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size).to(device)
        self.attention = CrossAttention(hidden_size, num_heads=2).to(device)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,batch_first = True).to(device)
        self.out = nn.Linear(hidden_size, output_size).to(device)
        self.dropout = nn.Dropout(dropout_p)
        self.hidden_transform = nn.Linear(hidden_size * 2, hidden_size).to(device)
        self.hidden_input_transform = nn.Linear(hidden_size * 2, hidden_size).to(device)
        self.hidden_size = hidden_size

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        encoder_outputs = encoder_outputs.to(device)
        encoder_hidden = encoder_hidden.to(device)
        if target_tensor is not None:
            target_tensor = target_tensor.to(device)
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
        decoder_hidden = self.transform_bidirectional_hidden(encoder_hidden)
        decoder_outputs = []
        attentions = []

        for i in range(MAX_LENGTH):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if target_tensor is not None:
                decoder_input = target_tensor[:, i].unsqueeze(1) 
            else:
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        attentions = torch.cat(attentions, dim=1)

        return decoder_outputs, decoder_hidden, attentions
    def transform_bidirectional_hidden(self, encoder_hidden):
        forward_states = encoder_hidden[0::2, :, :]
        backward_states = encoder_hidden[1::2, :, :] 
        combined_hidden = torch.cat((forward_states, backward_states), dim=2) 
        combined_hidden = self.hidden_transform(combined_hidden)
        return combined_hidden


    def forward_step(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input)
        embedded = self.dropout(embedded)
        query = hidden.permute(1, 0, 2)
        context, attn_weights = self.attention(query, encoder_outputs)
        input_gru = torch.cat((embedded, context), dim=2)
        input_gru = self.hidden_input_transform(input_gru)
        output, hidden = self.gru(input_gru, hidden)
        output = self.out(output)

        return output, hidden, attn_weights

In [34]:
class Translator(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Translator, self).__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, input_tensor, target_tensor=None):
        if target_tensor is not None:
            target_tensor = target_tensor.to(self.device)
        encoder_outputs, encoder_hidden = self.encoder(input_tensor)
        decoder_outputs, _, _ = self.decoder(encoder_outputs, encoder_hidden, target_tensor)
        return decoder_outputs

    def eval(self):
        self.encoder.eval()
        self.decoder.eval()

In [35]:
import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [36]:
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np

def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [43]:
def save_model(model, path):
    torch.save(model.state_dict(), path)


In [44]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer,
          decoder_optimizer, criterion):
    encoder.train()
    decoder.train()
    total_loss = 0
    for data in dataloader:
        input_tensor, target_tensor = data

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)
        predicted_token_ids = torch.argmax(decoder_outputs, dim=-1)[0].tolist()
        loss = criterion(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            target_tensor.view(-1)
        )
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [45]:
def val_epoch(dataloader, encoder, decoder, criterion):
    encoder.eval()
    decoder.eval()
    total_loss = 0
    with torch.no_grad():
        for data in dataloader:
            input_tensor, target_tensor = data
    
            encoder_outputs, encoder_hidden = encoder(input_tensor)
            decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)
    
            loss = criterion(
                decoder_outputs.view(-1, decoder_outputs.size(-1)),
                target_tensor.view(-1)
            )
    
            total_loss += loss.item()

    return total_loss / len(dataloader)

In [46]:
def train(train_dataloader, val_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,
          print_every=100, plot_every=100, patience=5):
    start = time.time()
    plot_losses = []
    print_train_loss_total = 0  
    print_val_loss_total = 0  
    
    plot_loss_total = 0 

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    best_val_loss = float('inf')  
    epochs_without_improvement = 0  

    for epoch in range(1, n_epochs + 1):
        train_loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_train_loss_total += train_loss
        val_loss = val_epoch(val_dataloader, encoder, decoder, criterion)
        print_val_loss_total += val_loss
        plot_loss_total += train_loss

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0 
            save_encode_path = "encoder.pth"
            save_decode_path = "decoder.pth"
            save_model(encoder, save_encode_path)
            save_model(decoder, save_decode_path)
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"Early stopping triggered after {epoch} epochs.")
            break

        if epoch % print_every == 0:
            print_train_loss_avg = print_train_loss_total / print_every
            print_train_loss_total = 0
            print('%s (%d %d%%) Train Loss: %.4f' % (timeSince(start, epoch / n_epochs),
                                                     epoch, epoch / n_epochs * 100, print_train_loss_avg))
            print_val_loss_avg = print_val_loss_total / print_every
            print_val_loss_total = 0
            print('%s (%d %d%%) Val Loss: %.4f' % (timeSince(start, epoch / n_epochs),
                                                   epoch, epoch / n_epochs * 100, print_val_loss_avg))

        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

    showPlot(plot_losses)


In [47]:
def evaluate(encoder, decoder, sentence, input_lang, output_lang):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang, sentence)

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)

        _, topi = decoder_outputs.topk(1)
        decoded_ids = topi.squeeze()

        decoded_words = []
        for idx in decoded_ids:
            if idx.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            decoded_words.append(output_lang.index2word[idx.item()])
    return decoded_words, decoder_attn

In [48]:
hidden_size = 256
batch_size = 32

encoder = EncoderRNN(VOCAB_SIZE, hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size, VOCAB_SIZE).to(device)

train(train_dataloader,val_dataloader, encoder, decoder, 15, print_every=1, plot_every=2, patience = 2)

18m 1s (- 207m 20s) (2 8%) Train Loss: 1.9037
18m 1s (- 207m 20s) (2 8%) Val Loss: 1.7591
36m 4s (- 189m 24s) (4 16%) Train Loss: 2.6265
36m 4s (- 189m 24s) (4 16%) Val Loss: 3.5604
54m 4s (- 171m 14s) (6 24%) Train Loss: 2.9599
54m 4s (- 171m 14s) (6 24%) Val Loss: 3.0734
Early stopping triggered after 7 epochs.


In [51]:
def save_model(model, path):
    torch.save(model.state_dict(), path)

def load_model(model, path):
    model.load_state_dict(torch.load(path))
    model.eval()


In [52]:
save_encode_path = "encoder.pth"
save_decode_path = "decoder.pth"
save_model(encoder, save_encode_path)
save_model(decoder, save_decode_path)

In [53]:
def beam_search_decode(encoder, decoder, input_tensor, max_length, beam_width, sos_token, eos_token, device):

    input_tensor = input_tensor.to(device)
    batch_size = input_tensor.size(0)

    encoder_mask = (input_tensor != 1).to(device)
    encoder_outputs, encoder_hidden = encoder(input_tensor)

    decoder_input = torch.tensor([[sos_token]] * batch_size, dtype=torch.long, device=device)
    decoder_hidden = decoder.transform_bidirectional_hidden(encoder_hidden)

    beams = [(decoder_input, 0, decoder_hidden)]

    for _ in range(max_length):
        new_beams = []
        for seq, score, hidden in beams:
            decoder_output, hidden, _ = decoder.forward_step(seq[:, -1:], hidden, encoder_outputs)

            topv, topi = decoder_output.squeeze(1).topk(beam_width)

            for i in range(beam_width):
                next_seq = torch.cat([seq, topi[:, i].unsqueeze(1)], dim=1)

                next_score = score + topv[:, i].item()

                new_beams.append((next_seq, next_score, hidden))

        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]

    best_sequence = beams[0][0]
    return best_sequence

def evaluate_with_beam_search(translator, input_tensor, max_length, beam_width, sos_token, eos_token, device):

    encoder = translator.encoder
    decoder = translator.decoder

    with torch.no_grad():
        best_sequence = beam_search_decode(encoder, decoder, input_tensor, max_length, beam_width, sos_token, eos_token, device)

    return best_sequence



In [56]:
def clean_decoded_sentence(sentence):
    special_tokens = ["<s>", "</s>", "<pad>", "<unk>"]
    for token in special_tokens:
        sentence = sentence.replace(token, "").strip() 
    return sentence

def translate_english_to_vietnamese(model, english_tokenizer, vietnamese_tokenizer, device):
    model.eval()
    
    english_sentence = input("Enter an English sentence: ").strip()
    
    english_tokens = english_tokenizer.encode(english_sentence)  # Convert sentence to token IDs
    english_tensor = torch.tensor(english_tokens).unsqueeze(0).to(device)  # Add batch dimension

    best_sentence = evaluate_with_beam_search(translator, english_tensor, 50, 5, 0, 2, device)
    
    vietnamese_sentence = vietnamese_tokenizer.decode(best_sentence[0])
    
    vietnamese_sentence_cleaned = clean_decoded_sentence(vietnamese_sentence)
    
    return vietnamese_sentence_cleaned

