In [None]:
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
import torch

In [None]:
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"

In [None]:
model = BertForQuestionAnswering.from_pretrained(model_name)

In [None]:
tokenizer = BertTokenizer.from_pretrained(model_name)

In [None]:
question = "When was the first DVD released?"
answer_document = "The first DVD (Digital Versatile Disc) was released on March 24, 1997. It was a movie titled 'Twister' and was released in Japan. DVDs quickly gained popularity as a replacement for VHS tapes and became a common format for storing and distributing digital video and data."

In [None]:
encoding = tokenizer.encode_plus(text=question, text_pair=answer_document, return_tensors="pt")
input = encoding["input_ids"]
sentence_embeding = encoding["token_type_ids"]
tokens = tokenizer.convert_ids_to_tokens(input[0])
print({"input_ids": input, "token_type_ids": sentence_embeding})

In [None]:
tokenizer.decode(101)

In [None]:
tokenizer.decode(102)

In [None]:
output = model(input_ids=torch.tensor(input), token_type_ids=torch.tensor(sentence_embeding))

In [None]:
start_index = torch.argmax(output.start_logits)
end_index = torch.argmax(output.end_logits)

print({"start_index": start_index, "end_index": end_index})

In [None]:
answer = ' '.join(tokens[start_index : end_index + 1])
print(f"Answer: {answer}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
start_scores = output.start_logits.detach().numpy()[0]
end_scores = output.end_logits.detach().numpy()[0]

In [None]:
token_labels = []
for (i, token) in enumerate(tokens):
    token_labels.append("{:} - {:>2}".format(i, token))

ax = sns.barplot(x=token_labels, y=start_scores, color="blue", alpha=0.5)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90, ha="center")
ax.grid(True)

In [None]:
ax = sns.barplot(x=token_labels, y=end_scores, color="red", alpha=0.5)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90, ha="center") 
ax.grid(True)