In [1]:
import spacy
from spacy import displacy

import torch
import numpy as np
import pandas as pd

from transformers import BertForQuestionAnswering
from transformers import BertTokenizer


nlp = spacy.load("en_core_web_sm")

# python -m spacy download en_core_web_sm

In [22]:
class ModelBertManager:
    def __init__(self):
        self.model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
        self.tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
        
    def questions(self, question, text):
        try:
            answer = str()

            input_ids = self.tokenizer.encode(question, text)
            tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
            tokens_idx = input_ids.index(self.tokenizer.sep_token_id)

            num_seg_a = tokens_idx + 1
            num_seg_b = len(input_ids) - num_seg_a
            
            x_token = tokens_idx + 1
            y_token = len(input_ids) - x_token
            segment_ids = [0] * x_token + [1] * y_token
           
            output = self.model(torch.tensor([input_ids]), token_type_ids=torch.tensor([segment_ids]))

            answer_start_index = torch.argmax(output.start_logits)
            answer_end_index = torch.argmax(output.end_logits)

            if answer_end_index >= answer_start_index:
                answer = tokens[answer_start_index]
                for i in range(answer_start_index + 1, answer_end_index + 1):
                    answer += f" {tokens[i]}"
            else:
                return ("Sorry! Try other question :(", False)
            
            if answer.startswith("[CLS]"):
                return ("Sorry! Try other question :(", False)
            return (f"{question} {answer}", True)
        except Exception as error:
            return (error, False)

In [23]:
model = ModelBertManager()

In [24]:
text_full = ""
with open("obama.txt", encoding="utf-8", errors="ignored") as fb:
    for current in fb:
        text_full += current
        

doc = nlp(text_full)
displacy.render(doc, style="ent")

In [27]:
def beautiful_format(text_full: str):
    if text_full[1]:
        doc = nlp(text_full[0])
        displacy.render(doc, style="ent")

beautiful_format(model.questions("When Barack Obama born?", text_full))

In [28]:
beautiful_format(model.questions("When Barack Obama served as president?", text_full))

In [29]:
beautiful_format(model.questions("Where Barack Obama born?", text_full))

In [30]:
beautiful_format(model.questions("Which political party is Obama a part of?", text_full))

In [32]:
beautiful_format(model.questions("What university did Obama graduate from?", text_full))