In [None]:
!pip uninstall fsspec -qq -y
!pip install --no-index --find-links "../input/hf-datasets/wheels" datasets -qq


In [None]:

import collections
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
from datasets import Dataset
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer, AutoTokenizer, default_data_collator

import os
import sys
#os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ["WANDB_DISABLED"] = 'True'

#import torch_xla
#import torch_xla.core.xla_model as xm

device = 'cuda'#xm.xla_device()


BASE_PATH = "../input/"
TORCH_EXTERNAL_CSV = "../input/external-data-mlqa-xquad-preprocessing/mlqa_hindi.csv"

#MODEL_NAME = 'salti/bert-base-multilingual-cased-finetuned-squad'
MODEL_NAME = 'AlexKay/xlm-roberta-large-qa-multilingual-finedtuned-ru'
#MODEL_NAME = '../input/xlm-roberta-squad2/deepset/xlm-roberta-base-squad2'
#MODEL_NAME = 'deepset/xlm-roberta-large-squad2'

#MODEL_NAME.replace('/', '_') + '_exp25'

#MODEL_NAME = 'DHBaek/xlm-roberta-large-korquad-mask'

df_train_base = pd.read_csv(BASE_PATH + "chaii-hindi-and-tamil-question-answering/train.csv")
df_test = pd.read_csv(BASE_PATH + "chaii-hindi-and-tamil-question-answering/test.csv")
df_sub = pd.read_csv(BASE_PATH + "chaii-hindi-and-tamil-question-answering/sample_submission.csv")

max_length = 384 # The maximum length of a feature (question and context)
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.
batch_size = 4
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
pad_on_right = tokenizer.padding_side == "right"


# In[3]:


df_torch = pd.read_csv(TORCH_EXTERNAL_CSV)
df_torch = df_torch.reset_index().rename(columns={'index': 'id'})
df_torch['id'] = 'id' + df_torch['id'].astype(str)

df_tydiqa = pd.read_csv(BASE_PATH + 'tydiqa-goldp/tydiqa-goldp-v1.1-bengali-and-telugu.csv')
#df_tydiqa = df_tydiqa.reset_index().rename(columns={'index': 'id'})

'''
df_trns = pd.concat([pd.read_csv('../input/squad_hi.csv'),
        pd.read_csv('../input/squad_ta.csv')], axis=0, ignore_index=True)#.drop(['c_id', 'is_in'], axis=1)
#df_trns['answers'] = df_trns['answers'].map(proc)
df_trns = df_trns[df_trns.is_in].reset_index(drop=True).drop(['c_id', 'is_in'], axis=1)
tmp = df_trns['answers'].apply(lambda x: pd.Series(eval(x)[0]))
df_trns['answer_text'] = tmp['text']
df_trns['answer_start'] = tmp['answer_start']
df_trns['language'] = ''

for col in df_train_base.columns:
    df_trns[col] = df_trns[col].astype(df_train_base[col].dtype)
'''

df_train_base = pd.concat([df_train_base, df_torch, df_tydiqa], axis=0, ignore_index=True)

#df_train_base = pd.concat([df_train_base, df_torch], axis=0)


# In[4]:


df_train_base


# In[5]:


def prepare_train_features(examples):

    examples["question"] = [q.lstrip() for q in examples["question"]]
    
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_examples.pop("offset_mapping")

    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)
        sequence_ids = tokenized_examples.sequence_ids(i)
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])
            
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

def prepare_validation_features(examples):
    examples["question"] = [q.lstrip() for q in examples["question"]]
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples


def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):
    all_start_logits, all_end_logits = raw_predictions
    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # The dictionaries we have to fill.
    predictions = collections.OrderedDict()

    # Logging.
    print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

    # Let's loop over all the examples!
    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]

        min_null_score = None # Only used if squad_v2 is True.
        valid_answers = []
        
        context = example["context"]
        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features[feature_index]["offset_mapping"]

            # Update minimum null prediction.
            cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]
            if min_null_score is None or min_null_score < feature_null_score:
                min_null_score = feature_null_score

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue

                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            "score": start_logits[start_index] + end_logits[end_index],
                            "text": context[start_char: end_char]
                        }
                    )
        
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
        else:
            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
            # failure.
            best_answer = {"text": "", "score": 0.0}
        
        # Let's pick our final answer: the best one or the null answer (only for squad_v2)
        predictions[example["id"]] = best_answer["text"]

    return predictions


def submit(trainer, df_test):
    test_dataset = Dataset.from_pandas(df_test)

    test_features = test_dataset.map(prepare_validation_features, batched=True, remove_columns=test_dataset.column_names)
    test_feats_small = test_features.map(lambda example: example, remove_columns=['example_id', 'offset_mapping'])

    test_predictions = trainer.predict(test_feats_small)
    test_features.set_format(type=test_features.format["type"], columns=list(test_features.features.keys()))

    final_test_predictions = postprocess_qa_predictions(test_dataset, test_features, test_predictions.predictions)

    df_sub['PredictionString'] = df_sub['id'].apply(lambda r: final_test_predictions[r])
    df_sub.to_csv('submission.csv', index=False)
    display(df_sub.head())

def jaccard(row): 
    a = set(row['answer'].lower().split()) 
    b = set(row['prediction'].lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

def convert_answers(row):
    return {'answer_start': [row['answer_start']], 'text': [row['answer_text']]}


# In[6]:


df_train_base['answers'] = df_train_base[['answer_start', 'answer_text']].apply(convert_answers, axis=1)
df_train_base = df_train_base.sample(frac=1, random_state=114).copy()

df_train = df_train_base.reset_index(drop=True)
df_valid = df_train_base[-64:].reset_index(drop=True)

train_dataset = Dataset.from_pandas(df_train)
valid_dataset = Dataset.from_pandas(df_valid)
tokenized_train_ds = train_dataset.map(prepare_train_features, batched=True, remove_columns=train_dataset.column_names)
tokenized_valid_ds = valid_dataset.map(prepare_train_features, batched=True, remove_columns=train_dataset.column_names)


# In[7]:


#import torch
#from numba import cuda
#torch.cuda.empty_cache()
#device = cuda.get_current_device()
#device.reset()
#!nvidia-smi


# In[8]:


import torch
def loss_fn(start_logits, end_logits,
            start_positions, end_positions):
    m = torch.nn.LogSoftmax(dim=1)
    loss_fct = torch.nn.KLDivLoss()

    start_loss = loss_fct(m(start_logits), torch.nn.functional.one_hot(start_positions, max_length).type(torch.float))
    end_loss = loss_fct(m(end_logits), torch.nn.functional.one_hot(end_positions, max_length).type(torch.float))
    total_loss = (start_loss + end_loss)
    return total_loss

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        device = model.device
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        start_positions = inputs['start_positions'].to(device)
        end_positions = inputs['end_positions'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask#, start_positions=start_positions, end_positions=end_positions,
                        #token_type_ids=token_type_ids
                        )

        loss = loss_fn(outputs.start_logits, outputs.end_logits,
                       start_positions, end_positions)

        return (loss, outputs) if return_outputs else loss


def get_trainer():
    model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
    model.to(device)
    args = TrainingArguments(
    f"chaii-qa",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    gradient_accumulation_steps=32 // batch_size,
    warmup_ratio=0.1,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=2,
    weight_decay=0.01,
    )
    print('AAA', args.device)
    trainer = CustomTrainer(model, args,
                      train_dataset=tokenized_train_ds,
                      eval_dataset=tokenized_valid_ds,
                      data_collator=default_data_collator,
                      tokenizer=tokenizer
    )

    
    return trainer


# In[ ]:


trainer = get_trainer()


# In[ ]:


trainer.train()
trainer.save_model(MODEL_NAME.replace('/', '_') + '_exp33')


# In[ ]:


validation_features = valid_dataset.map(prepare_validation_features, batched=True,remove_columns=valid_dataset.column_names)


# In[ ]:


len(validation_features)


# In[ ]:


validation_features


# In[ ]:


valid_feats_small = validation_features.map(lambda example: example, remove_columns=['example_id', 'offset_mapping'])
valid_feats_small


# In[ ]:


raw_predictions = trainer.predict(valid_feats_small)


# In[ ]:


examples = valid_dataset
features = validation_features

example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
    features_per_example[example_id_to_index[feature["example_id"]]].append(i)
    
final_predictions = postprocess_qa_predictions(valid_dataset, validation_features, raw_predictions.predictions)
references = [{"id": ex["id"], "answer": ex["answers"]['text'][0]} for ex in valid_dataset]


# In[ ]:


res = pd.DataFrame(references)
res['prediction'] = res['id'].apply(lambda r: final_predictions[r])
res['jaccard'] = res[['answer', 'prediction']].apply(jaccard, axis=1)
res


# In[ ]:


print(res['jaccard'].mean())


# ## Test predict and submit
# 
# Ok, so we got 0.47 average jaccard score on our validation set (the score can differ when I re-run the notebook). Next step is to run it on test and submit :) 

# In[ ]:


#submit(trainer, df_test)

