# Fine-tuning MT5 for Reverse Dictionary

## Requirements

In [None]:
!pip3 install -q -U datasets==2.18.0
!pip3 install -q -U accelerate==0.27.1
!pip3 install -q -U datasets==2.17.0
!pip3 install -q -U transformers==4.39.2
!pip3 install -q -U evaluate==0.4.1
!pip3 install -q -U scikit-learn==1.4.1.post1
!pip3 install -q -U pytorch_lightning==2.2.2

In [None]:
import pandas as pd
from tqdm import tqdm

from datasets import Dataset, DatasetDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    AdamW,
    set_seed,
    get_linear_schedule_with_warmup,
)
from transformers.modeling_outputs import Seq2SeqLMOutput

# For reproducibility
set_seed(123)
# Define needed parameters
checkpoint = "UBC-NLP/AraT5v2-base-1024"
epochs = 3
lr = 5e-4
# The max length of the embeddings, this should match the length of the shared task target embeddings
max_length = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Loading and preprocessing data

In [None]:
train_df = pd.read_json(
    "../data/shared-task/train_with_examples.json", encoding="utf-8"
)
val_df = pd.read_json("../data/shared-task/dev.json", encoding="utf-8")
test_df = pd.read_json("../data/shared-task/test.json", encoding="utf-8")

In [None]:
train_df.head()

In [None]:
# Merge train and val into one dict with keys train and val
train_val_dict = {
    "train": train_df.to_dict("records"),
    "val": val_df.to_dict("records"),
}

In [None]:
# Convert to HF dataset
train_ds = Dataset.from_pandas(train_df, split="train")
val_ds = Dataset.from_pandas(val_df, split="validation")
test_ds = Dataset.from_pandas(test_df, split="test")
dataset = DatasetDict({"train": train_ds, "val": val_ds})

In [None]:
# Take a look at the dataset
dataset

In [None]:
# Tokenization step
tokenizer = AutoTokenizer.from_pretrained(checkpoint, legacy=False)

padding = "max_length"
max_input_length = 256


def preprocess_function(items):
    # The inputs are the glosses + examples
    if "examples" in items:
        glosses = [
            f"{gloss}. {example[0]}، {example[1]}"
            for gloss, example in zip(items["gloss"], items["examples"])
        ]
    else:
        glosses = items["gloss"]

    model_inputs = tokenizer(
        glosses,
        max_length=max_input_length,
        padding=padding,
        truncation=True,
    )

    # Adding the 3 types of target embeddings, if they are available.
    if "electra" in items:
        model_inputs["electra"] = items["electra"]  # Electra embeddings
    if "bertseg" in items:
        model_inputs["bertseg"] = items["bertseg"]  # BERTseg embeddings
    if "bertmsa" in items:
        model_inputs["bertmsa"] = items["bertmsa"]  # BERTmsa embeddings

    targets = [ex for ex in items["word"]]
    # encode the words
    labels = tokenizer(
        targets, max_length=max_length, padding=padding, truncation=True
    ).input_ids

    # important: we need to replace the index of the padding tokens by -100
    # such that they are not taken into account by the CrossEntropyLoss
    labels_with_ignore_index = []
    for labels_example in labels:
        labels_example = [label if label != 0 else -100 for label in labels_example]
        labels_with_ignore_index.append(labels_example)

    model_inputs["labels"] = labels_with_ignore_index

    return model_inputs


# Final mapping of the dataset
dataset = dataset.map(preprocess_function, batched=True)

In [None]:
from torch.utils.data import DataLoader

dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "labels", "electra", "bertseg", "bertmsa"],
    output_all_columns=False,
)
train_dataloader = DataLoader(dataset["train"], shuffle=True, batch_size=8)
valid_dataloader = DataLoader(dataset["val"], batch_size=4)

In [None]:
print(
    f"Electra embeddings shape: 1, {dataset['train']['electra'].shape[1]}\n"
    f"BERTseg embeddings shape: 1, {dataset['train']['bertseg'].shape[1]}\n"
    f"BERTmsa embeddings shape: 1, {dataset['train']['bertmsa'].shape[1]}\n"
)

## Training

In [None]:
from typing import Literal
import torch.nn as nn

# Loss functions
mse_loss = nn.MSELoss()

TargetEmbeddingType = Literal["electra", "bertseg", "bertmsa"]


def train(
    dataloader,
    optimizer_,
    scheduler_,
    device_,
    target: TargetEmbeddingType = "electra",
    validate=False,
):
    # Use global variable for model.
    global model
    # Tracking variables.
    predictions = []
    ground_truth = []
    # Total loss for this epoch.
    total_loss = 0
    if not validate:
        model.train()
    if validate:
        model.eval()

    # For each batch of training data...
    for batch in tqdm(dataloader):
        ground_truth += batch[target].numpy().tolist()
        inputs = {
            k: v.to(device_)
            for k, v in batch.items()
            if k in ["input_ids", "attention_mask"]
        }
        labels = batch["labels"].to(device)
        if not validate:
            model.zero_grad()
        embeddings = model(**inputs, labels=labels, return_dict=True)
        # Loss is calculated on target embeddings, outside the model.
        loss = mse_loss(embeddings, batch[target])
        total_loss += loss.item()
        if not validate:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer_.step()
            scheduler_.step()
        predictions += embeddings.tolist()
        break

    # Calculate the average loss over the training data.
    avg_epoch_loss = total_loss / len(dataloader)

    # Return all true labels and prediction for future evaluations.
    return ground_truth, predictions, avg_epoch_loss

In [None]:
class RevDictModel(nn.Module):
    def __init__(self, from_pretrained: bool, max_len: int, checkpoint: str):
        super().__init__()
        if from_pretrained:
            self.base_model = AutoModelForSeq2SeqLM.from_pretrained(
                checkpoint
            )  # Using AutoModel which is better for embedding extraction
        else:
            model_config = AutoConfig.from_pretrained(checkpoint)
            self.base_model = AutoModelForSeq2SeqLM.from_config(model_config)

        # Redefining the linear layer to match the target embedding size (max_len)
        self.linear = nn.Linear(self.base_model.config.hidden_size, max_len)

    def forward(self, input_ids, attention_mask, labels, **kwargs):
        # Only using the encoder part to generate embeddings
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True,
        )
        pooled_emb = (
            outputs.encoder_last_hidden_state * attention_mask.unsqueeze(2)
        ).sum(dim=1) / attention_mask.sum(dim=1).unsqueeze(1)
        embedding = self.linear(pooled_emb)
        return embedding

    def save(self, file):
        torch.save(self, file)

    @staticmethod
    def load(file):
        return torch.load(file)

In [None]:
model = RevDictModel(from_pretrained=True, max_len=max_length, checkpoint=checkpoint)
model.to(device)

In [None]:
optimizer = AdamW(
    model.parameters(),
    lr=2e-5,
    eps=1e-8,
)
# Total number of training steps is number of batches * number of epochs.
# `train_dataloader` contains batched data so `len(train_dataloader)` gives
# us the number of batches.
total_steps = len(train_dataloader) * epochs

# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps,
)

# Store the average loss after each epoch so we can plot them.
all_loss = {"train_loss": [], "val_loss": []}

# Loop through each epoch.
print("Epoch")
for epoch in tqdm(range(epochs)):
    print()
    print("Training on batches...")
    # Perform one full pass over the training set.
    train_labels, train_predict, train_loss = train(
        train_dataloader, optimizer, scheduler, device
    )

    # Get prediction from model on validation data.
    print("Validation on batches...")
    valid_labels, valid_predict, val_loss = train(
        valid_dataloader, optimizer, scheduler, device, validate=True
    )

    # Print loss and accuracy values to see how training evolves.
    print("  train_loss: %.5f - val_loss: %.5f" % (train_loss, val_loss))
    print()

    # Store the loss value for plotting the learning curve.
    all_loss["train_loss"].append(train_loss)
    all_loss["val_loss"].append(val_loss)

# Plot loss curves.
# plot_dict(all_loss, use_xlabel="Epochs", use_ylabel="Value", use_linestyles=["-", "--"])

In [None]:
save_directory = "/content/drive/MyDrive/mt5_shared_task_checkpoint_0"
model.model.save_pretrained(save_directory)

## Evaluation

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# re-load the model
model = RevDictModel.from_pretrained("../checkpoints/mt5_v3").to(device)
encoder_model = RevDictModel.from_pretrained("../checkpoints/mt5_v3").to(device)

In [78]:
eval_dataset = dataset["val"].with_format("torch")

In [None]:
import evaluate
import tqdm


def run_evaluate(sample, target: TargetEmbeddingType = "electra"):
    outputs = model.generate(
        torch.Tensor(sample["input_ids"])
        .expand(1, len(sample["input_ids"]))
        .to(device)
        .long()
    )
    return outputs, sample[target]


predictions, references = [], []
for sample in tqdm(eval_dataset):
    p, l = run_evaluate(sample)
    predictions.append(p)
    references.append(l)

### P@3 with Cosine Similarity

Here we use the out embeddings and find the top 3 similar words from the test set then find how many of them match the ground truth.

In [None]:
from torch.nn.functional import cosine_similarity
from typing import List

model.eval()


def test_set_embeddings(test_set):
    outputs = encoder_model(input)
    last_hidden_states = outputs.last_hidden_state
    input_mask = input["input_ids"] != tokenizer.pad_token_id
    input_mask_expanded = input_mask.unsqueeze(-1).expand(last_hidden_states.size())
    sum_embeddings = torch.sum(last_hidden_states * input_mask_expanded, 1)
    sum_mask = input_mask_expanded.sum(1)
    mean_embeddings = sum_embeddings / sum_mask
    return mean_embeddings


# Generate the word embedding for a query
def get_prediction_embeddings(sample):
    # First, generate the prediction
    input_ids = tokenizer(
        sample["query"], return_tensors="pt", padding=True, truncation=True
    ).input_ids
    output_word = model.generate(input_ids)
    decoded_output = tokenizer.decode(output_word[0], skip_special_tokens=True)

    # Next, generate the embeddings of the predicted word
    input_ids = tokenizer(
        decoded_output, return_tensors="pt", padding=True, truncation=True
    ).input_ids
    outputs = encoder_model(input_ids)
    last_hidden_states = outputs.last_hidden_state
    input_mask = input_ids != tokenizer.pad_token_id
    input_mask_expanded = input_mask.unsqueeze(-1).expand(last_hidden_states.size())
    sum_embeddings = torch.sum(last_hidden_states * input_mask_expanded, 1)
    sum_mask = input_mask_expanded.sum(1)
    mean_embeddings = sum_embeddings / sum_mask
    return mean_embeddings


# Calculate top n words similar to the output embedding
def get_top_n(emb: torch.Tensor, data: List[torch.Tensor], n: int = 5):
    scores = []
    for item in data:
        # Find the similarity score
        score = cosine_similarity(emb, item)
        # Append to total results
        scores.append(score.item())
    # get top n
    return sorted(scores, reverse=True)[:n]

In [None]:
test_embeds = test_set_embeddings(eval_dataset)

In [None]:
scores = []
i = 0
for item in eval_dataset:
    test_word = get_prediction_embeddings(item)
    scores.append(get_top_n(test_word, test_embeds, n=1))
    i += 1
    if i == 100:
        break

In [None]:
# Get the average score
average_at_3 = [sum(x) / len(x) for x in zip(*scores)]
average = sum(average_at_3) / len(average_at_3)
average

# Use it

In [None]:
input_text = "البلاغة"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(
    device
)  # pt for PyTorch tensors
output_ids = model.generate(input_ids)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)