# Fine-tuning MT5 for Reverse Dictionary

## Requirements

In [None]:
!pip3 install -q -U datasets==2.18.0
!pip3 install -q -U accelerate==0.27.1
!pip3 install -q -U datasets==2.17.0
!pip3 install -q -U transformers==4.39.2
!pip3 install -q -U evaluate==0.4.1
!pip3 install -q -U scikit-learn==1.4.1.post1
!pip3 install -q -U pytorch_lightning==2.2.2

In [None]:
import pandas as pd
import datasets
import torch
from transformers import AutoTokenizer
from transformers import (
    MT5ForConditionalGeneration,
    MT5EncoderModel,
    AdamW,
    get_linear_schedule_with_warmup,
)

In [None]:
## Load the data
data = pd.read_json("../data/shared-task/train.json", encoding="utf-8")
data.head()


In [None]:
data_dict = data.to_dict(orient="records")
data_list = []

for i, row in enumerate(data_dict):
    word = row["UnDiacWord"]
    definition = row["Definition"]
    example = row["examples"]

    if example:
        example = eval(example)[0] + "،" + eval(example)[1]

    data_list.append({"query": f"{definition}", "word": f"{word}"})

In [None]:
# Convert the list to HF dataset
dataset = datasets.Dataset.from_list(data_list)
dataset = dataset.train_test_split(test_size=0.1)
test_val_split = dataset["test"].train_test_split(test_size=0.5)
dataset = datasets.DatasetDict(
    {
        "train": dataset["train"],
        "validation": test_val_split["test"],
        "test": test_val_split["train"],
    }
)

tokenizer = AutoTokenizer.from_pretrained(checkpoint, legacy=False)
padding = "max_length"
max_input_length = 256
max_target_length = 128


def preprocess_function(examples):
    inputs = [ex for ex in examples["query"]]
    targets = [ex for ex in examples["word"]]
    model_inputs = tokenizer(
        inputs, max_length=max_input_length, padding=padding, truncation=True
    )

    # encode the summaries
    labels = tokenizer(
        targets, max_length=max_target_length, padding=padding, truncation=True
    ).input_ids

    # important: we need to replace the index of the padding tokens by -100
    # such that they are not taken into account by the CrossEntropyLoss
    labels_with_ignore_index = []
    for labels_example in labels:
        labels_example = [label if label != 0 else -100 for label in labels_example]
        labels_with_ignore_index.append(labels_example)

    model_inputs["labels"] = labels_with_ignore_index

    return model_inputs


dataset = dataset.map(preprocess_function, batched=True)

In [None]:
from torch.utils.data import DataLoader

dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
train_dataloader = DataLoader(dataset["train"], shuffle=True, batch_size=8)
valid_dataloader = DataLoader(dataset["validation"], batch_size=4)
test_dataloader = DataLoader(dataset["test"], batch_size=4)

batch = next(iter(train_dataloader))
print(batch.keys())
tokenizer.decode(batch["input_ids"][0])
labels = batch["labels"][0]
tokenizer.decode([label for label in labels if label != -100])

## Training

In [None]:
import math

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModel, AutoConfig, AutoModelForSeq2SeqLM, set_seed
from transformers.modeling_outputs import Seq2SeqLMOutput

# For reproducibility
set_seed(123)
# Define needed parameters
checkpoint = "UBC-NLP/AraT5v2-base-1024"
epochs = 3
lr = 5e-4
# The max length of the embeddings, this should match the length of the shared task target embeddings
max_length = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def train(dataloader, optimizer_, scheduler_, device_):
  r"""
  Train pytorch model on a single pass through the data loader.
  It will use the global variable `model` which is the transformer model 
  loaded on `_device` that we want to train on.
  This function is built with reusability in mind: it can be used as is as long
    as the `dataloader` outputs a batch in dictionary format that can be passed 
    straight into the model - `model(**batch)`.
  Arguments:
      dataloader (:obj:`torch.utils.data.dataloader.DataLoader`):
          Parsed data into batches of tensors.
      optimizer_ (:obj:`transformers.optimization.AdamW`):
          Optimizer used for training.
      scheduler_ (:obj:`torch.optim.lr_scheduler.LambdaLR`):
          PyTorch scheduler.
      device_ (:obj:`torch.device`):
          Device used to load tensors before feeding to model.
  Returns:
      :obj:`List[List[int], List[int], float]`: List of [True Labels, Predicted
        Labels, Train Average Loss].
  """

  # Use global variable for model.
  global model

  # Tracking variables.
  predictions_labels = []
  true_labels = []
  # Total loss for this epoch.
  total_loss = 0

  # Put the model into training mode.
  model.train()

  # For each batch of training data...
  for batch in tqdm(dataloader, total=len(dataloader)):

    # Add original labels - use later for evaluation.
    true_labels += batch['labels'].numpy().flatten().tolist()
    
    # move batch to device
    batch = {k:v.type(torch.long).to(device_) for k,v in batch.items()}
    
    # Always clear any previously calculated gradients before performing a
    # backward pass.
    model.zero_grad()

    # Perform a forward pass (evaluate the model on this training batch).
    # This will return the loss (rather than the model output) because we
    # have provided the `labels`.
    # The documentation for this a bert model function is here: 
    # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
    outputs = model(**batch)

    # The call to `model` always returns a tuple, so we need to pull the 
    # loss value out of the tuple along with the logits. We will use logits
    # later to calculate training accuracy.
    loss, logits = outputs[:2]

    # Accumulate the training loss over all of the batches so that we can
    # calculate the average loss at the end. `loss` is a Tensor containing a
    # single value; the `.item()` function just returns the Python value 
    # from the tensor.
    total_loss += loss.item()

    # Perform a backward pass to calculate the gradients.
    loss.backward()

    # Clip the norm of the gradients to 1.0.
    # This is to help prevent the "exploding gradients" problem.
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    # Update parameters and take a step using the computed gradient.
    # The optimizer dictates the "update rule"--how the parameters are
    # modified based on their gradients, the learning rate, etc.
    optimizer_.step()

    # Update the learning rate.
    scheduler_.step()

    # Move logits and labels to CPU
    logits = logits.detach().cpu().numpy()

    # Convert these logits to list of predicted labels values.
    predictions_labels += logits.argmax(axis=-1).flatten().tolist()

  # Calculate the average loss over the training data.
  avg_epoch_loss = total_loss / len(dataloader)
  
  # Return all true labels and prediction for future evaluations.
  return true_labels, predictions_labels, avg_epoch_loss

In [None]:
class RevDictModel(nn.Module):
    def __init__(self, from_pretrained: bool, max_len:int, checkpoint: str):
        super().__init__()
        if from_pretrained:
            self.base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
        else:
            model_config = AutoConfig.from_pretrained(checkpoint)
            self.base_model = AutoModelForSeq2SeqLM.from_config(model_config)

        self.linear = nn.Linear(self.base_model.config.hidden_size, max_len)

    def forward(self, input_ids, attention_mask, labels):
        # In this forward pass, we transform the model output to a pooled embedding
        outputs: Seq2SeqLMOutput = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True,
        )
        pooled_emb = (
            outputs.encoder_last_hidden_state * attention_mask.unsqueeze(2)
        ).sum(dim=1) / attention_mask.sum(dim=1).unsqueeze(1)
        embedding = self.linear(pooled_emb)
        return outputs.loss, embedding

    def save(self, file):
        torch.save(self, file)

    @staticmethod
    def load(file):
        return torch.load(file)

In [None]:
model = RevDictModel(
    from_pretrained=True, max_len=max_length, checkpoint=checkpoint
)
model.to(device)

In [None]:
import tqdm


optimizer = AdamW(
    model.parameters(),
    lr=2e-5,
    eps=1e-8,
)
# Total number of training steps is number of batches * number of epochs.
# `train_dataloader` contains batched data so `len(train_dataloader)` gives
# us the number of batches.
total_steps = len(train_dataloader) * epochs

# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,  # Default value in run_glue.py
    num_training_steps=total_steps,
)

# Store the average loss after each epoch so we can plot them.
all_loss = {"train_loss": [], "val_loss": []}
all_acc = {"train_acc": [], "val_acc": []}

# Loop through each epoch.
print("Epoch")
for epoch in tqdm(range(epochs)):
    print()
    print("Training on batches...")
    # Perform one full pass over the training set.
    train_labels, train_predict, train_loss = train(
        train_dataloader, optimizer, scheduler, device
    )
    train_acc = accuracy_score(train_labels, train_predict)

    # Get prediction form model on validation data.
    print("Validation on batches...")
    valid_labels, valid_predict, val_loss = validation(valid_dataloader, device)
    val_acc = accuracy_score(valid_labels, valid_predict)

    # Print loss and accuracy values to see how training evolves.
    print(
        "  train_loss: %.5f - val_loss: %.5f - train_acc: %.5f - valid_acc: %.5f"
        % (train_loss, val_loss, train_acc, val_acc)
    )
    print()

    # Store the loss value for plotting the learning curve.
    all_loss["train_loss"].append(train_loss)
    all_loss["val_loss"].append(val_loss)
    all_acc["train_acc"].append(train_acc)
    all_acc["val_acc"].append(val_acc)

# Plot loss curves.
plot_dict(all_loss, use_xlabel="Epochs", use_ylabel="Value", use_linestyles=["-", "--"])

# Plot accuracy curves.
plot_dict(all_acc, use_xlabel="Epochs", use_ylabel="Value", use_linestyles=["-", "--"])

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor

# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor="validation_loss", patience=3, strict=False, verbose=False, mode="min"
)
lr_monitor = LearningRateMonitor(logging_interval="step")

trainer = Trainer(default_root_dir=".", callbacks=[early_stop_callback, lr_monitor])
trainer.fit(model)

In [None]:
save_directory = "/content/drive/MyDrive/mt5_checkpoint_4"  # save in the current working directory, you can change this of course
model.model.save_pretrained(save_directory)

## Evaluation

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# re-load the model
model = MT5ForConditionalGeneration.from_pretrained(
    "../checkpoints/mt5_v3",
).to(device)
encoder_model = MT5EncoderModel.from_pretrained("../checkpoints/mt5_v3").to(device)

In [None]:
eval_dataset = dataset["test"].with_format("torch")

### Measuring Exact Match
How many times the model predicts the correct word?

In [None]:
import evaluate
import tqdm


def run_evaluate(sample):
    outputs = model.generate(
        torch.Tensor(sample["input_ids"])
        .expand(1, len(sample["input_ids"]))
        .to(device)
        .long()
    )
    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return prediction, sample["word"]


predictions, references = [], []
i = 0
for sample in tqdm(eval_dataset):
    p, l = run_evaluate(sample)
    predictions.append(p)
    references.append(l)

In [None]:
metric = evaluate.load("exact_match")
accuracy = metric.compute(predictions=predictions, references=references)

### P@3 with Cosine Similarity

Here we use the out embeddings and find the top 3 similar words from the test set then find how many of them match the ground truth.

In [None]:
from torch.nn.functional import cosine_similarity
from typing import List

model.eval()


def test_set_embeddings(test_set):
    input = tokenizer(
        test_set["word"], return_tensors="pt", padding=True, truncation=True
    ).input_ids
    outputs = encoder_model(input)
    last_hidden_states = outputs.last_hidden_state
    input_mask = input["input_ids"] != tokenizer.pad_token_id
    input_mask_expanded = input_mask.unsqueeze(-1).expand(last_hidden_states.size())
    sum_embeddings = torch.sum(last_hidden_states * input_mask_expanded, 1)
    sum_mask = input_mask_expanded.sum(1)
    mean_embeddings = sum_embeddings / sum_mask
    return mean_embeddings


# Generate the word embedding for a query
def get_prediction_embeddings(sample):
    # First, generate the prediction
    input_ids = tokenizer(
        sample["query"], return_tensors="pt", padding=True, truncation=True
    ).input_ids
    output_word = model.generate(input_ids)
    decoded_output = tokenizer.decode(output_word[0], skip_special_tokens=True)

    # Next, generate the embeddings of the predicted word
    input_ids = tokenizer(
        decoded_output, return_tensors="pt", padding=True, truncation=True
    ).input_ids
    outputs = encoder_model(input_ids)
    last_hidden_states = outputs.last_hidden_state
    input_mask = input_ids != tokenizer.pad_token_id
    input_mask_expanded = input_mask.unsqueeze(-1).expand(last_hidden_states.size())
    sum_embeddings = torch.sum(last_hidden_states * input_mask_expanded, 1)
    sum_mask = input_mask_expanded.sum(1)
    mean_embeddings = sum_embeddings / sum_mask
    return mean_embeddings


# Calculate top n words similar to the output embedding
def get_top_n(emb: torch.Tensor, data: List[torch.Tensor], n: int = 5):
    scores = []
    for item in data:
        # Find the similarity score
        score = cosine_similarity(emb, item)
        # Append to total results
        scores.append(score.item())
    # get top n
    return sorted(scores, reverse=True)[:n]

In [None]:
test_embeds = test_set_embeddings(eval_dataset)

In [None]:
scores = []
i = 0
for item in eval_dataset:
    test_word = get_prediction_embeddings(item)
    scores.append(get_top_n(test_word, test_embeds, n=1))
    i += 1
    if i == 100:
        break

In [None]:
# Get the average score
average_at_3 = [sum(x) / len(x) for x in zip(*scores)]
average = sum(average_at_3) / len(average_at_3)
average

# Use it

In [None]:
input_text = "البلاغة"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(
    device
)  # pt for PyTorch tensors
output_ids = model.generate(input_ids)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)