In [10]:
!pip install transformers



In [11]:
import json

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam

from transformers import BertTokenizer, BertForQuestionAnswering

from pprint import pprint
import textwrap

# Wrap text to 80 characters.
wrapper = textwrap.TextWrapper(width=80) 

In [12]:
class Question:
    def __init__(self, text, answer, context, isImpossible = False) -> None:
        self.text = text
        self.context = context
        if isImpossible:
            self.answer = (-1, -1)
        else:
            endCharIndex = answer['answer_start'] + len(answer['text']) - 1
            whitespacesBeforeAnswer = 0
            whitespacesInAnswer = 0
            for i in context.whitespaces:
                if i >= answer['answer_start']:
                    if i < endCharIndex:
                        whitespacesInAnswer += 1
                    else:
                        break
                else:
                    whitespacesBeforeAnswer += 1
            noWhitespaceStart = answer['answer_start'] - whitespacesBeforeAnswer
            noWhitespaceEnd = noWhitespaceStart + len(answer['text']) - 1 - whitespacesInAnswer
            self.answer = context.getAnswerTokenIds(noWhitespaceStart, noWhitespaceEnd)

    def __repr__(self) -> str:
        return str({
            "text": self.text,
            "answer_start": self.answer[0],
            "answer_end": self.answer[1]
        })

class QuestionContext:
    def __init__(self, text, tokenizer) -> None:
        self.text = text
        self.tokenIds = tokenizer(text)['input_ids']
        self.tokens = tokenizer.convert_ids_to_tokens(self.tokenIds)
        whitespaces = []
        for i, c in enumerate(text):
            if c == ' ':
                whitespaces.append(i)
        
        self.whitespaces = tuple(whitespaces)

    def getAnswerTokenIds(self, startCharIndex, endCharIndex):
        answerStart = -1
        answerEnd = -1
        currChar = 0
        for index, token in enumerate(self.tokens):
            if (index != 0) and (index != len(self.tokens) - 1):
                cleanToken = token.replace('##', '')
                for c in cleanToken:
                    if currChar == startCharIndex:
                        answerStart = index
                    if currChar == endCharIndex:
                        answerEnd = index
                        return (answerStart, answerEnd)
                    currChar += 1


In [13]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased a

In [14]:
questions = []
with open('sample.json') as samplesFile:
    samplesRaw = json.load(samplesFile)['data']
    for group in samplesRaw:
        for paragraph in group['paragraphs']:
            context = QuestionContext(paragraph['context'], tokenizer)
            for qa in paragraph['qas']:
                questions.append(Question(qa['question'], qa['answers'][0], context, qa['is_impossible']))
pprint(questions)

[{'text': 'When did Beyonce start becoming popular?', 'answer_start': 67, 'answer_end': 70},
 {'text': 'What areas did Beyonce compete in when she was growing up?', 'answer_start': 55, 'answer_end': 57},
 {'text': "When did Beyonce leave Destiny's Child and become a solo singer?", 'answer_start': 128, 'answer_end': 128},
 {'text': 'In what city and state did Beyonce  grow up? ', 'answer_start': 47, 'answer_end': 49},
 {'text': 'In which decade did Beyonce become famous?', 'answer_start': 69, 'answer_end': 70},
 {'text': 'In what R&B group was she the lead singer?', 'answer_start': 81, 'answer_end': 84},
 {'text': 'What album made her a worldwide known artist?', 'answer_start': 124, 'answer_end': 126},
 {'text': "Who managed the Destiny's Child group?", 'answer_start': 91, 'answer_end': 92},
 {'text': 'When did Beyoncé rise to fame?', 'answer_start': 69, 'answer_end': 70},
 {'text': "What role did Beyoncé have in Destiny's Child?", 'answer_start': 72, 'answer_end': 73},
 {'text': 'What 

In [15]:
# BERT only needs the token IDs, but for the purpose of inspecting the 
# tokenizer's behavior, let's also get the token strings and display them.
input_ids = tokenizer.encode(context.text)
tokens = tokenizer.convert_ids_to_tokens(input_ids)

# For each token and its id...
for index, (token, id) in enumerate(zip(tokens, input_ids)):
    
    # If this is the [SEP] token, add some space around it to make it stand out.
    if id == tokenizer.sep_token_id:
        print('')
    
    # Print the token string and its ID in two columns.
    print('{}) {:<12} {:>6,}'.format(index, token, id))

    if id == tokenizer.sep_token_id:
        print('')
    

0) [CLS]           101
1) beyonce      20,773
2) gi           21,025
3) ##selle      19,358
4) knowles      22,815
5) -             1,011
6) carter        5,708
7) (             1,006
8) /             1,013
9) bi           12,170
10) ##ː          23,432
11) ##ˈ          29,715
12) ##j           3,501
13) ##ɒ          29,678
14) ##nse        12,325
15) ##ɪ          29,685
16) /             1,013
17) bee          10,506
18) -             1,011
19) yo           10,930
20) ##n           2,078
21) -             1,011
22) say           2,360
23) )             1,007
24) (             1,006
25) born          2,141
26) september     2,244
27) 4             1,018
28) ,             1,010
29) 1981          3,261
30) )             1,007
31) is            2,003
32) an            2,019
33) american      2,137
34) singer        3,220
35) ,             1,010
36) songwriter    6,009
37) ,             1,010
38) record        2,501
39) producer      3,135
40) and           1,998
41) actress       3,883
42

In [16]:
#print(wrapper.fill(context.text))