In [None]:
# https://www.kaggle.com/datasets/dhruvildave/en-fr-translation-dataset

from typing import NamedTuple, Generator, Callable
import pandas as pd
import torch
import torch.nn.functional as F
from torch import Tensor, optim, nn
from tqdm import tqdm
from transformers import AutoTokenizer
from src.transformer import Transformer

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class PairedSentences(NamedTuple):
    fr: str
    en: str

class ListPairedSentences(NamedTuple):
    fr: list[str]
    en: list[str]

    def __getitem__(self, index: int) -> PairedSentences:
        return PairedSentences(self.fr[index], self.en[index])

class TrainingBatch(NamedTuple):
    french: Tensor
    english: Tensor
    encoder_mask: Tensor | None
    decoder_mask: Tensor | None

    def __repr__(self):
        return f"TrainingBatch(x.shape={self.french.shape}, y.shape={self.english.shape}, encoder_mask.shape={self.encoder_mask.shape})"

In [None]:
class Tokenerizer:
    def __init__(self, sequence_length: int, name: str) -> None:
        self._tokenizer = AutoTokenizer.from_pretrained(name)
        self._seq_length = sequence_length

    def tokenerize(self, text: str, padding="max_length", truncation=True):
        return self._tokenizer(
            text,
            return_tensors="pt",
            max_length=self._seq_length,
            padding=padding,
            truncation=truncation)

    def decode(self, token_ids: Tensor, **kwargs) -> str:
        return self._tokenizer.decode(token_ids, **kwargs)

    @property
    def tokenizer(self) -> AutoTokenizer:
        return self._tokenizer

    @property
    def sequence_length(self) -> int:
        return self._seq_length

    @property
    def vocab_size(self) -> int:
        return self._tokenizer.vocab_size

def make_batch(paired_sentences: ListPairedSentences, tokenizer: Tokenerizer) -> TrainingBatch:
    # Tokenize each sentence in the 'fr' and 'en' lists
    fr_sentences = [tokenizer.tokenerize(sentence) for sentence in paired_sentences.fr]
    en_sentences = [tokenizer.tokenerize(sentence) for sentence in paired_sentences.en]

    # Stack tokenized tensors for batching
    X_batch = torch.stack([x['input_ids'].squeeze(0) for x in fr_sentences])
    Y_batch = torch.stack([y['input_ids'].squeeze(0) for y in en_sentences])

    # Create encoder and decoder padding mask: 1 for real tokens, 0 for padding
    encoder_mask = torch.stack([x['attention_mask'].squeeze(0) for x in fr_sentences]) \
        .unsqueeze(1).unsqueeze(2)
    decoder_mask = torch.stack([y['attention_mask'].squeeze(0) for y in en_sentences]) \
        .unsqueeze(1).unsqueeze(2)

    return TrainingBatch(
        french=X_batch,
        english=Y_batch,
        encoder_mask=encoder_mask.to(torch.float32),
        decoder_mask=decoder_mask.to(torch.float32))

def get_first_masked_token(mask: torch.Tensor) -> torch.IntTensor:
    squeezed_mask = mask.squeeze(1).squeeze(1) # mask is shaped (bs, 1, 1, sequence_length)
    first_masked_indices = (squeezed_mask == 0).int().argmax(dim=1)
    first_masked_indices[squeezed_mask.sum(dim=1) == squeezed_mask.size(1)] = squeezed_mask.size(1)
    return first_masked_indices

def mask_last_token(current_mask: torch.Tensor) -> torch.Tensor:
    first_masked_indices = get_first_masked_token(current_mask) # get the index of first masked token
    last_token_indices = torch.clamp(first_masked_indices - 1, min=0) # to avoid negative indices
    current_mask[torch.arange(current_mask.size(0)), 0, 0, last_token_indices] = 0 # set the last 1 token to 0
    return current_mask

In [None]:
def get_page(csv_path: str, page: int, rows_per_page: int):
    return pd.read_csv(csv_path, skiprows = 1 + page * rows_per_page, nrows=rows_per_page, header=None, names=["en", "fr"])

def make_generator(csv_path: str, rows_per_page: int) -> Generator[ListPairedSentences, None, None]:
    i = 0
    while True:
        page = get_page(csv_path, i, rows_per_page)
        fr_sentences = page["fr"].to_list()
        en_sentences = page["en"].to_list()
        yield ListPairedSentences(fr_sentences, en_sentences)
        i += 1


def get_num_steps(csv_path: str, rows_per_page: int) -> int:
    total_rows = sum(1 for _ in open(csv_path)) - 1 # minus one for header row
    num_steps = (total_rows + rows_per_page - 1) // rows_per_page  # Round up
    return num_steps

In [None]:
csv_path = "data/en-fr.csv"
tokenizer = Tokenerizer(200, "bert-base-uncased")
model = Transformer(vocab_size=tokenizer.vocab_size, max_sequence_len=tokenizer.sequence_length).to(DEVICE)

In [None]:
# opt and loss
learning_rate = 1e-4
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.tokenizer.pad_token_id).to(DEVICE)

# training loop qty
num_epochs = 1
batch_size = 32
num_steps = get_num_steps(csv_path, batch_size)

In [None]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0

    # Reset generator at the start of each epoch
    progress_bar = tqdm(make_generator(csv_path, batch_size), total=num_steps)
    progress_bar.set_description(f"Epoch {epoch + 1}")

    for step, raw_batch in enumerate(progress_bar, start=1):
        # Prepare data for the model
        training_batch = make_batch(raw_batch, tokenizer)  # Converts batch to `TrainingBatch` format
        input_ids = training_batch.french.to(DEVICE)  # Source sentences token IDs (French)
        target_ids = training_batch.english.to(DEVICE)  # Target sentences token IDs (English)
        encoder_mask = training_batch.encoder_mask.to(DEVICE)
        decoder_mask = training_batch.decoder_mask.to(DEVICE)

        # Forward pass
        optimizer.zero_grad()
        output_probs = model(input_ids,
                             target_ids,
                             encoder_mask=encoder_mask,
                             decoder_mask=mask_last_token(decoder_mask)) # mask last token for teacher forcing

        # flatten target and outputprobs to compute cce loss
        target_ids_flat = target_ids.view(-1)
        output_probs_flat = output_probs.view(-1, output_probs.size(-1))

        # Calculate the loss
        loss = loss_fn(
            output_probs_flat,
            target_ids_flat)
        loss.backward()
        optimizer.step()

        # Track loss
        epoch_loss += loss.item()

        # Optionally, print progress
        progress_bar.set_postfix_str(f"current loss : {loss.item():.4f} ; "
                                     f"epoch loss : {epoch_loss / step:.4f}")

    # Print average loss per epoch
    print(f"Epoch [{epoch+1}/{num_epochs}] completed, Average Loss: {epoch_loss / num_steps:.4f}")

In [None]:
@torch.no_grad()
def infer(model: nn.Module, tokenerizer: Tokenerizer, french_sentence: str, max_length: int | None = None) -> str:
    # first model as eval (we don't train here)
    model.eval()
    sequence_length = max_length if max_length else tokenerizer.sequence_length

    # Tokenize the input sequence in french
    tokens = tokenerizer.tokenerize(french_sentence)

    # create the encoder input and mask
    encoder_ids = tokens["input_ids"].int().to(DEVICE)
    encoder_mask = tokens["attention_mask"].unsqueeze(0).unsqueeze(0).to(DEVICE)

    # initialize decoder mask and input
    decoder_mask = torch.zeros((1, 1, 1, tokenerizer.sequence_length)).to(DEVICE)
    decoder_mask[0, 0, 0, 0] = 1 # unmask the start of sequence token
    target_ids = torch.zeros((1, tokenerizer.sequence_length)).int().to(DEVICE)
    target_ids[0, 0] = tokenerizer.tokenizer.cls_token_id # start of sequence

    # loop to generate output ids
    for idx in range(1, sequence_length):
        output_probs: Tensor = model(
            encoder_ids,
            target_ids,
            encoder_mask=encoder_mask,
            decoder_mask=decoder_mask)

        # select next token ID with the highest probability
        next_token_id = output_probs.argmax(dim=-1)[0, idx]
        target_ids[0, idx] = next_token_id
        decoder_mask[0, 0, 0, idx] = 1 # unmask the generated token

        # early stop when encounter sep_token_id
        if next_token_id.item() == tokenerizer.tokenizer.sep_token_id:
            break

    return tokenerizer.decode(target_ids[0], skip_special_tokens=True)

output = infer(model, tokenizer, "ahah")