In [41]:
import torch
import wikipedia as wiki
from collections import OrderedDict
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline

In [128]:
class WikiAnswerFinder:
    def __init__(self, model_name="deepset/bert-base-cased-squad2"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForQuestionAnswering.from_pretrained(model_name)
        self.max_len = self.model.config.max_position_embeddings


    def tokenize(self, question, context):
        self.inputs = self.tokenizer.encode_plus(question, context, return_tensors='pt')
        self.input_ids = self.inputs["input_ids"].tolist()[0]
        # print(f"The tokenizer translates {len(text)} texts into {len(inputs['input_ids'][0])} tokens.")
        self.inputs = self.chunkify()


    def chunkify(self):
        qmask = self.inputs['token_type_ids'] < 1
        qt = torch.masked_select(self.inputs['input_ids'], qmask)
        # print(f"The question consists of {qt.shape[0]} tokens.")

        chunk_size = self.max_len - qt.shape[0] - 1
        # print(f"Each chunk will contain {chunk_size - 2} tokens of the Wikipedia article.")
        # print(f"Each Chunk size = [CLS] [question={qt.shape[0]}] [SEP] [context={chunk_size-2}] [SEP] = {max_size}")

        chunked_input = OrderedDict()

        for k, v in self.inputs.items():
            q = torch.masked_select(v, qmask)
            c = torch.masked_select(v, ~qmask)
            chunks = torch.split(c, chunk_size)

            for i, chunk in enumerate(chunks):
                if i not in chunked_input: chunked_input[i] = {}
                qc_pair = torch.cat((q, chunk))
                if i != len(chunks) - 1:
                    # append [SEP] to the end of (input_ids)
                    # append 1 to the end of (attention_mask) or (token_type_ids)
                    qc_pair = torch.cat((qc_pair, torch.tensor([102 if k == "input_ids" else 1])))
                    
                chunked_input[i][k] = torch.unsqueeze(qc_pair, dim=0)
        return chunked_input


    def get_answer(self, question):
        results = wiki.search(question)
        print("Question:", question)
        print("Wikipedia top search result:", results[0])

        page = wiki.page(results[0])
        context = page.content
        self.tokenize(question, context)

        answer = ""
        for _, chunk in self.inputs.items():
            tokens = self.tokenizer.convert_ids_to_tokens(chunk["input_ids"][0])
            predicts = self.model(**chunk)
            answer_start = torch.argmax(predicts.start_logits)
            answer_end = torch.argmax(predicts.end_logits) + 1
            chunk_answer = self.tokenizer.convert_tokens_to_string(tokens[answer_start:answer_end])

            if not chunk_answer.startswith("[CLS]"):
                answer += chunk_answer + " / "
        return answer

In [129]:
finder = WikiAnswerFinder() 

In [130]:
finder.get_answer('Who is the current president of Taiwan?')

Question: Who is the current president of Taiwan?
Wikipedia top search result: President of the Republic of China


Token indices sequence length is longer than the specified maximum sequence length for this model (3254 > 512). Running this sequence through the model will result in indexing errors


'Tsai Ing - wen / Lai Ching - te / Donald Trump / Chiang Kai - shek / '

In [136]:
finder.get_answer('What is the length of the Great Wall of China?')

Question: What is the length of the Great Wall of China?
Wikipedia top search result: Great Wall of China


'12, 000 mi ) / 6 mi ) / '

In [147]:
finder.get_answer("What is the distance between the moon and the earth?")

Question: What is the distance between the moon and the earth?
Wikipedia top search result: Moon


''