# Train – English to Spanish Translation

In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

from data.data import EnglishToSpanish
from transformer.models.transformer import Transformer
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


### Train

In [2]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-es')
pad_token_id = tokenizer.pad_token_id




In [3]:
train_dataset = EnglishToSpanish(split='train')
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [4]:
EPOCHS = 30

model = Transformer(
    source_vocab_size=tokenizer.vocab_size,
    target_vocab_size=tokenizer.vocab_size,
    max_len=128,
    embed_dim=512,
    num_heads=8,
    ffn_hidden_dim=2048,
    N=6,
    dropout=0.1
).to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=False)

    for batch in progress_bar:
        src = batch['input_ids'].to(device)
        tgt = batch['labels'].to(device)

        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]

        src_mask = (src != pad_token_id).unsqueeze(1).unsqueeze(2)
        L = tgt_input.size(1)
        tgt_mask = torch.tril(torch.ones((L, L), device=device)).bool()
        tgt_mask = tgt_mask.unsqueeze(0).unsqueeze(0)

        logits = model(src, tgt_input, src_mask, tgt_mask, src_mask)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_output.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    print(f'Epoch {epoch + 1} - Total Loss: {total_loss:.4f}')
    
    if (epoch + 1) % 3 == 0:
        torch.save(model.state_dict(), f'transformer_epoch{epoch+1}.pt')


                                                                       

Epoch 1 - Total Loss: 12999.4714


                                                                       

Epoch 2 - Total Loss: 10406.2217


                                                                       

Epoch 3 - Total Loss: 9529.5641


                                                                       

Epoch 4 - Total Loss: 8902.9232


                                                                       

Epoch 5 - Total Loss: 8368.5333


                                                                       

Epoch 6 - Total Loss: 7945.5740


                                                                       

Epoch 7 - Total Loss: 7586.8106


                                                                       

Epoch 8 - Total Loss: 7274.3715


                                                                       

Epoch 9 - Total Loss: 7000.3477


                                                                        

Epoch 10 - Total Loss: 6752.4823


                                                                        

Epoch 11 - Total Loss: 6535.1290


                                                                        

Epoch 12 - Total Loss: 6331.8116


                                                                        

Epoch 13 - Total Loss: 6147.9780


                                                                        

Epoch 14 - Total Loss: 5979.4386


                                                                        

Epoch 15 - Total Loss: 5827.1593


                                                                        

Epoch 16 - Total Loss: 5680.6494


                                                                        

Epoch 17 - Total Loss: 5545.9932


                                                                        

Epoch 18 - Total Loss: 5422.1001


                                                                        

Epoch 19 - Total Loss: 5303.5667


                                                                        

Epoch 20 - Total Loss: 5194.2645


                                                                        

Epoch 21 - Total Loss: 5089.9528


                                                                        

Epoch 22 - Total Loss: 4991.8034


                                                                        

Epoch 23 - Total Loss: 4899.5368


                                                                        

Epoch 24 - Total Loss: 4811.7374


                                                                        

Epoch 25 - Total Loss: 4727.9812


                                                                        

Epoch 26 - Total Loss: 4649.9095


                                                                        

Epoch 27 - Total Loss: 4572.7969


                                                                        

Epoch 28 - Total Loss: 4496.5549


                                                                        

Epoch 29 - Total Loss: 4426.0569


                                                                        

Epoch 30 - Total Loss: 4362.8466


### Inference

In [10]:
def inference(sentence: str) -> str:
    """Runs inference on English to Spanish translation (limited to 128 tokens)

    Args:
        sentence: English sentence

    Returns:
        output: translated sentence in Spanish
    """
    model.eval()
    enc = tokenizer(sentence, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
    src = enc['input_ids'].to(device)
    src_mask = (src != pad_token_id).unsqueeze(1).unsqueeze(2)

    start_token_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.cls_token_id or pad_token_id
    eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id or pad_token_id

    generated = torch.tensor([[start_token_id]], dtype=torch.long, device=device)

    for _ in range(127):
        tgt_mask = torch.tril(torch.ones((generated.size(1), generated.size(1)), device=device)).unsqueeze(0).unsqueeze(0).bool()
        with torch.no_grad():
            logits = model(src, generated, src_mask, tgt_mask, src_mask)
        next_token = logits[:, -1].argmax(-1).unsqueeze(0)
        generated = torch.cat([generated, next_token], dim=1)
        if next_token.item() == eos_token_id:
            break

    output = tokenizer.decode(generated[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return output


In [11]:
example_sentences = [["I like to eat apples and bananas.",
                      "She is reading a book in the sun.",
                      "They play soccer every Sunday afternoon."],
                     ["We couldn't find the restaurant despite using the map.",
                      "If it rains tomorrow, we'll cancel the hike.",
                      "The teacher explained the problem in a different way."],
                     ["Although the train was late, we still made it on time.",
                      "The decision, which had been debated for months, was finally announced.",
                      "He acted as though nothing had happened, despite knowing the consequences."]]

levels_to_labels = ['Simple',
                    'Medium',
                    'Complex']
for l, level in enumerate(example_sentences):
    print(f'DIFFICULTY: {levels_to_labels[l]}')
    for s, sentence in enumerate(level):
        print(f'\t{s + 1}: {sentence}')
        print(f'\t\t=> {inference(sentence)}')

DIFFICULTY: Simple
	1: I like to eat apples and bananas.
		=> me gustan las comidas y la hambre.
	2: She is reading a book in the sun.
		=> leyó un libro en el sol.
	3: They play soccer every Sunday afternoon.
		=> aban los gritos de la mañana.
DIFFICULTY: Medium
	1: We couldn't find the restaurant despite using the map.
		=> utábamos el diván, pese a la tragedia.
	2: If it rains tomorrow, we'll cancel the hike.
		=> che, si no llovíamos aprender la tumba.
	3: The teacher explained the problem in a different way.
		=> tó el profesor de camino, lo que debían modo.
DIFFICULTY: Complex
	1: Although the train was late, we still made it on time.
		=> , el tren nos llevaba, todavía mucho tiempo.
	2: The decision, which had been debated for months, was finally announced.
		=> ía la decisión, que había sido anunciado por meses, fue anunciado.
	3: He acted as though nothing had happened, despite knowing the consequences.
		=> to, no había nada, no había nada, a pesar de lo que sabía, a pesar de