In [None]:
# !pip install transformers

import pandas as pd
import json
from pathlib import Path
from sklearn.model_selection import train_test_split


import torch
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering
from transformers import Trainer, TrainingArguments
from torch.utils.data import Dataset

## Data

In [None]:
!mkdir squad
!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json -O squad/train-v2.0.json
!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json -O squad/dev-v2.0.json

In [None]:
# Question -> String, Context -> String, Answer -> Dictionary containing answer text, starting index of answer, length of answer

def read_squad(path):
    path = Path(path)
    with open(path, 'rb') as f:
        squad_dict = json.load(f)

    contexts = []
    questions = []
    answers = []
    for group in squad_dict['data']:
        for passage in group['paragraphs']:
            context = passage['context']
            for qa in passage['qas']:
                question = qa['question']
                for answer in qa['answers']:
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)

    return contexts, questions, answers

train_contexts, train_questions, train_answers = read_squad('squad/train-v2.0.json')
val_contexts, val_questions, val_answers = read_squad('squad/dev-v2.0.json')

In [None]:
# Get ending integer index for each answer -  SQuaD only has starting indices

def add_end_idx(answers, contexts):
    for answer, context in zip(answers, contexts):
        gold_text = answer['text']
        start_idx = answer['answer_start']
        end_idx = start_idx + len(gold_text)

        # sometimes squad answers are off by a character or two – fix this
        if context[start_idx:end_idx] == gold_text:
            answer['answer_end'] = end_idx
        elif context[start_idx-1:end_idx-1] == gold_text:
            answer['answer_start'] = start_idx - 1
            answer['answer_end'] = end_idx - 1     # When the gold label is off by one character
        elif context[start_idx-2:end_idx-2] == gold_text:
            answer['answer_start'] = start_idx - 2
            answer['answer_end'] = end_idx - 2     # When the gold label is off by two characters

add_end_idx(train_answers, train_contexts)
add_end_idx(val_answers, val_contexts)

In [None]:
# Tokenize the contexts and questions

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
val_encodings = tokenizer(val_contexts, val_questions, truncation=True, padding=True)

In [None]:
# Convert our character start/end positions to token start/end positions

def add_token_positions(encodings, answers):
    start_positions = []
    end_positions = []
    for i in range(len(answers)):
        start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))
        end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1))

        # if start position is None, the answer passage has been truncated
        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        if end_positions[-1] is None:
            end_positions[-1] = tokenizer.model_max_length

    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

add_token_positions(train_encodings, train_answers)
add_token_positions(val_encodings, val_answers)

In [None]:
# Dataset class

class SquadDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = SquadDataset(train_encodings)
val_dataset = SquadDataset(val_encodings)

## Train

In [None]:
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")

In [None]:
# Train the model

idx = 0
model_path = f'gdrive/MyDrive/model_{idx}'

training_args = TrainingArguments(
    output_dir=model_path,          # output directory
    num_train_epochs=1, 
    evaluation_strategy="epoch"             # total number of training epochs
)

# Trainer object 

trainer = Trainer(
    model=model,                         
    args=training_args,                 
    train_dataset=train_dataset,        
    eval_dataset=val_dataset             
)

trainer.train()

## Evaluate

In [None]:
trainer.evaluate()

## Save

In [None]:
!mkdir qa_model
trainer.save_model("qa_model")

## Predict

In [None]:
test_model = DistilBertForQuestionAnswering.from_pretrained("qa_model")

In [None]:
test_question = "Attack happened on 11 September 2001"
test_context = "When did attack happen?"
test_encodings = tokenizer(test_context, test_question, truncation=True, padding=True, return_tensors="pt")

In [None]:
pred = test_model(**test_encodings)

In [None]:
answer_start = torch.argmax(pred["start_logits"])
answer_end = torch.argmax(pred["end_logits"]) + 1

answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(test_encodings["input_ids"][0][answer_start:answer_end]))

In [None]:
answer