# Hindi and Tamil Question Answering with Lightning Flash

# Install the required dependencies

In [None]:
! pip install -q torch==1.10.0+cu111 torchvision==0.11.1+cu111 torchaudio==0.10.0+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
! pip install -qU torchtext wandb
! pip install -q pytorch-lightning==1.4.9
! pip install -q "git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[text]"

In [None]:
! nvidia-smi
! mkdir /kaggle/temp

%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
import os
import numpy as np
import pandas as pd

import torch
import pytorch_lightning as pl
import flash
import wandb

from dataclasses import asdict, dataclass

from flash import Trainer
from flash.text import QuestionAnsweringData, QuestionAnsweringTask

In [None]:
# Setup and login into wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")

! wandb login $WANDB_API_KEY

In [None]:
@dataclass
class HyperParams:
    seed: int = 42
    
    train_val_split: float = 0.1
    batch_size: int = 2
    
    backbone: str = "xlm-roberta-base"
    
    ## Optimizer and Scheduler specific
    optimizer = 'adamw'
    learning_rate = 1e-5

    # Training/Finetuning args
    debug: bool = False
    num_gpus: int = torch.cuda.device_count()
    accumulate_grad_batches: int = 2
    max_epochs: int = 5
    finetuning_strategy: str= ("freeze_unfreeze", 2)

HYPER_PARAMS = HyperParams()
pl.seed_everything(HYPER_PARAMS.seed)

## 1. Create the DataModule

In [None]:
INPUT_DIR = "/kaggle/input/chaii-hindi-and-tamil-question-answering"
TEMP_PATH = "./"
INPUT_DATA_PATH = os.path.join(INPUT_DIR, "train.csv")
TRAIN_DATA_PATH = os.path.join(TEMP_PATH, "_train.csv")
VAL_DATA_PATH = os.path.join(TEMP_PATH, "_val.csv")
PREDICT_DATA_PATH = os.path.join(INPUT_DIR, "test.csv")

# Display a small portion of the dataset
df = pd.read_csv(INPUT_DATA_PATH)
display(df.head())

In [None]:
fraction = 1 - HYPER_PARAMS.train_val_split

# Splitting data into train and val beforehand since preprocessing will be different for datasets.
tamil_examples = df[df["language"] == "tamil"]
train_split_tamil = tamil_examples.sample(frac=fraction,random_state=200)
val_split_tamil = tamil_examples.drop(train_split_tamil.index)

hindi_examples = df[df["language"] == "hindi"]
train_split_hindi = hindi_examples.sample(frac=fraction,random_state=200)
val_split_hindi = hindi_examples.drop(train_split_hindi.index)

train_split = pd.concat([train_split_tamil, train_split_hindi]).reset_index(drop=True)
val_split = pd.concat([val_split_tamil, val_split_hindi]).reset_index(drop=True)

train_split.to_csv(TRAIN_DATA_PATH, index=False)
val_split.to_csv(VAL_DATA_PATH, index=False)

In [None]:
datamodule = QuestionAnsweringData.from_csv(
    train_file=TRAIN_DATA_PATH,
    val_file=VAL_DATA_PATH,
    batch_size=HYPER_PARAMS.batch_size,
    backbone=HYPER_PARAMS.backbone
)

## 2. Build the task

In [None]:
model = QuestionAnsweringTask(
    backbone=HYPER_PARAMS.backbone,
    learning_rate=HYPER_PARAMS.learning_rate,
    optimizer=HYPER_PARAMS.optimizer,
)

## 3. Create the trainer and finetune the model

In [None]:
callbacks = [
    pl.callbacks.ModelCheckpoint(
        monitor='rouge2_fmeasure',
        save_top_k=1,
        filename='checkpoint/{epoch:02d}-{rouge2_fmeasure:.4f}',
        mode='max',
    ),
]
wandb_logger = pl.loggers.WandbLogger(
    project='chaii-competition',
    config=asdict(HYPER_PARAMS),
    group='XLM Roberta', 
    job_type='finetune',
    log_model=False,
)

In [None]:
trainer = Trainer(
    logger=wandb_logger,
    callbacks=callbacks,
    gpus=HYPER_PARAMS.num_gpus,
    max_epochs=HYPER_PARAMS.max_epochs,
    accumulate_grad_batches=HYPER_PARAMS.accumulate_grad_batches,
)

wandb_logger.watch(model)
trainer.finetune(model, datamodule, strategy=HYPER_PARAMS.finetuning_strategy)
wandb.finish()

## 4. Predictions

In [None]:
# Convert the prediction queries to dictionary format.
predict_data = pd.read_csv(PREDICT_DATA_PATH)
predict_data = predict_data[predict_data.columns[:3]].to_dict(orient="list")

# Answer some Questions!
predictions = model.predict(predict_data)
print(predictions)

In [None]:
# Create submission.
submission = {"id": [], "PredictionString": []}
for prediction in predictions:
    submission["id"].extend(prediction.keys())
    submission["PredictionString"].extend(prediction.values())
submission = pd.DataFrame(submission)
submission.to_csv("./submission.csv", index=False)