In [None]:
!pip uninstall fsspec -qq -y
!pip install --no-index --find-links=../input/hf-datasets/wheels datasets -qq
!pip uninstall transformers -qq -y
!pip install --no-index --find-links=../input/transformers-latest-model transformers -qq

In [None]:
from datasets import Dataset
import pandas as pd
from tqdm import tqdm

import torch, gc
from transformers import default_data_collator
from hf_qa_utils import *
from passage_qa_utils import *

torch.set_grad_enabled(False)

In [None]:
test_df = pd.read_csv("../input/chaii-hindi-and-tamil-question-answering/test.csv")

# Get feature from Rembert

In [None]:
passage_ckpts = ["../input/rembert-finetuning-indicx-on-sq2-epoch3/rembert_indicx_over_squad2/checkpoint-1021"]

_, candidates, ds_feats = generate_output(passage_ckpts, test_df, 384, 128, \
                                           batch_size=256, return_feats=True, pp_cleanup=False)
passage_df = []
for qid, arr in candidates.items():
    rec = max(arr, key=lambda x: x["start_score"])
    rec["id"] = qid
    passage_df.append(rec)
passage_df = pd.DataFrame(passage_df)
cois = ["id", "feature_index"]
passage_df = passage_df[cois].drop_duplicates()
codf_main = get_char_offsets_df(ds_feats)
passage_df = pd.merge(passage_df, codf_main, on=cois, how="left") 
del codf_main
passage_df = pd.merge(passage_df, test_df[["id", "question", "context", "language"]], how="left", on="id")
passage_df["subc"] = passage_df.apply(lambda row: row["context"][row["min_ix"]:row["max_ix"]], axis=1)

del ds_feats, candidates
torch.cuda.empty_cache()
gc.collect()

# Generate QA predictions from sub-contexts

In [None]:
passage_df = passage_df[["id", "question", "subc", "language"]].reset_index(drop=True)
passage_df.rename(columns={"subc": "context"}, inplace=True)

In [None]:
hin_ckpts = ["../input/muril-finetuning-indicx-on-squad2-epoch-2/muril_lg_indix_sq2ep1/checkpoint-3480", \
            "../input/folds-1-and-2-muril-indicx-sq2ep2-finetuning/muril_lg_indix_sq2ep1_fold1/checkpoint-3463", \
            "../input/folds-1-and-2-muril-indicx-sq2ep2-finetuning/muril_lg_indix_sq2ep1_fold2/checkpoint-3452", \
            "../input/folds-3-and-4-muril-indicx-sq2ep2-finetuning/muril_lg_indix_sq2ep1_fold4/checkpoint-3472"]

all_preds, hin_candidates = generate_output(hin_ckpts, passage_df[passage_df["language"]=="hindi"], 384, 128, \
                                            batch_size=128, return_feats=False, pp_cleanup=True)

tam_ckpts = [f"../input/folds-consolidated-xlmr-qa-finetune-on-indix/fold{i}" for i in range(5)]
tam_preds, tam_candidates = generate_output(tam_ckpts, passage_df[passage_df["language"]=="tamil"], 384, 128, \
                                            batch_size=128, return_feats=False, pp_cleanup=True)

In [None]:
all_preds.update(tam_preds)
answers = pd.Series(all_preds).reset_index()
answers.columns = ["id", "PredictionString"]
test_df = pd.merge(test_df, answers, on='id', how='left')
test_df['PredictionString'].fillna('', inplace=True)

# Postprocessing

In [None]:
import re
year_ptrn = re.compile("\d{4}")

time_prefixes = ["கி.மு", "கி.பி", " ई", "ई.पू", "वर्ष", "सन"]
def update_year_answer(pred_ans):
    if any([tp in pred_ans for tp in time_prefixes]):
        return pred_ans
    ypreds = year_ptrn.findall(pred_ans)
    if len(ypreds)!=1:
        return pred_ans
    return ypreds[0]

years = ["எந்த ஆண்ட", "किस वर्ष", "किस साल"]
is_ans_year = (test_df["question"].str.contains("|".join(years), regex=True))
if is_ans_year.any():
    test_df.loc[is_ans_year, "PredictionString"] = test_df.loc[is_ans_year, "PredictionString"].apply(update_year_answer)
test_df['PredictionString'].fillna('', inplace=True)

In [None]:
import unicodedata
hin = [chr(i) for i in range(2406, 2416)]
enn = [f"{i}" for i in range(10)]

is_pred_hin = test_df["PredictionString"].apply(lambda x: set(x)<=set(hin))
test_df["trans"] = test_df["PredictionString"].copy()
test_df.loc[is_pred_hin, "trans"] = test_df.loc[is_pred_hin, "trans"].apply(lambda txt: "".join([enn[hin.index(c)] for c in txt]))

is_trans_in_context = test_df.apply(lambda row: row["trans"] in row["context"], axis=1)
test_df.loc[is_pred_hin&is_trans_in_context, "PredictionString"] = test_df.loc[is_pred_hin&is_trans_in_context, "trans"]

In [None]:
test_df["PredictionString"] = test_df["PredictionString"].apply(lambda pred: pred if "[" not in pred else pred[:pred.index("[")])

In [None]:
test_df['PredictionString'].fillna('', inplace=True)
test_df[['id', 'PredictionString']].to_csv('submission.csv', index=False)