<a target="_blank" href="https://colab.research.google.com/github/pr-Mais/arabic-reverse-dictionary/blob/main/code/mt5_training_shared_task.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Fine-tuning MT5 for Reverse Dictionary

In [None]:
# This script is used to mount the google drive to the colab environment.
import sys

IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
    from google.colab import drive

    drive.mount("/content/drive")

## Requirements

We require the following libraries for the modeling process:

In [None]:
!pip3 install -q -U datasets==2.18.0
!pip3 install -q -U transformers==4.39.2
!pip3 install -q -U evaluate==0.4.1
!pip3 install -q -U scikit-learn==1.4.1.post1
!pip3 install -q -U torch==2.2.1
!pip3 install -q -U tokenizers==0.15.2
!pip3 install -q -U tqdm==4.66.2
!pip3 install -q -U pandas==2.2.1
!pip3 install -q -U numpy==1.26.4

Next, we import all required libraries and modules.

In [None]:
import pandas as pd
from tqdm import tqdm
from typing import Literal
from datasets import Dataset, DatasetDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    AdamW,
    set_seed,
    get_linear_schedule_with_warmup,
)
from transformers.modeling_outputs import Seq2SeqLMOutput

# For reproducibility
set_seed(123)
checkpoint = "UBC-NLP/AraT5v2-base-1024"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TargetEmbeddingType = Literal["electra", "bertseg", "bertmsa"]

# The max length of the embeddings, this should
# match the length of the target embeddings length.
max_length = {
    "electra": 256,
    "bertseg": 768,
    "bertmsa": 768,
}

# Which target embedding to use as the target for the model?
target_embedding: TargetEmbeddingType = "bertseg"
max_length = max_length[target_embedding]

### Loading and preprocessing data

In [None]:
train_ds_path = "../data/shared-task/train_with_examples.json"
val_ds_path = "../data/shared-task/dev.json"
test_ds_path = "../data/shared-task/test.json"

In [None]:
train_df = pd.read_json(train_ds_path, encoding="utf-8")
val_df = pd.read_json(val_ds_path, encoding="utf-8")
test_df = pd.read_json(test_ds_path, encoding="utf-8")

In [None]:
# Printing out some information about the datasets
print(f"Train dataset has {len(train_df)} examples, and the following columns:")
print(train_df.columns.tolist())
print()
print(f"Validation dataset has {len(val_df)} examples, and the following columns:")
print(val_df.columns.tolist())
print()
print(f"Test dataset has {len(test_df)} examples, and the following columns:")
print(test_df.columns.tolist())

In [None]:
# Merge train and validation into one dict with keys `train` and `val`.
# This is for training and development, test set has no targets provided.
train_val_dict = {
    "train": train_df.to_dict("records"),
    "val": val_df.to_dict("records"),
}
# Convert to HF dataset
train_ds = Dataset.from_pandas(train_df, split="train")
val_ds = Dataset.from_pandas(val_df, split="validation")
test_ds = Dataset.from_pandas(test_df, split="test")
dataset = DatasetDict({"train": train_ds, "val": val_ds})

In the next steps, we prepare the data for modeling. 

The features we care about from the dataset are the `gloss` and `examples`, and the target is either `electra`, `bertseg` or `bertmsa`. This means we will train 3 different models on each target.

In the preprocessing step, we will tokenize the data and convert it to a format that can be fed into the model. This includes merging the `gloss` and `examples` into a single string, tokenizing the string, and converting the tokens to token ids.

In [None]:
# Tokenization step
tokenizer = AutoTokenizer.from_pretrained(checkpoint, legacy=False)

padding = "max_length"
max_input_length = 256


def preprocess_function(items):
    # The inputs are the glosses + examples
    if "examples" in items:
        glosses = [
            f"{gloss}. {example[0]}، {example[1]}"
            for gloss, example in zip(items["gloss"], items["examples"])
        ]
    else:
        glosses = items["gloss"]

    model_inputs = tokenizer(
        glosses,
        max_length=max_input_length,
        padding=padding,
        truncation=True,
        return_tensors="pt",
    )

    # Adding the 3 types of target embeddings, if they are available.
    if "electra" in items:
        model_inputs["electra"] = items["electra"]  # Electra embeddings
    if "bertseg" in items:
        model_inputs["bertseg"] = items["bertseg"]  # BERTseg embeddings
    if "bertmsa" in items:
        model_inputs["bertmsa"] = items["bertmsa"]  # BERTmsa embeddings

    targets = [ex for ex in items["word"]]
    # encode the words
    labels = tokenizer(
        targets,
        max_length=max_length,
        padding=padding,
        truncation=True,
        return_tensors="pt",
    ).input_ids

    # important: we need to replace the index of the padding tokens by -100
    # such that they are not taken into account by the CrossEntropyLoss
    labels_with_ignore_index = []
    for labels_example in labels:
        labels_example = [label if label != 0 else -100 for label in labels_example]
        labels_with_ignore_index.append(labels_example)

    model_inputs["labels"] = labels_with_ignore_index

    return model_inputs


# Final mapping of the dataset
dataset = dataset.map(preprocess_function, batched=True)

Finally, data splits are converted into PyTorch `DataLoader` objects for training.

In [None]:
dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "labels", "electra", "bertseg", "bertmsa"],
    output_all_columns=False,
)
train_dataloader = DataLoader(dataset["train"], shuffle=True, batch_size=8)
valid_dataloader = DataLoader(dataset["val"], batch_size=4)

In [None]:
print(
    f"Electra embeddings shape: 1, {dataset['train']['electra'].shape[1]}\n"
    f"BERTseg embeddings shape: 1, {dataset['train']['bertseg'].shape[1]}\n"
    f"BERTmsa embeddings shape: 1, {dataset['train']['bertmsa'].shape[1]}\n"
)

## Training

The training pipeline is prepared to accept 3 types of targets: `electra`, `bertseg`, and `bertmsa`. The training process is the same for all targets.

In [None]:
# Loss functions
mse_loss = nn.MSELoss()


def train(
    dataloader,
    optimizer_,
    scheduler_,
    device_,
    target: TargetEmbeddingType = "electra",
    validate=False,
):
    # Use global variable for model.
    global model
    # Tracking variables.
    predictions = []
    ground_truth = []
    # Total loss for this epoch.
    total_loss = 0
    if not validate:
        model.train()
    if validate:
        model.eval()

    # For each batch of training data...
    for batch in tqdm(dataloader):
        ground_truth += batch[target].numpy().tolist()
        inputs = {
            k: v.to(device_)
            for k, v in batch.items()
            if k in ["input_ids", "attention_mask"]
        }
        labels = batch["labels"].to(device)
        if not validate:
            model.zero_grad()
        embeddings = model(**inputs, labels=labels, return_dict=True)
        # Loss is calculated on target embeddings, outside the model.
        loss = mse_loss(embeddings, batch[target].to(device_))
        total_loss += loss.item()
        if not validate:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer_.step()
            scheduler_.step()
        predictions += embeddings.tolist()

    # Calculate the average loss over the training data.
    avg_epoch_loss = total_loss / len(dataloader)

    # Return all true labels and prediction for future evaluations.
    return ground_truth, predictions, avg_epoch_loss

In [None]:
class RevDictModel(nn.Module):
    def __init__(self, max_length: int, checkpoint: str):
        super().__init__()
        model_config = AutoConfig.from_pretrained(checkpoint)
        self.base_model = AutoModelForSeq2SeqLM.from_config(model_config)

        print(max_length)

        # Redefining the linear layer to match the target embedding size (max_length)
        self.linear = nn.Linear(self.base_model.config.hidden_size, max_length)

    def forward(self, input_ids, attention_mask, **kwargs):
        # Only using the encoder part to generate embeddings
        outputs = self.base_model.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        pooled_emb = (outputs.last_hidden_state * attention_mask.unsqueeze(2)).sum(
            dim=1
        ) / attention_mask.sum(dim=1).unsqueeze(1)
        embedding = self.linear(pooled_emb)
        return embedding

    def save(self, file):
        torch.save(self, file)

    @staticmethod
    def load(file):
        return torch.load(file, map_location=device)

### Model training per target

In [None]:
# Hyperparameters.
epochs = 5
lr = 4e-5

In [None]:
model = RevDictModel(max_length=max_length, checkpoint=checkpoint)
model.to(device)

In [None]:
# Total number of training steps is number of batches * number of epochs.
# `train_dataloader` contains batched data so `len(train_dataloader)` gives
# us the number of batches.
total_steps = len(train_dataloader) * epochs

optimizer = AdamW(
    model.parameters(),
    lr=lr,
    eps=1e-8,
)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps,
)

# Store the average loss after each epoch so we can plot them.
all_loss = {"train_loss": [], "val_loss": []}

# Loop through each epoch.
for epoch in range(epochs):
    print(f"Epoch {epoch}")
    train_labels, train_predict, train_loss = train(
        train_dataloader, optimizer, scheduler, device, target=target_embedding
    )
    valid_labels, valid_predict, val_loss = train(
        valid_dataloader, optimizer, scheduler, device, validate=True
    )

    # Print loss and accuracy values to see how training evolves.
    print("  train_loss: %.5f - val_loss: %.5f" % (train_loss, val_loss))
    print()

    # Store the loss value for plotting the learning curve.
    all_loss["train_loss"].append(train_loss)
    all_loss["val_loss"].append(val_loss)

# Plot loss curves.
# plot_dict(all_loss, use_xlabel="Epochs", use_ylabel="Value", use_linestyles=["-", "--"])

Save the trained model.

In [None]:
save_directory = f"/content/drive/MyDrive/mt5_{target_embedding}_checkpoint_0"
model.save(save_directory)

## Evaluation

In [None]:
# re-load the model
model = RevDictModel.load("../checkpoints/mt5_shared_task_checkpoint_0")

In [None]:
predictions = []
for sample in tqdm(valid_dataloader):
    with torch.no_grad():
        inputs = {
            k: v.to(device)
            for k, v in sample.items()
            if k in ["input_ids", "attention_mask"]
        }
        outputs = model(**inputs)
        predictions += outputs

### P@K with Cosine Similarity

Here we use the out embeddings and find the top 3 similar words from the test set then find how many of them match the ground truth.

In [None]:
from torch.nn import functional as F
from typing import List


# Calculate top n words similar to the output embedding
def get_top_n(emb: torch.Tensor, predictions: List[torch.Tensor], k: int = 5):
    scores = []
    for item in predictions:
        # Find the similarity score
        score = F.cosine_similarity(emb.to(device), item, dim=0)
        # Append to total results
        scores.append(score.item())
    # get top k
    return sorted(scores, reverse=True)[:k]

In [None]:
scores = []
i = 0
for item in valid_dataloader.dataset:
    emb = item["electra"]
    scores.append(get_top_n(emb, predictions, k=3))
    i += 1
    if i == 5:
        break

In [None]:
# Get the average score
average_at_k = [sum(x) / len(x) for x in zip(*scores)]
average = sum(average_at_k) / len(average_at_k)
average

# Use it