# Script for model evaluation
## COVID-QA Analysis
### Yash Khandelwal, Kaushik Ravindran

github: https://github.com/yashskhandelwal/Covid_QA_Analysis





In [None]:
%%capture
# env setup
# install relavant libraries
!pip install datasets transformers
!pip install accelerate
!pip install humanize
!pip install millify
!pip install tqdm
!apt-get install git-lfs
!pip install codecarbon
!git lfs install

In [None]:
%%capture
# for running on tpu
!pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

In [None]:
# imports
import math, statistics, time
from collections import defaultdict
import numpy as np
from tqdm.auto import tqdm
from datetime import datetime
import torch_xla
import torch_xla.core.xla_model as xm

import torch
from codecarbon import EmissionsTracker
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, Trainer, TrainingArguments

import warnings
warnings.filterwarnings("ignore")

In [None]:
# login to hugging face
from huggingface_hub import notebook_login
notebook_login()

In [None]:
# constants
dataset = "covid_qa_deepset"
pre_trained_model_checkpoint = "twmkn9/bert-base-uncased-squad2"
model_name = "covid_qa_analysis_bert_base_uncased_squad2"
hub_model_id = "armageddon/covid_qa_analysis_bert_base_uncased_squad2"
stride = 150
max_answer_length=150

### Get the dataset

In [None]:
raw_datasets = load_dataset(dataset)

In [None]:
#Split dataset into train and test.
raw_datasets = raw_datasets["train"].train_test_split(train_size=0.9, seed=42)

### Tokenization code section

In [None]:
tokenizer = AutoTokenizer.from_pretrained(pre_trained_model_checkpoint)

In [None]:
# pre-processing for validation examples
def preprocess_validation_examples(examples):

    questions = [q.strip() for q in examples["question"]]
    context =  examples["context"]
    answers = examples["answers"] 
    
    # use model tokenizer to tokenize examples
    inputs = tokenizer(
        questions,
        context,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    # return_overflowing_tokens -- for each feature, it represents the original example it belonged to
    # return_offsets_mapping -- for each token, it returns the start and end position of the word represented by that token in the original context
    
    # pop overflow_to_sample mapping
    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i] # get original example index
        example_ids.append(examples["id"][sample_idx]) # get and store the id of the original sample index
        
        # labels in tokenized input indicating whether token belongs to question (0), context (1), or special token (None)
        sequence_ids = inputs.sequence_ids(i)  
        
        # update offset mapping so that only context offset mapping is stored and question offset mapping is discarded
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]
    
    # add a new column to inputs and return
    inputs["example_id"] = example_ids
    return inputs

In [None]:
train_valid_dataset = raw_datasets["train"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
    load_from_cache_file=False
)

In [None]:
test_valid_dataset = raw_datasets["test"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["test"].column_names,
    load_from_cache_file=False
)

### Evaluation Code Section

In [None]:
n_best = 20
metric = load_metric("squad")

In [None]:
def compute_metrics(start_logits, end_logits, features, examples, max_answer_length):
    example_to_features = defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
          start_logit = start_logits[feature_index]
          end_logit = end_logits[feature_index]
          offsets = features[feature_index]["offset_mapping"]

          start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
          end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
          for start_index in start_indexes:
              for end_index in end_indexes:
                  # Skip answers that are not fully in the context
                  if offsets[start_index] in [None, []] or offsets[end_index] in [None, []]:
                      continue
                  # Skip answers with a length that is either < 0 or > max_answer_length
                  if (
                      end_index < start_index
                      or end_index - start_index + 1 > max_answer_length
                  ):
                      continue

                  answer = {
                      "text": context[offsets[start_index][0] : offsets[end_index][1]],
                      "logit_score": start_logit[start_index] + end_logit[end_index],
                  }
                  answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

### Validating the model

In [None]:
# get finetuned model
finetuned_model = AutoModelForQuestionAnswering.from_pretrained(hub_model_id, use_auth_token=True)

# set training args
args = TrainingArguments(
    model_name,
    evaluation_strategy="no",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=False,
    hub_model_id=hub_model_id,
    push_to_hub=True,
)

# create trainer for prediction
trainer = Trainer(
      model=finetuned_model,
      args=args,
      train_dataset=None,
      eval_dataset=None,
      tokenizer=tokenizer,
  )

In [None]:
# validate on training dataset
predictions_tv = trainer.predict(train_valid_dataset)
start_logits_tv, end_logits_tv = predictions_tv.predictions
print("Metrics on training dataset:\n", compute_metrics(start_logits_tv, end_logits_tv, train_valid_dataset, raw_datasets["train"], max_answer_length))

In [None]:
# validate on test dataset
predictions = trainer.predict(test_valid_dataset)
start_logits, end_logits = predictions.predictions
print("Metrics on test dataset:\n", compute_metrics(start_logits, end_logits, test_valid_dataset, raw_datasets["test"], max_answer_length))

### Printing average response time

In [None]:
# code to find average response time

# set training args
args = TrainingArguments(
    model_name,
    evaluation_strategy="no",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=False,
    hub_model_id=hub_model_id,
    push_to_hub=True,
    log_level='critical',
    logging_strategy='no',
    disable_tqdm=True
)

# create trainer for prediction
trainer = Trainer(
      model=finetuned_model,
      args=args,
      train_dataset=None,
      eval_dataset=None,
      tokenizer=tokenizer,
)


def current_milli_time():
  return round(time.time() * 1000)

start_time = current_milli_time()
test_count = 500
for i in range(test_count):
  ds_temp = test_valid_dataset.select([i])
  predictions = trainer.predict(ds_temp)
total_time = current_milli_time()-start_time

print("Average response time is:", total_time/test_count, "ms")