# Hindi and Tamil Question Answering with Lightning Flash

In [None]:
# ! pip install -q pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
! pip install -q torch==1.10.0+cu111 torchvision==0.11.1+cu111 torchaudio==0.10.0+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
! pip install -qU torchtext 
! pip install -q pytorch-lightning==1.4.9
! pip install -q "git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[text]"
! pip install -qU torchmetrics>=0.5.0 rich wandb

In [None]:
! mkdir /kaggle/temp
! pip list | grep torch
! pip list | grep tokenizers
! pip list | grep transformers
! pip list | grep datasets

%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
! nvidia-smi
ENABLE_ORT = False
if ENABLE_ORT:
    import torch
    if torch.cuda.is_available():
        ! pip install -q torch-ort -f https://download.onnxruntime.ai/onnxruntime_stable_cu111.html
        ! python -m torch_ort.configure

In [None]:
import gc
gc.enable()
import os
import numpy as np
import pandas as pd
import torch
import pytorch_lightning as pl

import wandb

from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union
from torch.optim.lr_scheduler import (
    _LRScheduler, StepLR, MultiStepLR, ReduceLROnPlateau, CosineAnnealingWarmRestarts
)

from transformers import (
    AdamW,
    get_cosine_schedule_with_warmup,
    get_linear_schedule_with_warmup,
    get_cosine_with_hard_restarts_schedule_with_warmup
)
from transformers.trainer_pt_utils import get_parameter_names

import flash
from flash import Trainer
from flash.core.optimizers import LinearWarmupCosineAnnealingLR
from flash.core.finetuning import NoFreeze
from flash.text import QuestionAnsweringData, QuestionAnsweringTask
from flash.text.question_answering.finetuning import QuestionAnsweringFreezeEmbeddings

In [None]:
# Setup and login into wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")

! wandb login $WANDB_API_KEY

In [None]:
INPUT_DIR = "/kaggle/input/chaii-hindi-and-tamil-question-answering"
TEMP_PATH = "./"
INPUT_DATA_PATH = os.path.join(INPUT_DIR, "train.csv")
TRAIN_DATA_PATH = os.path.join(TEMP_PATH, "_train.csv")
VAL_DATA_PATH = os.path.join(TEMP_PATH, "_val.csv")
PREDICT_DATA_PATH = os.path.join(INPUT_DIR, "test.csv")

@dataclass
class HyperParams:
    seed: int = 42
    
    # dataset specific
    train_val_split: float = 0.1
    batch_size: int = 4
    
    # model specific
    backbone: str = "xlm-roberta-base"
    pretrained: bool = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    ## Optimizer and Scheduler specific
    # lr=2e-7,
    # lr=5.531681197617226e-05, ## Adam
    # lr=0.0001096478196143185, ## AdamW
    optimizer = 'AdamW'
    learning_rate = 3e-5
    weight_decay = 1e-2
    epsilon = 1e-8
    max_grad_norm = 1.0
    lr_scheduler: str = "cosine_schedule_with_warmup"
    num_warmup_steps: float = 0.1
    lr_scheduler_config = {
        "interval": "step",
        "frequency": 1,
    }
    
    # Training/Finetuning args
    debug: bool = False
    num_gpus: int = torch.cuda.device_count()
    accumulate_grad_batches: int = 2
    enable_ort: bool = ENABLE_ORT
    max_epochs: int = 10
    finetuning_strategy: str= "no_freeze"
    stochastic_weight_avg: bool = False

HYPER_PARAMS = HyperParams()
pl.seed_everything(HYPER_PARAMS.seed)

In [None]:
# Display a small portion of the dataset
df = pd.read_csv(INPUT_DATA_PATH)
display(df.head())

## 1. Create the DataModule

In [None]:
fraction = 1 - HYPER_PARAMS.train_val_split

# Splitting data into train and val beforehand since preprocessing will be different for datasets.
tamil_examples = df[df["language"] == "tamil"]
train_split_tamil = tamil_examples.sample(frac=fraction,random_state=200)
val_split_tamil = tamil_examples.drop(train_split_tamil.index)

hindi_examples = df[df["language"] == "hindi"]
train_split_hindi = hindi_examples.sample(frac=fraction,random_state=200)
val_split_hindi = hindi_examples.drop(train_split_hindi.index)

train_split = pd.concat([train_split_tamil, train_split_hindi]).reset_index(drop=True)
val_split = pd.concat([val_split_tamil, val_split_hindi]).reset_index(drop=True)

train_split.to_csv(TRAIN_DATA_PATH, index=False)
val_split.to_csv(VAL_DATA_PATH, index=False)

In [None]:
# 1. Create the DataModule
datamodule = QuestionAnsweringData.from_csv(
    train_file=TRAIN_DATA_PATH,
    val_file=VAL_DATA_PATH,
    batch_size=HYPER_PARAMS.batch_size,
    backbone=HYPER_PARAMS.backbone
)

## 2. Build the task

In [None]:
class ChaiiQuestionAnswering(QuestionAnsweringTask):
    def __init__(
       self,
        backbone: str = "distilbert-base-uncased",
        optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
        lr_scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
        metrics: Union[Callable, Mapping, Sequence, None] = None,
        learning_rate: float = 5e-5,
        enable_ort: bool = False,
    ):
        super().__init__(
            backbone=backbone,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            metrics=metrics,
            learning_rate=learning_rate,
            enable_ort=enable_ort,
        )

    @staticmethod
    def jaccard(str1, str2): 
        a = set(str1.lower().split()) 
        b = set(str2.lower().split())
        c = a.intersection(b)
        return float(len(c)) / (len(a) + len(b) - len(c))
        
    def compute_metrics(self, generated_tokens, batch):
        scores = []
        for example in batch:
            predicted_answer = generated_tokens[example["example_id"]]
            target_answer = example["answer"]["text"][0] if len(example["answer"]["text"]) > 0 else ""
            scores.append(ChaiiQuestionAnswering.jaccard(predicted_answer, target_answer))

        result = {"jaccard_score": torch.mean(torch.tensor(scores))}
        return result
    
#     def configure_optimizers(self):
#         decay_parameters = get_parameter_names(self.model, [torch.nn.LayerNorm])
#         decay_parameters = [name for name in decay_parameters if "bias" not in name]
#         optimizer_grouped_parameters = [
#             {
#                 "params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
#                 "weight_decay": HYPER_PARAMS.weight_decay,
#             },
#             {
#                 "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
#                 "weight_decay": 0.0,
#             },
#         ]
#         optimizer = AdamW(optimizer_grouped_parameters,lr=self.learning_rate,correct_bias=True)

#         if self.lr_scheduler is not None:
#             return [optimizer], [self._instantiate_lr_scheduler(optimizer)]
#         return optimizer
    
#     def configure_optimizers(self):
#         opt_parameters = (
#             []
#         )  # To be passed to the optimizer (only parameters of the layers you want to update).
#         named_parameters = list(self.model.named_parameters())

#         # According to AAAMLP book by A. Thakur, we generally do not use any decay
#         # for bias and LayerNorm.weight layers.
#         no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
#         set_2 = ["layer.4", "layer.5", "layer.6", "layer.7"]
#         set_3 = ["layer.8", "layer.9", "layer.10", "layer.11"]
#         init_lr = 1e-6

#         for i, (name, params) in enumerate(named_parameters):

#             weight_decay = 0.0 if any(p in name for p in no_decay) else 0.01

#             if name.startswith("roberta.embeddings") or name.startswith("roberta.encoder"):
#                 # For first set, set lr to 1e-6 (i.e. 0.000001)
#                 lr = init_lr

#                 # For set_2, increase lr to 0.00000175
#                 lr = init_lr * 1.75 if any(p in name for p in set_2) else lr

#                 # For set_3, increase lr to 0.0000035
#                 lr = init_lr * 3.5 if any(p in name for p in set_3) else lr

#                 opt_parameters.append(
#                     {"params": params, "weight_decay": weight_decay, "lr": lr}
#                 )

#             # For regressor and pooler, set lr to 0.0000036 (slightly higher than the top layer).
#             if name.startswith("qa_outputs"):
#                 lr = init_lr * 3.6

#                 opt_parameters.append(
#                     {"params": params, "weight_decay": weight_decay, "lr": lr}
#                 )

#         optimizer = AdamW(opt_parameters, lr=init_lr, correct_bias=True)
#         if self.lr_scheduler is not None:
#             return [optimizer], [self._instantiate_lr_scheduler(optimizer)]
#         return optimizer

    def configure_optimizers(self):
        # To be passed to the optimizer (only parameters of the layers you want to update).
        opt_parameters = []
        weight_decay = HYPER_PARAMS.weight_decay    
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        named_parameters = list(self.model.named_parameters())
    
        def get_lrs(start_lr, last_lr):
            step = np.abs(np.log(start_lr) - np.log(last_lr)) / 3
            return [start_lr,  np.exp(np.log(start_lr) + step), np.exp(np.log(start_lr)+2*step), last_lr]
        
        def is_decay_param(name: str):
            return not any(p in name for p in no_decay)
        
        ranges = [(0,69), (69, 133), (133, 197), (197, 199)]
        lrs = get_lrs(1e-8, self.learning_rate)
        
        for _range, lr in zip(ranges, lrs):
            params = named_parameters[_range[0]:_range[1]]
            decay_parameters = [p for n, p in params if is_decay_param(n)]
            no_decay_parameters = [p for n, p in params if not is_decay_param(n)]

            opt_parameters.append(
                {"params": decay_parameters, "weight_decay": weight_decay, "lr": lr}
            )

            # According to AAAMLP book by A. Thakur, we generally do not use any decay
            # for bias and LayerNorm.weight layers.
            opt_parameters.append(
                {"params": no_decay_parameters, "weight_decay": 0.0, "lr": lr}
            )
        

        optimizer = AdamW(opt_parameters, lr=self.learning_rate, correct_bias=True)
        if self.lr_scheduler is not None:
            return [optimizer], [self._instantiate_lr_scheduler(optimizer)]
        return optimizer
    
    def configure_optimizers(self):
            # To be passed to the optimizer (only parameters of the layers you want to update).
            opt_parameters = []
            weight_decay = HYPER_PARAMS.weight_decay    
            no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
            for n, p in self.model.named_parameters():
                if "qa_outputs" in n:
                    if "bias" in n:
                        opt_parameters.append(
                            {"params": p, "weight_decay": 0.0, "lr": self.learning_rate}
                        )
                    else:
                        opt_parameters.append(
                            {"params": p, "weight_decay": weight_decay, "lr": self.learning_rate}
                        )
            optimizer = AdamW(opt_parameters, lr=self.learning_rate, correct_bias=True)
            if self.lr_scheduler is not None:
                return [optimizer], [self._instantiate_lr_scheduler(optimizer)]
            return optimizer

In [None]:
model = ChaiiQuestionAnswering(
    backbone=HYPER_PARAMS.backbone,
    learning_rate=HYPER_PARAMS.learning_rate,
    lr_scheduler=(HYPER_PARAMS.lr_scheduler, {"num_warmup_steps": HYPER_PARAMS.num_warmup_steps}, HYPER_PARAMS.lr_scheduler_config),
    enable_ort=HYPER_PARAMS.enable_ort,
)

## 3. Create the trainer and finetune the model

In [None]:
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    save_top_k=1,
    filename='checkpoint/{epoch:02d}-{val_loss:.4f}',
    mode='max',
)

wandb_logger = pl.loggers.WandbLogger(
    project='chaii-competition',
    group=f"{HYPER_PARAMS.backbone}",
    job_type=f"{HYPER_PARAMS.finetuning_strategy}",
    name="ORT=False_LR=1e-5_NoFreeze_CW=0.1_AG=2_WRate=0.5_4grpLRD",
    log_model=True,
    config=asdict(HYPER_PARAMS),
)

earlystopping = pl.callbacks.EarlyStopping(monitor='val_loss', patience=3, mode='min')

swa = pl.callbacks.StochasticWeightAveraging()

In [None]:
FIND_LR = False
callbacks = [lr_monitor, earlystopping, checkpoint_callback]
callbacks_with_swa = callbacks + [swa]

if FIND_LR:
    callbacks.append(QuestionAnsweringFreezeEmbeddings(model_type=model.model.config.model_type))

trainer = Trainer(
    fast_dev_run=HYPER_PARAMS.debug,
    logger=wandb_logger if not HYPER_PARAMS.debug else True,
    callbacks=callbacks,
    gpus=HYPER_PARAMS.num_gpus,
    log_every_n_steps=10,
    weights_summary='top',
    auto_lr_find=FIND_LR,
    max_epochs=HYPER_PARAMS.max_epochs,
    accumulate_grad_batches=HYPER_PARAMS.accumulate_grad_batches,
    num_sanity_val_steps=0,
    stochastic_weight_avg=HYPER_PARAMS.stochastic_weight_avg,
)

In [None]:
def get_num_training_steps(datamodule, trainer) -> int:
    """Total training steps inferred from datamodule and devices."""
    dataset_size = len(datamodule.train_dataloader())
    num_devices = max(1, trainer.num_gpus, trainer.num_processes)
    effective_batch_size = trainer.accumulate_grad_batches * num_devices
    max_estimated_steps = (dataset_size // effective_batch_size) * trainer.max_epochs
    return max_estimated_steps

num_training_steps_for_lr_find = get_num_training_steps(datamodule, trainer)
print(f"Num of training steps: {num_training_steps_for_lr_find}")

In [None]:
if FIND_LR:
    ## Find LR ##
    # Run learning rate finder
    lr_finder = trainer.tuner.lr_find(
        model=model,
        datamodule=datamodule,
        min_lr=1e-8, 
        max_lr=1,
        num_training=num_training_steps_for_lr_find,
        mode="linear",
        early_stop_threshold=None,
        update_attr=False,
    )
    fig = lr_finder.plot(suggest=True)
    fig.show()
    new_lr = lr_finder.suggestion()
    print(f"Suggested Learning Rate: {new_lr}")
else:
    ## Finetune the model ##
    if not HYPER_PARAMS.debug:
        wandb_logger.watch(model)
    
    if HYPER_PARAMS.finetuning_strategy == "freeze":
        print("Frozen Model")
        trainer.finetune(model, datamodule=datamodule)
    else:
        print("Unfrozen Model")
        trainer.fit(model, datamodule=datamodule)
    
    if not HYPER_PARAMS.debug:
        wandb.finish()

In [None]:
torch.cuda.empty_cache()
gc.collect()

## Prediction

In [None]:
# Convert the prediction queries to dictionary format.
predict_data = pd.read_csv(PREDICT_DATA_PATH)
predict_data = predict_data[predict_data.columns[:3]].to_dict(orient="list")

# Answer some Questions!
predictions = model.predict(predict_data)
print(predictions)

# Create submission.
submission = {"id": [], "PredictionString": []}
for prediction in predictions:
    submission["id"].extend(prediction.keys())
    submission["PredictionString"].extend(prediction.values())
submission = pd.DataFrame(submission)
submission.to_csv("./submission.csv", index=False)