**How to extract dataset names in a simple question answering task using Bert**

# Load Module

In [None]:
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
import torch

import pandas as pd

from tqdm import tqdm
tqdm.pandas()

import os
import re
import json

# Helper

In [None]:
def clean_text(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt)).strip()

def totally_clean_text(txt):
    txt = clean_text(txt)
    txt = re.sub(' +', ' ', txt)
    return txt

# Const

In [None]:
TRAIN = "../input/coleridgeinitiative-show-us-the-data/train"
TEST  = "../input/coleridgeinitiative-show-us-the-data/test"

# Load Data & Preprocess

In [None]:
## json to pandas
paper_sentense = []
for file in tqdm(os.listdir(TRAIN)):
    
    texts = []
    
    ids = file.split(".")[0]
    file_path = os.path.join(TRAIN, file)
    with open(file_path, "r") as f:
        json_datasets = json.load(f)
    
    for json_dataset in json_datasets:
        for k, v in json_dataset.items():
            if k == "text":
                text = v
            else:
                title = v
    
        paper_sentense.append([ids, title, text])
paper_sentense_df = pd.DataFrame(paper_sentense, columns=["Id", "Title", "Sentense"])

In [None]:
## cleaned
paper_sentense_df["CleanedSentense"] = paper_sentense_df["Sentense"].progress_apply(totally_clean_text)

In [None]:
# example
paper_sentense_df["CleanedSentense"].values[0]

In [None]:
# load bert

model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

def answer_question(question, answer_text):
    '''
    Takes a `question` string and an `answer_text` string (which contains the
    answer), and identifies the words within the `answer_text` that are the
    answer. Prints them out.
    '''
    # ======== Tokenize ========
    # Apply the tokenizer to the input text, treating them as a text-pair.
    input_ids = tokenizer.encode(question, answer_text)

    # Report how long the input sequence is.
    print('Query has {:,} tokens.\n'.format(len(input_ids)))
    
    if len(input_ids) > 512:
        input_ids = input_ids[:512]

    # ======== Set Segment IDs ========
    # Search the input_ids for the first instance of the `[SEP]` token.
    sep_index = input_ids.index(tokenizer.sep_token_id)

    # The number of segment A tokens includes the [SEP] token istelf.
    num_seg_a = sep_index + 1

    # The remainder are segment B.
    num_seg_b = len(input_ids) - num_seg_a

    # Construct the list of 0s and 1s.
    segment_ids = [0]*num_seg_a + [1]*num_seg_b

    # There should be a segment_id for every input token.
    assert len(segment_ids) == len(input_ids)

    # ======== Evaluate ========
    # Run our example question through the model.
    scores = model(torch.tensor([input_ids]), # The tokens representing our input text.
                                    token_type_ids=torch.tensor([segment_ids])) # The segment IDs to differentiate question from answer_text

    # ======== Reconstruct Answer ========
    # Find the tokens with the highest `start` and `end` scores.
    answer_start = torch.argmax(scores[0])
    answer_end = torch.argmax(scores[1])

    # Get the string versions of the input tokens.
    tokens = tokenizer.convert_ids_to_tokens(input_ids)

    # Start with the first token.
    answer = tokens[answer_start]

    # Select the remaining answer tokens and join them with whitespace.
    for i in range(answer_start + 1, answer_end + 1):
        
        # If it's a subword token, then recombine it with the previous token.
        if tokens[i][0:2] == '##':
            answer += tokens[i][2:]
        
        # Otherwise, add a space then the token.
        else:
            answer += ' ' + tokens[i]

    print('Answer: "' + answer + '"')

In [None]:
def question_answer(index):
    s = paper_sentense_df["CleanedSentense"].tolist()[index]
    
    print("Base:", s)
    
    question = "What is the name of the dataset you are using?"
    answer_question(question, s)

In [None]:
question_answer(0)

In [None]:
question_answer(1)

In [None]:
question_answer(7)

**Good accuracy!!**