# Train – English to Spanish Translation

In [8]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

from data.data import EnglishToSpanish
from transformer.models.transformer import Transformer
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

### Train

In [2]:
from transformer.models.transformer import Transformer
from data.data import EnglishToSpanish
from torch.utils.data import DataLoader
import torch

device = torch.device('mps' if torch.cuda.is_available() else 'cpu')

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-es')
pad_token_id = tokenizer.pad_token_id


In [3]:
train_dataset = EnglishToSpanish(split='train')
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
model = Transformer(
    source_vocab_size=tokenizer.vocab_size,
    target_vocab_size=tokenizer.vocab_size,
    max_len=128,
    embed_dim=512,
    num_heads=8,
    ffn_hidden_dim=2048,
    N=6,
    dropout=0.1
).to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(3):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=False)

    for batch in progress_bar:
        src = batch['input_ids'].to(device)
        tgt = batch['labels'].to(device)

        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]

        src_mask = (src != pad_token_id).unsqueeze(1).unsqueeze(2)
        L = tgt_input.size(1)
        tgt_mask = torch.tril(torch.ones((L, L), device=device)).bool()
        tgt_mask = tgt_mask.unsqueeze(0).unsqueeze(0)

        logits = model(src, tgt_input, src_mask, tgt_mask, src_mask)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_output.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    print(f'Epoch {epoch + 1} - Total Loss: {total_loss:.4f}')


                                                                       

KeyboardInterrupt: 