In [43]:
!pip install transformers



In [162]:
LEARNING_RATE=5e-5
EPOCHS=20
BATCH_SIZE=3

In [163]:
import json

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam

from transformers import BertTokenizer, BertForQuestionAnswering
from transformers.tokenization_utils_base import PaddingStrategy

from pprint import pprint
import textwrap

# Wrap text to 80 characters.
wrapper = textwrap.TextWrapper(width=80) 

In [164]:
CPU_DEVICE = 'cpu'
CUDA_DEVICE = 'cuda'
DEVICE = CUDA_DEVICE if torch.cuda.is_available() else CPU_DEVICE

In [165]:
class Question:
    def __init__(self, text, answer, context, isImpossible = False) -> None:
        self.text = text
        self.context = context
        if isImpossible:
            self.answer = (-1, -1)
        else:
            endCharIndex = answer['answer_start'] + len(answer['text']) - 1
            whitespacesBeforeAnswer = 0
            whitespacesInAnswer = 0
            for i in context.whitespaces:
                if i >= answer['answer_start']:
                    if i < endCharIndex:
                        whitespacesInAnswer += 1
                    else:
                        break
                else:
                    whitespacesBeforeAnswer += 1
            noWhitespaceStart = answer['answer_start'] - whitespacesBeforeAnswer
            noWhitespaceEnd = noWhitespaceStart + len(answer['text']) - 1 - whitespacesInAnswer
            self.answer = context.getAnswerTokenIndexes(noWhitespaceStart, noWhitespaceEnd)

    def __repr__(self) -> str:
        return str({
            "text": self.text,
            "answer_start": self.answer[0],
            "answer_end": self.answer[1],
            "answer": ' '.join(self.context.tokens[self.answer[0]:self.answer[1]+1])
        })

class QuestionContext:
    def __init__(self, text, tokenizer) -> None:
        self.text = text
        self.tokenIds = tokenizer(text)['input_ids']
        self.tokens = tokenizer.convert_ids_to_tokens(self.tokenIds)
        whitespaces = []
        for i, c in enumerate(text):
            if c == ' ':
                whitespaces.append(i)
        
        self.whitespaces = tuple(whitespaces)

    def getAnswerTokenIndexes(self, startCharIndex, endCharIndex):
        answerStart = -1
        answerEnd = -1
        currChar = 0
        for index, token in enumerate(self.tokens):
            if (index != 0) and (index != len(self.tokens) - 1):
                cleanToken = token.replace('##', '')
                for c in cleanToken:
                    if currChar == startCharIndex:
                        answerStart = index
                    if currChar == endCharIndex:
                        answerEnd = index
                        return (answerStart, answerEnd)
                    currChar += 1

In [166]:
questions = []
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

with open('sample.json') as samplesFile:
    samplesRaw = json.load(samplesFile)['data']
    for group in samplesRaw:
        for paragraph in group['paragraphs']:
            context = QuestionContext(paragraph['context'], tokenizer)
            for qa in paragraph['qas']:
                questions.append(Question(qa['question'], qa['answers'][0], context, qa['is_impossible']))
#pprint(questions)

In [167]:
class QuestionsDataset(torch.utils.data.Dataset):
    def __init__(self, questions) -> None:
        super().__init__()
        self.questions = [q.text for q in questions]
        self.contexts = [q.context.text for q in questions]
        self.answers = [torch.tensor(q.answer) for q in questions]

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, index):
        return self.questions[index], self.contexts[index], self.answers[index]

In [168]:
#ret = tokenizer._batch_encode_plus([['is', 'hi oops'], ['of', 'hello'], ['i am good, thanks', 'haha']], max_length=10, padding_strategy=PaddingStrategy.MAX_LENGTH)
#for id in ret["input_ids"][2]:
#    print(tokenizer.convert_ids_to_tokens(id))

In [169]:
trainDataset = QuestionsDataset(questions)
trainSetLoader = DataLoader(trainDataset, batch_size=BATCH_SIZE, shuffle=False)
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased').to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased a

In [170]:
def predictionsF1Score(modelAnswers, trueAnswers):

    def findAnswerF1(modelAnswer, trueAnswer):
        modelSequence = range(modelAnswer[0], modelAnswer[1] + 1 )
        trueSequence = range(trueAnswer[0], trueAnswer[1] + 1 )
        numCommon = len(set(trueSequence).intersection(modelSequence))

        if numCommon == 0:
            return 0
        
        precision = 1.0 * numCommon / len(trueSequence)
        recall = 1.0 * numCommon / len(modelSequence)
        f1 = (2 * precision * recall) / (precision + recall)
        return f1

    totalF1 = 0
    for model, gold in zip(modelAnswers, trueAnswers):
        totalF1 += findAnswerF1(model, gold)

    return totalF1/len(trueAnswers)

def predictionsExactScore(modelAnswers, trueAnswers):
    correct = 0
    for model, true in zip(modelAnswers, trueAnswers):
        correct += int( (model[0] == true[0]) and (model[1] == true[1]) )
    
    return correct/len(trueAnswers)

def getPredictedAnswers(startLogits, endLogits):
    softmaxStart = torch.log_softmax(startLogits, dim = 1)
    _, start = torch.max(softmaxStart, dim = 1)

    softmaxEnd = torch.log_softmax(endLogits, dim = 1)
    _, end = torch.max(softmaxEnd, dim = 1)
    return (start.cpu().detach().numpy(), end.cpu().detach().numpy())

In [171]:
for epoch in range(EPOCHS):
    model.train()
    epochExactBatchScores = []
    epochBatchLosses = []
    epochBatchF1 = []
    for batchQuestions, batchContexts, batchAnswers in trainSetLoader:
        qaPairs = [[question, answer] for question, answer in zip(batchQuestions, batchContexts)]
        tok = tokenizer._batch_encode_plus(qaPairs, padding_strategy=PaddingStrategy.LONGEST, return_tensors="pt")
        inputIds = tok['input_ids'].to(DEVICE)
        segmentIds = tok['token_type_ids'].to(DEVICE)
        attentionMask = tok['attention_mask'].to(DEVICE)
        startPositions = batchAnswers[:, 0].to(DEVICE)
        endPositions = batchAnswers[:, 1].to(DEVICE)
        
        outputs = model(input_ids=inputIds, token_type_ids=segmentIds, attention_mask=attentionMask, start_positions=startPositions, end_positions=endPositions)    
        batchLoss = outputs[0]

        optimizer.zero_grad()

        batchLoss.backward()

        optimizer.step()

        startPredictions, endPredictions = getPredictedAnswers(outputs.start_logits, outputs.end_logits)
        modelAnswers = np.vstack((startPredictions, endPredictions)).T
        
        epochExactBatchScores.append(predictionsExactScore(modelAnswers, batchAnswers))
        epochBatchLosses.append(batchLoss.item())
        epochBatchF1.append(predictionsF1Score(modelAnswers, batchAnswers))
    
    print(f"Exact: {sum(epochExactBatchScores)/len(epochExactBatchScores)}")
    print(f"F1: {sum(epochBatchF1)/len(epochBatchF1)}")
    print(f"Loss: {sum(epochBatchLosses)/len(epochBatchLosses)}")

Exact: 0.0
F1: 0.03068988553926689
Loss: 4.978862285614014
Exact: 0.0
F1: 0.05583333499787404
Loss: 3.8857779502868652
Exact: 0.0
F1: 0.053619466994532364
Loss: 3.3339975357055662
Exact: 0.0
F1: 0.05793304125304495
Loss: 2.9933157444000242
Exact: 0.06666666666666667
F1: 0.14472502120291963
Loss: 2.6169959545135497
Exact: 0.0
F1: 0.10071221140063616
Loss: 2.6307623386383057
Exact: 0.06666666666666667
F1: 0.23193998092303175
Loss: 2.3168453216552733
Exact: 0.13333333333333333
F1: 0.3329004329004329
Loss: 2.1080045223236086
Exact: 0.26666666666666666
F1: 0.4761604010025063
Loss: 1.7295691251754761
Exact: 0.26666666666666666
F1: 0.4208080227396123
Loss: 1.5060613632202149
Exact: 0.6
F1: 0.7871264367816091
Loss: 1.2938761115074158
Exact: 0.2
F1: 0.47111111111111104
Loss: 1.6968835592269897
Exact: 0.5333333333333334
F1: 0.6677931387608808
Loss: 0.9333142399787903
Exact: 0.39999999999999997
F1: 0.5750839002267574
Loss: 1.1183129668235778
Exact: 0.5333333333333333
F1: 0.7044444444444444
Loss: 