<a target="_blank" href="https://colab.research.google.com/github/pr-Mais/arabic-reverse-dictionary/blob/main/code/mt5_training_shared_task.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Fine-tuning MT5 for Reverse Dictionary

In [1]:
# This script is used to mount the google drive to the colab environment.
import sys

IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
    from google.colab import drive

    drive.mount("/content/drive")

## Requirements

We require the following libraries for the modeling process:

In [None]:
!pip3 install -q -U datasets==2.18.0
!pip3 install -q -U transformers==4.39.2
!pip3 install -q -U evaluate==0.4.1
!pip3 install -q -U scikit-learn==1.4.1.post1
!pip3 install -q -U torch==2.2.1
!pip3 install -q -U tokenizers==0.15.2
!pip3 install -q -U tqdm==4.66.2
!pip3 install -q -U pandas==2.2.1
!pip3 install -q -U numpy==1.26.4

Next, we import all required libraries and modules.

In [2]:
import pandas as pd
from tqdm import tqdm
from typing import Literal
from datasets import Dataset, DatasetDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    AdamW,
    set_seed,
    get_linear_schedule_with_warmup,
)
from transformers.modeling_outputs import Seq2SeqLMOutput

# For reproducibility
set_seed(123)
#checkpoint = "UBC-NLP/AraT5v2-base-1024"
checkpoint = "google/mt5-base"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TargetEmbeddingType = Literal["electra", "bertseg", "bertmsa"]

# The max length of the embeddings, this should
# match the length of the target embeddings length.
max_length = {
    "electra": 256,
    "bertseg": 768,
    "bertmsa": 768,
}

# Which target embedding to use as the target for the model?
target_embedding: TargetEmbeddingType = "electra"
max_length = max_length[target_embedding]

  from .autonotebook import tqdm as notebook_tqdm


### Loading and preprocessing data

In [3]:
if not IN_COLAB:
  train_ds_path = "../data/shared-task/train_with_examples.json"
  val_ds_path = "../data/shared-task/dev.json"
  test_ds_path = "../data/shared-task/test.json"
else:
  train_ds_path = "/content/drive/MyDrive/shared-task/train_with_examples.json"
  val_ds_path = "/content/drive/MyDrive/shared-task/dev.json"
  test_ds_path = "/content/drive/MyDrive/shared-task/test.json"

In [4]:
train_df = pd.read_json(train_ds_path, encoding="utf-8")
val_df = pd.read_json(val_ds_path, encoding="utf-8")
test_df = pd.read_json(test_ds_path, encoding="utf-8")

In [5]:
train_df.head(20)

Unnamed: 0,id,word,pos,gloss,electra,bertseg,bertmsa,examples
0,ar.976497,اشتط,V,اشتطَّ/ اشتطَّ على/ اشتطَّ في يشتطّ، اشْتَطِط/...,"[-1.2550759315, -0.5122892857, 0.3778211772, 0...","[-0.0769249499, -0.1658652574, -0.1849669367, ...","[-0.5870065689, 0.9182425141, 0.53450250630000...",[اسم الفاعل، والصفة المشبهة، واسم المفعول، واس...
1,ar.974603,خصب,N,نامٍ، كثير العشب,"[0.31928840280000004, -0.739033699, 0.70923948...","[-0.1230165064, -0.18066638710000002, -0.07713...","[-1.0599673986, 1.4727288485, -0.1896873564, 0...",[نبات عشبي بري زراعي يبلغ ارتفاعه 20-80 سم ساق...
2,ar.993088,كَفَّن,V,كَفَّن المَيِّتَ: لَفَّهُ بالكَفَن. (مُبَالَغَ...,"[-0.6270618439, 0.1627743989, 0.142576322, -0....","[0.17214727400000002, -0.056175738600000004, 0...","[-0.9546464086, 1.7046545744000001, -1.1613812...",[أَيْن: اسمُ استِفهامٍ مَبْنيٌ على الفَتحِ فِي...
3,ar.995571,حكم,N,خبير في قوانين الألعاب يتولَّى إدارة المباراة ...,"[-0.1196034774, -1.0221087933, -0.754472792100...","[0.047808088400000004, -0.1173066497, 0.378811...","[-1.3533554077, 0.9892686605000001, 0.29662185...",[الحكم في المصارعة المحترفة هو المشرف على المب...
4,ar.979447,كذاب,N,صيغة مبالغة من كذَبَ/ كذَبَ على: كثير الكذب .,"[-0.9610854983, 0.4087179005, 0.054193534, -0....","[-0.0205224976, 0.19622370600000003, 0.3586411...","[-1.8391281366, 1.2455878258, 0.5053068399, 0....",[وردت لفظ الكذب ومشتقاتها في القرآن الكريم في ...
5,ar.976482,شرم,N,كلُّ شقٍّ غير نافذ في جبل أو حائط.,"[-1.5637322664, -0.7942423820000001, 0.0531375...","[-0.1907260567, -0.2457073182, 0.1192496195, 0...","[-0.22798559070000002, 0.650627017, 0.25319314...",[كان الهرم والمعبد الجنائزي غير المكتمل محاطًا...
6,ar.995025,بَسْط,N,عدد أعلى في الكسر الاعتيادي كالعدد (2) في الكس...,"[0.0865183845, -0.6144404411000001, 0.25060844...","[-0.128886193, -0.079362981, 0.0060265856, 0.0...","[-0.29289975760000003, 0.8322463036000001, 0.8...",[الكسر هو ناتج قسمة، أو العدد الذي يحصل عليه ب...
7,ar.978185,غنى,V,تغنَّى به؛ ترنَّم به.,"[-0.3357913494, 0.0016047321, 0.7248107791, 0....","[0.1466675252, -0.2847951353, 0.1616912037, -0...","[-0.9279320240000001, 0.6518034935, -0.0197837...",[يُزْهِرُ إِزْهَارًا وَيَبْتَهِجُ ابْتِهَاجًا ...
8,ar.965439,بَيْعِيّ,,منسوب إلى بيع،صفقة يتم بموجبها تبادل الشيء بال...,"[-0.7022292614, -0.0916798115, 0.3377537727, -...","[0.3020823896, -0.1332699656, 0.2473745793, -0...","[-0.7345462441, 0.5775005817000001, 0.58639520...",[وتعرف عملية البيع أيضاً بأنها عملية مفاوضات ت...
9,ar.980856,هجن,V,كان معيبًا مرذولاً، دخَل فيه عيبٌ,"[0.1757310033, -2.0628080368, -0.7020929456, -...","[0.0246690679, -0.6263778806, 0.0322731957, 0....","[-1.0371968746, 1.4605400562, 0.3261591494, -1...",[ومن شعره التأمُّلي قولُه مخاطبًا نفسه، التي ك...


In [6]:
# Printing out some information about the datasets
print(f"Train dataset has {len(train_df)} examples, and the following columns:")
print(train_df.columns.tolist())
print()
print(f"Validation dataset has {len(val_df)} examples, and the following columns:")
print(val_df.columns.tolist())
print()
print(f"Test dataset has {len(test_df)} examples, and the following columns:")
print(test_df.columns.tolist())

Train dataset has 31372 examples, and the following columns:
['id', 'word', 'pos', 'gloss', 'electra', 'bertseg', 'bertmsa', 'examples']

Validation dataset has 3921 examples, and the following columns:
['id', 'word', 'pos', 'gloss', 'electra', 'bertseg', 'bertmsa']

Test dataset has 3922 examples, and the following columns:
['id', 'word', 'pos', 'gloss']


In [7]:
# Merge train and validation into one dict with keys `train` and `val`.
# This is for training and development, test set has no targets provided.
train_val_dict = {
    "train": train_df.to_dict("records"),
    "val": val_df.to_dict("records"),
}
# Convert to HF dataset
train_ds = Dataset.from_pandas(train_df, split="train")
val_ds = Dataset.from_pandas(val_df, split="validation")
test_ds = Dataset.from_pandas(test_df, split="test")
dataset = DatasetDict({"train": train_ds, "val": val_ds})

In the next steps, we prepare the data for modeling.

The features we care about from the dataset are the `gloss` and `examples`, and the target is either `electra`, `bertseg` or `bertmsa`. This means we will train 3 different models on each target.

In the preprocessing step, we will tokenize the data and convert it to a format that can be fed into the model. This includes merging the `gloss` and `examples` into a single string, tokenizing the string, and converting the tokens to token ids.

In [8]:
# Tokenization step
tokenizer = AutoTokenizer.from_pretrained(checkpoint, legacy=False)

padding = "max_length"
max_input_length = 256


def preprocess_function(items):
    # The inputs are the glosses + examples
    # if "examples" in items:
    #     glosses = [
    #         f"{gloss}. {example[0]}، {example[1]}"
    #         for gloss, example in zip(items["gloss"], items["examples"])
    #     ]
    # else:
    glosses = items["gloss"]

    model_inputs = tokenizer(
        glosses,
        max_length=max_input_length,
        padding=padding,
        truncation=True,
        return_tensors="pt",
    )

    # Adding the 3 types of target embeddings, if they are available.
    if "electra" in items:
        model_inputs["electra"] = items["electra"]  # Electra embeddings
    if "bertseg" in items:
        model_inputs["bertseg"] = items["bertseg"]  # BERTseg embeddings
    if "bertmsa" in items:
        model_inputs["bertmsa"] = items["bertmsa"]  # BERTmsa embeddings

    targets = [ex for ex in items["word"]]
    # encode the words
    labels = tokenizer(
        targets,
        max_length=max_length,
        padding=padding,
        truncation=True,
        return_tensors="pt",
    ).input_ids

    # important: we need to replace the index of the padding tokens by -100
    # such that they are not taken into account by the CrossEntropyLoss
    labels_with_ignore_index = []
    for labels_example in labels:
        labels_example = [label if label != 0 else -100 for label in labels_example]
        labels_with_ignore_index.append(labels_example)

    model_inputs["labels"] = labels_with_ignore_index

    return model_inputs


# Final mapping of the dataset
dataset = dataset.map(preprocess_function, batched=True)
test_ds = test_ds.map(preprocess_function, batched=True)

Map: 100%|██████████| 31372/31372 [00:47<00:00, 659.81 examples/s]
Map: 100%|██████████| 3921/3921 [00:06<00:00, 649.84 examples/s]
Map: 100%|██████████| 3922/3922 [00:04<00:00, 894.39 examples/s]


Finally, data splits are converted into PyTorch `DataLoader` objects for training.

In [9]:
dataset.set_format(
    type="torch",
    columns=["id", "input_ids", "attention_mask", "labels", "electra", "bertseg", "bertmsa", "word"],
    output_all_columns=False,
)
train_dataloader = DataLoader(dataset["train"], shuffle=True, batch_size=8)
valid_dataloader = DataLoader(dataset["val"], batch_size=4)

test_ds.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "id"],
    output_all_columns=False,
)
test_dataloader = DataLoader(test_ds, batch_size=4)

In [10]:
test_dataloader.dataset

Dataset({
    features: ['id', 'word', 'pos', 'gloss', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 3922
})

In [11]:
print(
    f"Electra embeddings shape: 1, {dataset['train']['electra'].shape[1]}\n"
    f"BERTseg embeddings shape: 1, {dataset['train']['bertseg'].shape[1]}\n"
    f"BERTmsa embeddings shape: 1, {dataset['train']['bertmsa'].shape[1]}\n"
)

Electra embeddings shape: 1, 256
BERTseg embeddings shape: 1, 768
BERTmsa embeddings shape: 1, 768



## Training

The training pipeline is prepared to accept 3 types of targets: `electra`, `bertseg`, and `bertmsa`. The training process is the same for all targets.

In [None]:
# Loss functions
mse_loss = nn.MSELoss()


def train(
    dataloader,
    optimizer_,
    scheduler_,
    device_,
    target: TargetEmbeddingType = "electra",
    validate=False,
):
    # Use global variable for model.
    global model
    # Tracking variables.
    predictions = []
    ground_truth = []
    # Total loss for this epoch.
    total_loss = 0
    if not validate:
        model.train()
    if validate:
        model.eval()

    # For each batch of training data...
    for batch in tqdm(dataloader):
        ground_truth += batch[target].numpy().tolist()
        inputs = {
            k: v.to(device_)
            for k, v in batch.items()
            if k in ["input_ids", "attention_mask"]
        }
        labels = batch["labels"].to(device)
        if not validate:
            model.zero_grad()
        embeddings = model(**inputs, labels=labels, return_dict=True)
        # Loss is calculated on target embeddings, outside the model.
        loss = mse_loss(embeddings, batch[target].to(device_))
        total_loss += loss.item()
        if not validate:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer_.step()
            scheduler_.step()
        predictions += embeddings.tolist()

    # Calculate the average loss over the training data.
    avg_epoch_loss = total_loss / len(dataloader)

    # Return all true labels and prediction for future evaluations.
    return ground_truth, predictions, avg_epoch_loss

In [12]:
class RevDictModel(nn.Module):
    def __init__(self, max_length: int, checkpoint: str):
        super().__init__()
        model_config = AutoConfig.from_pretrained(checkpoint)
        self.base_model = AutoModelForSeq2SeqLM.from_config(model_config)

        print(max_length)

        # Redefining the linear layer to match the target embedding size (max_length)
        self.linear = nn.Linear(self.base_model.config.hidden_size, max_length)

    def forward(self, input_ids, attention_mask, **kwargs):
        # Only using the encoder part to generate embeddings
        outputs = self.base_model.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        pooled_emb = (outputs.last_hidden_state * attention_mask.unsqueeze(2)).sum(
            dim=1
        ) / attention_mask.sum(dim=1).unsqueeze(1)
        embedding = self.linear(pooled_emb)
        return embedding

    def save(self, file):
        torch.save(self, file)

    @staticmethod
    def load(file):
        return torch.load(file, map_location=device)

### Model training per target

In [None]:
# Hyperparameters.
epochs = 1
lr = 3e-5

In [None]:
model = RevDictModel(max_length=max_length, checkpoint=checkpoint)
model.to(device)

In [None]:
# Total number of training steps is number of batches * number of epochs.
# `train_dataloader` contains batched data so `len(train_dataloader)` gives
# us the number of batches.
total_steps = len(train_dataloader) * epochs

optimizer = AdamW(
    model.parameters(),
    lr=lr,
    eps=1e-8,
)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps,
)

# Store the average loss after each epoch so we can plot them.
all_loss = {"train_loss": [], "val_loss": []}

# Loop through each epoch.
for epoch in range(epochs):
    print(f"Epoch {epoch}")
    train_labels, train_predict, train_loss = train(
        train_dataloader, optimizer, scheduler, device, target=target_embedding
    )
    valid_labels, valid_predict, val_loss = train(
        valid_dataloader, optimizer, scheduler, device, validate=True, target=target_embedding
    )

    # Print loss and accuracy values to see how training evolves.
    print("  train_loss: %.5f - val_loss: %.5f" % (train_loss, val_loss))
    print()

    # Store the loss value for plotting the learning curve.
    all_loss["train_loss"].append(train_loss)
    all_loss["val_loss"].append(val_loss)

# Plot loss curves.
# plot_dict(all_loss, use_xlabel="Epochs", use_ylabel="Value", use_linestyles=["-", "--"])

Save the trained model.

In [None]:
save_directory = f"/content/drive/MyDrive/arat5-v0/mt5_{target_embedding}_checkpoint_0"
model.save(save_directory)

## Evaluation

In [9]:
# re-load the model
model = RevDictModel.load(f"../checkpoints/mt5_{target_embedding}_checkpoint_1")

In [None]:
predictions = []
targets = []
for sample in tqdm(valid_dataloader):
    with torch.no_grad():
        inputs = {
            k: v.to(device)
            for k, v in sample.items()
            if k in ["input_ids", "attention_mask"]
        }
        outputs = model(**inputs)
        predictions += outputs
        targets += sample[target_embedding]

### Cosine Similarity

Here we calculate the similarity between predictions & targets.

In [None]:
scores = []
for pred, target in zip(predictions, targets):
    score = F.cosine_similarity(pred.to(device), target.to(device), dim=0)
    scores.append(score.item())

In [None]:
# Get the mean score
mean = sum(scores) / len(scores)
mean

## Export the results for submission

In [13]:
model_electra = RevDictModel.load(f"../checkpoints/mt5_electra_checkpoint_0")
model_bertseg = RevDictModel.load(f"../checkpoints/mt5_bertseg_checkpoint_0")
model_bertmsa = RevDictModel.load(f"../checkpoints/mt5_bertmsa_checkpoint_0")

In [22]:
# Export my embedding to JSON in format:
# {
#     "id": "word id",
#     "electra": [1, 2, 3, ...],
#     "bertseg": [1, 2, 3, ...],
#     "bertmsa": [1, 2, 3, ...]
# }

import json

predictions = []
for sample in tqdm(valid_dataloader):
    with torch.no_grad():
        inputs = {
            k: v
            for k, v in sample.items()
            if k in ["input_ids", "attention_mask"]
        }
        outputs_electra = model_electra(**inputs)
        outputs_bertseg = model_bertseg(**inputs)
        outputs_bertmsa = model_bertmsa(**inputs)
        for i in range(4):
          if i >= len(sample["id"]):
            break
          predictions.append({
            "id": sample["id"][i],
            "electra": outputs_electra[i].tolist(),
            "bertseg": outputs_bertseg[i].tolist(),
            "bertmsa": outputs_bertmsa[i].tolist(),
          })

100%|██████████| 981/981 [14:12<00:00,  1.15it/s]


In [23]:
with open("dev_submission_ex.json", "w") as f:
    json.dump(predictions, f)

In [411]:
predictions = []
for sample in tqdm(valid_dataloader):
    with torch.no_grad():
        inputs = {
            k: v
            for k, v in sample.items()
            if k in ["input_ids", "attention_mask"]
        }
        outputs_electra = model_electra(**inputs)
        outputs_bertseg = model_bertseg(**inputs)
        outputs_bertmsa = model_bertmsa(**inputs)
        predictions.append({
          "id": sample["id"],
          "electra": outputs_electra[0].tolist(),
          "bertseg": outputs_bertseg[0].tolist(),
          "bertmsa": outputs_bertmsa[0].tolist(),
        })

with open("dev_submission.json", "w") as f:
    json.dump(predictions, f)

100%|██████████| 981/981 [11:58<00:00,  1.37it/s]


# Use it

In [None]:
words = dataset['train']['word']
targets = dataset['train'][target_embedding]

In [None]:
from torch.nn import functional as F
from typing import List

model.eval()

query = "منسوب إلى بيع،صفقة يتم بموجبها تبادل الشيء بالشيء أو بما يساوي قيمته"

# Calculate top n words similar to the output embedding
scores = {}
tokenized_query = tokenizer(
    query,
    max_length=max_input_length,
    padding=padding,
    truncation=True,
    return_tensors="pt")
inputs = {
            k: v.to(device)
            for k, v in tokenized_query.items()
            if k in ["input_ids", "attention_mask"]
        }
output = model(**inputs)
for i in tqdm(range(len(targets))):
  score = F.cosine_similarity(output[0].to(device), targets[i][0].to(device), dim=0)
  scores[words[i]] = score.item()

In [None]:
sorted(scores, key=scores.get, reverse=True)[:10]