In [1]:
from transformers import AutoTokenizer, SplinterForQuestionAnswering
import torch

  from .autonotebook import tqdm as notebook_tqdm


# Downloading the Model
Pretrained on the QASS layer. 

QASS layer is a token contextualize to question represtentations

In [3]:
tokenizer = AutoTokenizer.from_pretrained("tau/splinter-base-qass")
model = SplinterForQuestionAnswering.from_pretrained("tau/splinter-base-qass")

In [55]:
def search_span_text(question,passage):
    inputs = tokenizer(question, passage, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)

    # get the start and end index
    answer_start_index = outputs.start_logits.argmax()
    answer_end_index = outputs.end_logits.argmax()
    
    print(f"start index: {answer_start_index}")
    print(f"end index: {answer_end_index}")
    predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
    
    return tokenizer.decode(predict_answer_tokens)

In [57]:
question = "Who was Jim Henson?"
text = 'Jim Henson was a nice puppet'

question = 'What is the capital city of Indonesia?'
text = "Indonesia is a big country, it has a lot of island. Jakarta, the capital city, located in Java island."

search_span_text(question,text)

start index: 1
end index: 7


'What is the capital city of Indonesia'

In [44]:
target_start_index = torch.tensor([14])
target_end_index = torch.tensor([15])

inputs = tokenizer(question, text, return_tensors="pt")

outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
loss = outputs.loss
loss

tensor(7.9873, grad_fn=<DivBackward0>)

# Testing on the Data

In [66]:
import pandas as pd
import re
import pprint

In [70]:
def clean_text(text):
    text = re.sub(r'(\s\?)',' ',text)
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r"\b\?\b", "\'", text)
    text = re.sub(r"(,\?)",",", text)
    text = re.sub(r"\?+", "?", text)
    text = text.strip()
    return text

In [60]:
df = pd.read_csv('news_dataset.csv', encoding='iso-8859-1')
df.head()

Unnamed: 0,id,author,date,year,month,topic,article
0,17307,Marlise Simons,1/01/2017,2017,1,architecture,PARIS ? When the Islamic State was about to...
1,17292,Andy Newman,31/12/2016,2016,12,art,Angels are everywhere in the Mu?iz family?s ap...
2,17298,Emma G. Fitzsimmons,2/01/2017,2017,1,business,Finally. The Second Avenue subway opened in Ne...
3,17311,Carl Hulse,3/01/2017,2017,1,business,WASHINGTON ? It?s or time for Republica...
4,17339,Jim Rutenberg,5/01/2017,2017,1,business,"For Megyn Kelly, the shift from Fox News to NB..."


In [73]:
i = 2
passage = clean_text(df.article.iloc[i])
pprint.pprint(passage)

('Finally. The Second Avenue subway opened in New York City on Sunday, with '
 'thousands of riders flooding into its polished stations to witness a piece '
 'of history nearly a century in the making. They descended beneath the '
 'streets of the Upper East Side of Manhattan to board Q trains bound for '
 'Coney Island in Brooklyn. They cheered. Their eyes filled with tears. They '
 'snapped selfies in front of colorful mosaics lining the walls of the '
 'stations. It was the first day of 2017, and it felt like a new day for a '
 'city that for so long struggled to build this sorely needed subway line. In '
 'a rare display of unbridled optimism from hardened New Yorkers, they arrived '
 'with huge grins and wide eyes, taking in the bells and whistles at three new '
 'stations. I was very choked up, Betsy Morris, 70, said as she rode the first '
 'train to leave the 96th Street station, at noon. How do you explain '
 "something that you never thought would happen? It's going to change

In [74]:
question = 'When did the Second Avenue subway open?'

In [80]:
search_span_text(question,passage[:151])


start index: 14
end index: 18


'The Second Avenue subway opened'