In [96]:
import json
import re
import pandas as pd
import numpy as np
from tqdm import tqdm
from nltk.corpus import stopwords
from tqdm import tqdm, tqdm_notebook
import time

In [2]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

In [78]:
from transformers import AutoTokenizer, BertForTokenClassification

In [271]:
file_path = 'v1.0-simplified_nq-dev-all.jsonl'
with open(file_path, 'r') as json_file:
    json_list = list(json_file)

data = []
for json_str in json_list:
    data.append(json.loads(json_str))

In [5]:
def get_nq_tokens(simplified_nq_example):

    if "document_text" not in simplified_nq_example:
        raise ValueError("`get_nq_tokens` should be called on a simplified NQ"
                     "example that contains the `document_text` field.")

    return simplified_nq_example["document_text"].split(" ")


def simplify_nq_example(nq_example):

    def _clean_token(token):
        return re.sub(u" ", "_", token["token"])

    text = " ".join([_clean_token(t) for t in nq_example["document_tokens"]])

    def _remove_html_byte_offsets(span):
        if "start_byte" in span:
            del span["start_byte"]

        if "end_byte" in span:
            del span["end_byte"]

        return span

    def _clean_annotation(annotation):
        annotation["long_answer"] = _remove_html_byte_offsets(
            annotation["long_answer"])
        annotation["short_answers"] = [
            _remove_html_byte_offsets(sa) for sa in annotation["short_answers"]
        ]
        return annotation

    simplified_nq_example = {
      "question_text": nq_example["question_text"],
      "example_id": nq_example["example_id"],
      "document_url": nq_example["document_url"],
      "document_text": text,
      "long_answer_candidates": [
          _remove_html_byte_offsets(c)
          for c in nq_example["long_answer_candidates"]
      ],
      "annotations": [_clean_annotation(a) for a in nq_example["annotations"]]
    }

    if len(get_nq_tokens(simplified_nq_example)) != len(
      nq_example["document_tokens"]):
        raise ValueError("Incorrect number of tokens.")

    return simplified_nq_example

In [82]:
SAMPLE_RATE = 15
KAGGLE_FORMAT = False

def get_question_and_document(line):
    question = line['question_text']
    text = line['document_text'].split(' ')
    annotations = line['annotations'][0]
    
    return question, text, annotations


def get_long_candidate(i, annotations, candidate):
    if i == annotations['long_answer']['candidate_index']:
        label = 1
    else:
        label = 0

    # get place where long answer starts and ends in the document text
    long_start = candidate['start_token']
    long_end = candidate['end_token']
    
    return label, long_start, long_end


def form_data_row(question, label, text, long_start, long_end):
    row = {
        'question': question,
        'long_answer': ' '.join(text[long_start:long_end]),
        'is_long_answer': label,
    }
    
    return row


def preprocess_data(data):
    rows = []

    for line in data:
        if not KAGGLE_FORMAT:
            line = simplify_nq_example(line)
        question, text, annotations = get_question_and_document(line)
        for i, candidate in enumerate(line['long_answer_candidates']):
            label, long_start, long_end = get_long_candidate(i, annotations, candidate)

            if label == True or (i % SAMPLE_RATE == 0):
                rows.append(
                    form_data_row(question, label, text, long_start, long_end)
                )
        
    return pd.DataFrame(rows)

In [83]:
data = preprocess_data(data)
data.head()

Unnamed: 0,question,long_answer,is_long_answer
0,what do the 3 dots mean in math,"<Table> <Tr> <Th_colspan=""2""> ∴ </Th> </Tr> <T...",0
1,what do the 3 dots mean in math,<Tr> <Td> hyphen </Td> <Td> ‐ </Td> </Tr>,0
2,what do the 3 dots mean in math,<Tr> <Td> asterisk </Td> <Td> * </Td> </Tr>,0
3,what do the 3 dots mean in math,<Tr> <Td> ordinal indicator </Td> <Td> o a </T...,0
4,what do the 3 dots mean in math,<Tr> <Td> sound - recording copyright </Td> <T...,0


In [84]:
def remove_stopwords(sentence):
    words = sentence.split()
    words = [word for word in words if word not in stopwords.words('english')]
    
    return ' '.join(words)

def remove_html(sentence):
    html = re.compile(r'<.*?>')
    return html.sub(r'', sentence)

def clean_df_by_column(df, column):
    df[column] = df[column].apply(lambda x : remove_stopwords(x))
    df[column] = df[column].apply(lambda x : remove_html(x))
    return df

def clean_df(df):
    df = clean_df_by_column(df, 'long_answer')
    df = clean_df_by_column(df, 'question')
    return df

In [85]:
data = clean_df(data)
data.sample(5)

Unnamed: 0,question,long_answer,is_long_answer
30037,channel abc phoenix arizona,This list broadcast television stations servi...,0
69029,sang cold outside tom jones,Catatonia formed 1992 . She subsequently sang...,1
26220,list celebrities hollywood star,Tim McCoy Motion pictures 1600 Vine Stre...,0
44150,head parliament uk,Where Government lost confidence House Common...,0
62902,last time michigan basketball championship,2014 -- Nik Stauskas,0


In [86]:
data[data.is_long_answer == 1].count()

question          3771
long_answer       3771
is_long_answer    3771
dtype: int64

In [87]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [12]:
example = data.loc[0].question
example

'3 dots mean math'

In [13]:
print(tokenizer(example))
tokens = tokenizer.tokenize(example)
print(tokens)
ids = tokenizer.convert_tokens_to_ids(tokens)
print(ids)
print(tokenizer.decode(ids))

{'input_ids': [101, 1017, 14981, 2812, 8785, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1]}
['3', 'dots', 'mean', 'math']
[1017, 14981, 2812, 8785]
3 dots mean math


In [28]:
tokenizer.sep_token

'[SEP]'

In [227]:
class LongAnswerDataset(Dataset):
    def __init__(self, data, tokenizer, max_len=150):
        self._tokenizer = tokenizer
        self._max_len = max_len
        self._questions = data.question.values
        self._long_answers = data.long_answer.values
        self._targets = data.is_long_answer.values
        
    
    def __getitem__(self, idx):
        input_tokens = self._questions[idx].split()
        input_tokens.append(' ' + self._tokenizer.sep_token + ' ')
        long_answer_tokens = self._long_answers[idx].split()
        input_tokens.extend(long_answer_tokens)
        encoding = self._tokenizer(input_tokens,
                          is_split_into_words=True,
                          return_offsets_mapping=True,
                          padding='max_length',
                          truncation=True,
                          max_length=self._max_len,
                          return_tensors='pt')
        encoding.pop('token_type_ids')
        encoding.pop('offset_mapping')
        encoding.pop('attention_mask')
        return encoding, self._targets[idx]
        
        
        
    def __len__(self):
        return self._targets.shape[0]

In [164]:
data.shape

(71656, 3)

In [208]:
lad = LongAnswerDataset(test_df, tokenizer)
lad[0][0].keys()

dict_keys(['input_ids'])

In [166]:
example, target = lad[10]
# example

In [167]:
train_loader = DataLoader(lad, batch_size=4096)


In [168]:
example['input_ids'].shape, example['attention_mask'].shape

(torch.Size([1, 150]), torch.Size([1, 150]))

In [169]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [228]:
dataset = LongAnswerDataset(data[:60000], tokenizer)
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

dataset = LongAnswerDataset(data[60000:], tokenizer)
val_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [252]:
class LongAnswerClassifier(nn.Module):
    def __init__(self, vocab_size, hidden1, hidden2):
        super(LongAnswerClassifier, self).__init__()
        self.fc1 = nn.Linear(vocab_size, hidden1)
        self.fc2 = nn.Linear(hidden1, hidden2)
        self.fc3 = nn.Linear(hidden2, 1)
    
    def forward(self, inputs):
        x = F.relu(self.fc1(inputs.squeeze(1).float()))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

In [257]:
model = LongAnswerClassifier(dataset._max_len, 128, 64)
model

LongAnswerClassifier(
  (fc1): Linear(in_features=150, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=1, bias=True)
)

In [258]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=0.001)

In [259]:
model.train()
train_losses = []
for epoch in range(10):
    losses = []
    total = 0
    for inputs, target in train_loader:
        model.zero_grad()

        output = model(inputs['input_ids'])
        loss = criterion(output.squeeze(), target.float())
        
        loss.backward()
              
        nn.utils.clip_grad_norm_(model.parameters(), 3)

        optimizer.step()
        
        progress_bar.set_description(f'Loss: {loss.item():.3f}')
        
        losses.append(loss.item())
        total += 1
    
    epoch_loss = sum(losses) / total
    train_losses.append(epoch_loss)
        
    tqdm.write(f'Epoch #{epoch + 1}\tTrain Loss: {epoch_loss:.3f}')

Epoch #1	Train Loss: 36.251
Epoch #2	Train Loss: 46.888
Epoch #3	Train Loss: 29.829
Epoch #4	Train Loss: 18.933
Epoch #5	Train Loss: 15.849
Epoch #6	Train Loss: 6.891
Epoch #7	Train Loss: 4.301
Epoch #8	Train Loss: 2.132
Epoch #9	Train Loss: 2.454
Epoch #10	Train Loss: 1.378


In [269]:
correct = 0
total = 0
model.eval()
with torch.no_grad():
    for data in val_dataloader:
        inputs, labels = data
        # calculate outputs by running images through the network
        outputs = model(inputs['input_ids'])
        # the class with the highest energy is what we choose as prediction
        predicted = torch.sigmoid(outputs)
        predicted = predicted.apply_(lambda x: 1 if x >= 0.5 else 0).squeeze(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        # print(predicted)
        # print(labels)
print(f'Accuracy: {100 * correct // total} %')

Accuracy: 90 %
