In [1]:
from typing import NamedTuple, Generator
import pandas as pd
import torch
import torch.nn.functional as F
from torch import Tensor, optim, nn
from tqdm import tqdm
from src.transformer import Transformer

# https://www.kaggle.com/datasets/dhruvildave/en-fr-translation-dataset

In [2]:
class PairedSentences(NamedTuple):
    fr: str
    en: str

class ListPairedSentences(NamedTuple):
    fr: list[str]
    en: list[str]

    def __getitem__(self, index: int) -> PairedSentences:
        return PairedSentences(self.fr[index], self.en[index])

In [3]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
print(tokenizer.vocab_size)

class TrainingBatch(NamedTuple):
    x: Tensor
    y: Tensor
    encoder_mask: Tensor | None
    # no need for decoder mask, we'll set it after

    def __repr__(self):
        return f"TrainingBatch(x.shape={self.x.shape}, y.shape={self.y.shape}, encoder_mask.shape={self.encoder_mask.shape})"

def tokenize(text):
    return tokenizer(
        text,
        return_tensors="pt",
        max_length=64,
        padding="max_length",
        truncation=True)

def make_batch(paired_sentences: ListPairedSentences) -> TrainingBatch:
    # Tokenize each sentence in the 'fr' and 'en' lists
    fr_sentences = [tokenize(sentence) for sentence in paired_sentences.fr]
    en_sentences = [tokenize(sentence) for sentence in paired_sentences.en]

    # Stack tokenized tensors for batching
    X_batch = torch.stack([x['input_ids'].squeeze(0) for x in fr_sentences])
    Y_batch = torch.stack([y['input_ids'].squeeze(0) for y in en_sentences])

    # Create encoder padding mask: 1 for real tokens, 0 for padding
    encoder_mask = torch.stack([y['attention_mask'].squeeze(0) for y in fr_sentences]) \
        .unsqueeze(1).unsqueeze(2)

    return TrainingBatch(
        x=X_batch,
        y=Y_batch,
        encoder_mask=encoder_mask.to(torch.float32))

  from .autonotebook import tqdm as notebook_tqdm


30522




In [4]:
def get_page(csv_path: str, page: int, rows_per_page: int):
    return pd.read_csv(csv_path, skiprows = 1 + page * rows_per_page, nrows=rows_per_page, header=None, names=["en", "fr"])

def make_generator(csv_path: str, rows_per_page: int) -> Generator[ListPairedSentences, None, None]:
    i = 0
    while True:
        page = get_page(csv_path, i, rows_per_page)
        fr_sentences = page["fr"].to_list()
        en_sentences = page["en"].to_list()
        yield ListPairedSentences(fr_sentences, en_sentences)
        i += i

def get_num_steps(csv_path: str, rows_per_page: int) -> int:
    total_rows = sum(1 for _ in open(csv_path)) - 1 # minus one for header row
    num_steps = (total_rows + rows_per_page - 1) // rows_per_page  # Round up
    return num_steps

In [5]:
learning_rate = 1e-4
num_epochs = 10
batch_size = 32
csv_path = "data/en-fr.csv"
num_steps = get_num_steps(csv_path, batch_size)
data_generator = make_generator(csv_path, batch_size)
model = Transformer(vocab_size=tokenizer.vocab_size)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

In [6]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0

    # Reset generator at the start of each epoch
    progress_bar = tqdm(iter(data_generator), total=num_steps)
    progress_bar.set_description(f"Epoch {epoch + 1}")

    for step, raw_batch in enumerate(progress_bar, start=1):
        # Prepare data for the model
        training_batch = make_batch(raw_batch)  # Converts batch to `TrainingBatch` format
        input_ids = training_batch.x  # Source sentences token IDs (French)
        target_ids = training_batch.y  # Target sentences token IDs (English)
        encoder_mask = training_batch.encoder_mask  # Mask for encoder

        # Shift target_ids for teacher forcing
        decoder_input_ids = target_ids[:, :-1]  # All except last token as input
        labels = target_ids[:, 1:]  # All except first token as target

        # Forward pass
        optimizer.zero_grad()
        output_probs = model(input_ids, decoder_input_ids, encoder_mask=encoder_mask)

        # Calculate the loss
        loss = loss_fn(
            output_probs.view(-1, output_probs.size(-1)),
            labels.reshape(-1))
        loss.backward()
        optimizer.step()

        # Track loss
        epoch_loss += loss.item()

        # Optionally, print progress
        progress_bar.set_postfix_str(f"current loss : {loss.item():.4f} ;"
                                     f"epoch loss : {epoch_loss / step:.4f}")

    # Print average loss per epoch
    print(f"Epoch [{epoch+1}/{num_epochs}] completed, Average Loss: {epoch_loss / num_steps:.4f}")

Epoch 1:   0%|          | 12/703762 [01:14<1219:30:15,  6.24s/it, current loss : 6.9620 ;epoch loss : 8.2534] 


KeyboardInterrupt: 