# Train – English to Spanish Translation

In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

from data.data import EnglishToSpanish
from transformer.models.transformer import Transformer
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


### Train

In [3]:
from transformer.models.transformer import Transformer
from data.data import EnglishToSpanish
from torch.utils.data import DataLoader
import torch

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-es')
pad_token_id = tokenizer.pad_token_id




In [4]:
train_dataset = EnglishToSpanish(split='train')
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
EPOCHS = 30

model = Transformer(
    source_vocab_size=tokenizer.vocab_size,
    target_vocab_size=tokenizer.vocab_size,
    max_len=128,
    embed_dim=512,
    num_heads=8,
    ffn_hidden_dim=2048,
    N=6,
    dropout=0.1
).to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=False)

    for batch in progress_bar:
        src = batch['input_ids'].to(device)
        tgt = batch['labels'].to(device)

        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]

        src_mask = (src != pad_token_id).unsqueeze(1).unsqueeze(2)
        L = tgt_input.size(1)
        tgt_mask = torch.tril(torch.ones((L, L), device=device)).bool()
        tgt_mask = tgt_mask.unsqueeze(0).unsqueeze(0)

        logits = model(src, tgt_input, src_mask, tgt_mask, src_mask)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_output.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    print(f'Epoch {epoch + 1} - Total Loss: {total_loss:.4f}')
    
    if (epoch + 1) % 3 == 0:
        torch.save(model.state_dict(), f'transformer_epoch{epoch+1}.pt')



                                                                       

Epoch 1 - Total Loss: 12997.1249


                                                                       

Epoch 2 - Total Loss: 10436.1374


                                                                      

KeyboardInterrupt: 

In [6]:
torch.save(model.state_dict(), "transformer_2epochs.pt")

### Inference

In [9]:
def inference(sentence: str) -> str:
    """Runs inference on English to Spanish translation (limited to 128 tokens)

    Args:
        sentence: English sentence

    Returns:
        output: translated sentence in Spanish
    """
    model.eval()
    enc = tokenizer(sentence, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
    src = enc['input_ids'].to(device)
    src_mask = (src != pad_token_id).unsqueeze(1).unsqueeze(2)
    start_token_id = tokenizer.bos_token_id or tokenizer.cls_token_id or tokenizer.pad_token_id
    eos_token_id = tokenizer.eos_token_id or tokenizer.sep_token_id or pad_token_id
    generated = torch.tensor([[start_token_id]], dtype=torch.long, device=device)
    for _ in range(127):
        tgt_mask = torch.tril(torch.ones((generated.size(1), generated.size(1)), device=device)).unsqueeze(0).unsqueeze(0).bool()
        with torch.no_grad():
            logits = model(src, generated, src_mask, tgt_mask, src_mask)
        next_token = logits[:, -1].argmax(-1).unsqueeze(0)
        generated = torch.cat([generated, next_token], dim=1)
        if next_token.item() == eos_token_id:
            break
    output = tokenizer.decode(generated[0], skip_special_tokens=True)
    return output

In [10]:
example_sentences = [["I like to eat apples and bananas.",
                      "She is reading a book in the sun.",
                      "They play soccer every Sunday afternoon."],
                     ["We couldn't find the restaurant despite using the map.",
                      "If it rains tomorrow, we'll cancel the hike.",
                      "The teacher explained the problem in a different way."],
                     ["Although the train was late, we still made it on time.",
                      "The decision, which had been debated for months, was finally announced.",
                      "He acted as though nothing had happened, despite knowing the consequences."]]

levels_to_labels = ['Simple',
                    'Medium',
                    'Complex']
for l, level in enumerate(example_sentences):
    print(f'DIFFICULTY: {levels_to_labels[l]}')
    for s, sentence in enumerate(level):
        print(f'\t{s + 1}: {sentence} => {inference(sentence)}')

DIFFICULTY: Simple
	1: I like to eat apples and bananas. => o que nos habían aquel momento.o.o.o.o.o.o.ó..o.o.ó........o................................
	2: She is reading a book in the sun. => a es un hombre.ó el cielo. la ciudad.ó el cielo..ó..ó.ó..ó......ó.....ó.ó el cielo.......ó..........amente...ó...
	3: They play soccer every Sunday afternoon. => o habían habían ninguna.ban a la vista.ó el último. los último.ó el nuevo.ó el nuevo.o...ó.....ó.ó el nuevo.o......tá......ó..ó...o.ó...
DIFFICULTY: Medium
	1: We couldn't find the restaurant despite using the map. => o no habíamos a la vista de la isla Lincoln.ó el único.ó el último.ó el cielo.ó el cielo. el cielo...ó el cielo...... el único...................o.....
	2: If it rains tomorrow, we'll cancel the hike. => o, ¿qué hacemos a la vez.ó el último.ó el único.o.ó el último. la vez.. el último.ó....o.....o.....ó....o...ó.o....o.....
	3: The teacher explained the problem in a different way. => o habían habían habían a lado de la vis

In [12]:
inference("hello.")
#inference("good night.")


'o.ó..ó......ó...ó.............................................'