Question Answering with a fine-tuned BERT on CoQA dataset. The results on test data are displayed using a Gradio App

Reference:
1. https://towardsdatascience.com/question-answering-with-a-fine-tuned-bert-bc4dafd45626
2. https://gradio.app/ml_examples

In [1]:
!pip install transformers gradio --quiet 

[K     |████████████████████████████████| 2.9 MB 20.5 MB/s 
[K     |████████████████████████████████| 3.6 MB 27.5 MB/s 
[K     |████████████████████████████████| 56 kB 3.2 MB/s 
[K     |████████████████████████████████| 596 kB 36.7 MB/s 
[K     |████████████████████████████████| 895 kB 39.0 MB/s 
[K     |████████████████████████████████| 3.3 MB 39.4 MB/s 
[K     |████████████████████████████████| 206 kB 43.7 MB/s 
[K     |████████████████████████████████| 1.9 MB 37.3 MB/s 
[K     |████████████████████████████████| 961 kB 59.4 MB/s 
[K     |████████████████████████████████| 63 kB 2.5 MB/s 
[K     |████████████████████████████████| 3.5 MB 45.1 MB/s 
[?25h  Building wheel for ffmpy (setup.py) ... [?25l[?25hdone
  Building wheel for flask-cachebuster (setup.py) ... [?25l[?25hdone


# Importing Libraries

In [2]:
#importing libraries 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

!pip install transformers

import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer



# Data Loading

In [3]:
coqa = pd.read_json('http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json')
coqa.head()

Unnamed: 0,version,data
0,1,"{'source': 'wikipedia', 'id': '3zotghdk5ibi9ce..."
1,1,"{'source': 'cnn', 'id': '3wj1oxy92agboo5nlq4r7..."
2,1,"{'source': 'gutenberg', 'id': '3bdcf01ogxu7zdn..."
3,1,"{'source': 'cnn', 'id': '3ewijtffvo7wwchw6rtya..."
4,1,"{'source': 'gutenberg', 'id': '3urfvvm165iantk..."


# Data Preprocessing

In [4]:
del coqa["version"]

In [5]:
cols = ["text","question","answer"]

# j = 1
comp_list = []
for index, row in coqa.iterrows():
    for i in range(len(row["data"]["questions"])):
        temp_list = []
#         temp_list.append(j)
        temp_list.append(row["data"]["story"])
        temp_list.append(row["data"]["questions"][i]["input_text"])
        temp_list.append(row["data"]["answers"][i]["input_text"])
        comp_list.append(temp_list)
#     j += 1
new_df = pd.DataFrame(comp_list, columns=cols)

In [6]:
new_df.to_csv("CoQA_data.csv", index=False)


In [7]:
data = pd.read_csv("CoQA_data.csv")
data.head()

Unnamed: 0,text,question,answer
0,"The Vatican Apostolic Library (), more commonl...",When was the Vat formally opened?,It was formally established in 1475
1,"The Vatican Apostolic Library (), more commonl...",what is the library for?,research
2,"The Vatican Apostolic Library (), more commonl...",for what subjects?,"history, and law"
3,"The Vatican Apostolic Library (), more commonl...",and?,"philosophy, science and theology"
4,"The Vatican Apostolic Library (), more commonl...",what was started in 2014?,a project


In [23]:
print("Number of question and answers: ", len(data))
print(data.text[0])

Number of question and answers:  108647
The Vatican Apostolic Library (), more commonly called the Vatican Library or simply the Vat, is the library of the Holy See, located in Vatican City. Formally established in 1475, although it is much older, it is one of the oldest libraries in the world and contains one of the most significant collections of historical texts. It has 75,000 codices from throughout history, as well as 1.1 million printed books, which include some 8,500 incunabula. 

The Vatican Library is a research library for history, law, philosophy, science and theology. The Vatican Library is open to anyone who can document their qualifications and research needs. Photocopies for private study of pages from books published between 1801 and 1990 can be requested in person or by mail. 

In March 2014, the Vatican Library began an initial four-year project of digitising its collection of manuscripts, to be made available online. 

The Vatican Secret Archives were separated from 

# Downloading the Question Answer Pretrained BERT model from Hugging Faces

In [9]:
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

Downloading:   0%|          | 0.00/443 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.25G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

In [10]:
random_num = np.random.randint(0,len(data))

question = data["question"][random_num]
text = data["text"][random_num]

In [11]:
input_ids = tokenizer.encode(question, text)
print("The input has a total of {} tokens.".format(len(input_ids)))

The input has a total of 344 tokens.


In [12]:
tokens = tokenizer.convert_ids_to_tokens(input_ids)

for token, id in zip(tokens, input_ids):
    print('{:8}{:8,}'.format(token,id))

[CLS]        101
who        2,040
is         2,003
boy        2,879
##ce       3,401
?          1,029
[SEP]        102
new        2,047
york       2,259
(          1,006
cnn       13,229
)          1,007
-          1,011
-          1,011
a          1,037
new        2,047
york       2,259
man        2,158
arrested   4,727
in         1,999
connection   4,434
with       2,007
the        1,996
stabbing  21,690
of         1,997
two        2,048
children   2,336
in         1,999
brooklyn   6,613
may        2,089
be         2,022
linked     5,799
to         2,000
another    2,178
stabbing  21,690
in         1,999
a          1,037
manhattan   7,128
subway    10,798
,          1,010
a          1,037
law        2,375
enforcement   7,285
official   2,880
told       2,409
cnn       13,229
thursday   9,432
.          1,012
police     2,610
believe    2,903
daniel     3,817
st         2,358
.          1,012
hubert    15,346
,          1,010
27         2,676
,          1,010
was        2,001
out     

In [13]:

#first occurence of [SEP] token
sep_idx = input_ids.index(tokenizer.sep_token_id)
print(sep_idx)

#number of tokens in segment A - question
num_seg_a = sep_idx+1
print(num_seg_a)

#number of tokens in segment B - text
num_seg_b = len(input_ids) - num_seg_a
print(num_seg_b)

segment_ids = [0]*num_seg_a + [1]*num_seg_b
print(segment_ids)

assert len(segment_ids) == len(input_ids)

6
7
337
[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [14]:
#token input_ids to represent the input
#token segment_ids to differentiate our segments - text and question 
output = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([segment_ids]))

In [15]:
#tokens with highest start and end scores
answer_start = torch.argmax(output.start_logits)
answer_end = torch.argmax(output.end_logits)

In [16]:
if answer_end >= answer_start:
    answer = " ".join(tokens[answer_start:answer_end+1])
else:
    print("I am unable to find the answer to this question. Can you please ask another question?")
    
print("Text:\n{}".format(text.capitalize()))
print("\nQuestion:\n{}".format(question.capitalize()))
print("\nAnswer:\n{}.".format(answer.capitalize()))

Text:
New york (cnn) -- a new york man arrested in connection with the stabbing of two children in brooklyn may be linked to another stabbing in a manhattan subway, a law enforcement official told cnn thursday. 

police believe daniel st. hubert, 27, was out on parole when he stabbed two young children inside an elevator -- killing one of them. 

st. hubert was arrested by detectives around 8 p.m. wednesday. 

he was arrested around the same time that chief of detectives robert boyce identified st. hubert by name for the first time as the suspect in the attack. 

detectives were obtaining evidence thursday that could link him to a fatal stabbing on the subway in the chelsea neighborhood of manhattan, the official said. 

investigators were executing search warrants to see if he is linked to additional stabbings since his release from prison on may 23, a law enforcement official said. 

law enforcement has been involved with st. hubert plenty in the past, including nine arrests, though 

In [17]:
start_scores = output.start_logits.detach().numpy().flatten()
end_scores = output.end_logits.detach().numpy().flatten()

token_labels = []
for i, token in enumerate(tokens):
    token_labels.append("{}-{}".format(token,i))

In [18]:
print(len(token_labels))


344


In [19]:
answer = tokens[answer_start]

for i in range(answer_start+1, answer_end+1):
    if tokens[i][0:2] == "##":
        answer += tokens[i][2:]
    else:
        answer += " " + tokens[i]

# Function to get prediction from the model

In [20]:
def question_answer(context, question):
    
    #tokenize question and text in ids as a pair
    input_ids = tokenizer.encode(question, context)
    
    #string version of tokenized ids
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    
    #segment IDs
    #first occurence of [SEP] token
    sep_idx = input_ids.index(tokenizer.sep_token_id)

    #number of tokens in segment A - question
    num_seg_a = sep_idx+1

    #number of tokens in segment B - text
    num_seg_b = len(input_ids) - num_seg_a
    
    #list of 0s and 1s
    segment_ids = [0]*num_seg_a + [1]*num_seg_b
    
    assert len(segment_ids) == len(input_ids)
    
    #model output using input_ids and segment_ids
    output = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([segment_ids]))
    
    #reconstructing the answer
    answer_start = torch.argmax(output.start_logits)
    answer_end = torch.argmax(output.end_logits)

    if answer_end >= answer_start:
        answer = tokens[answer_start]
        for i in range(answer_start+1, answer_end+1):
            if tokens[i][0:2] == "##":
                answer += tokens[i][2:]
            else:
                answer += " " + tokens[i]
                
    if answer.startswith("[CLS]"):
        answer = "Unable to find the answer to your question."

    return answer.capitalize()
    # print("\nAnswer:\n{}".format())

# Testing the function

In [21]:
text = """Victoria has a written constitution enacted in 1975, but based on the 1855 colonial constitution, passed by the United Kingdom Parliament as the Victoria Constitution Act 1855, which establishes the Parliament as the state's law-making body for matters coming under state responsibility. The Victorian Constitution can be amended by the Parliament of Victoria, except for certain 'entrenched' provisions that require either an absolute majority in both houses, a three-fifths majority in both houses, or the approval of the Victorian people in a referendum, depending on the provision."""
question = "When did Victoria enact its constitution?"

question_answer(text, question)

'1975'

# Gradio App

In [22]:
import gradio as gr


gr.Interface(question_answer,
    [
        gr.inputs.Textbox(lines=7, label="Context"),
        gr.inputs.Textbox(label="Question"),
    ],
    gr.outputs.Textbox(label="Answer")).launch()

Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`
This share link will expire in 72 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted
Running on External URL: https://21429.gradio.app
Interface loading below...


(<Flask 'gradio.networking'>,
 'http://127.0.0.1:7860/',
 'https://21429.gradio.app')