In [None]:
import json
import pandas as pd
import re
import numpy as np
import torch
import random
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertForSequenceClassification
from sklearn.metrics import accuracy_score, f1_score
import matplotlib.pyplot as plt
from IPython.display import clear_output
import time

In [None]:
with open('v1.0-simplified_nq-dev-all.jsonl', 'r') as json_file:
    json_list = list(json_file)

data = []
for json_str in json_list:
    data.append(json.loads(json_str))

In [None]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels = 2
)

In [None]:
lr = 1e-05
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

In [None]:
import re

import pandas as pd
from nltk.corpus import stopwords
from torch.utils.data import Dataset

class LongAnswerDataset(Dataset):
    SAMPLE_RATE = 15
    
    def __init__(self, data, tokenizer, max_len=150, kaggle_format=True):
        self._tokenizer = tokenizer
        self._max_len = max_len
        self._kaggle_format = kaggle_format
        
        data = self._preprocess_data(data)
        data = self._clean_df(data)
        self._questions = data.question.values
        self._long_answers = data.long_answer.values
        self._targets = data.is_long_answer.values
        
        
    def _get_nq_tokens(self, simplified_nq_example):
        if "document_text" not in simplified_nq_example:
            raise ValueError("`get_nq_tokens` should be called on a simplified NQ"
                         "example that contains the `document_text` field.")

        return simplified_nq_example["document_text"].split(" ")
    
    def _clean_token(self, token):
        return re.sub(u" ", "_", token["token"])

    def _remove_html_byte_offsets(self, span):
        if "start_byte" in span:
            del span["start_byte"]

        if "end_byte" in span:
            del span["end_byte"]

        return span

    def _clean_annotation(self, annotation):
        annotation["long_answer"] = self._remove_html_byte_offsets(
            annotation["long_answer"])
        annotation["short_answers"] = [
            self._remove_html_byte_offsets(sa) for sa in annotation["short_answers"]
        ]
        return annotation
    
    def _simplify_nq_example(self, nq_example):
        text = " ".join([self._clean_token(t) for t in nq_example["document_tokens"]])

        simplified_nq_example = {
          "question_text": nq_example["question_text"],
          "example_id": nq_example["example_id"],
          "document_url": nq_example["document_url"],
          "document_text": text,
          "long_answer_candidates": [
              self._remove_html_byte_offsets(c)
              for c in nq_example["long_answer_candidates"]
          ],
          "annotations": [self._clean_annotation(a) for a in nq_example["annotations"]]
        }

        if len(self._get_nq_tokens(simplified_nq_example)) != len(
          nq_example["document_tokens"]):
            raise ValueError("Incorrect number of tokens.")

        return simplified_nq_example
    
    def _get_question_and_document(self, line):
        question = line['question_text']
        text = line['document_text'].split(' ')
        annotations = line['annotations'][0]

        return question, text, annotations


    def _get_long_candidate(self, i, annotations, candidate):
        if i == annotations['long_answer']['candidate_index']:
            label = 1
        else:
            label = 0

        # get place where long answer starts and ends in the document text
        long_start = candidate['start_token']
        long_end = candidate['end_token']

        return label, long_start, long_end


    def _form_data_row(self, question, label, text, long_start, long_end):
        row = {
            'question': question,
            'long_answer': ' '.join(text[long_start:long_end]),
            'is_long_answer': label,
        }

        return row


    def _preprocess_data(self, data):
        rows = []

        for line in data:
            if not self._kaggle_format:
                line = self._simplify_nq_example(line)
            question, text, annotations = self._get_question_and_document(line)
            for i, candidate in enumerate(line['long_answer_candidates']):
                label, long_start, long_end = self._get_long_candidate(i, annotations, candidate)

                if label == True or (i % self.SAMPLE_RATE == 0):
                    rows.append(
                        self._form_data_row(question, label, text, long_start, long_end)
                    )

        return pd.DataFrame(rows)
    
    def _remove_stopwords(self, sentence):
        words = sentence.split()
        words = [word for word in words if word not in stopwords.words('english')]

        return ' '.join(words)

    def _remove_html(self, sentence):
        html = re.compile(r'<.*?>')
        return html.sub(r'', sentence)

    def _clean_df_by_column(self, df, column):
        df[column] = df[column].apply(lambda x : self._remove_stopwords(x))
        df[column] = df[column].apply(lambda x : self._remove_html(x))
        return df

    def _clean_df(self, df):
        df = self._clean_df_by_column(df, 'long_answer')
        df = self._clean_df_by_column(df, 'question')
        return df
        
    
    def __getitem__(self, idx):
        input_tokens = self._questions[idx].split()
        input_tokens.append(' ' + self._tokenizer.sep_token + ' ')
        long_answer_tokens = self._long_answers[idx].split()
        input_tokens.extend(long_answer_tokens)
        encoding = self._tokenizer(input_tokens,
                          is_split_into_words=True,
                          return_offsets_mapping=False,
                          return_token_type_ids=False,
                          padding='max_length',
                          truncation=True,
                          max_length=self._max_len,
                          return_tensors='pt')
        return encoding, self._targets[idx]
        
    
        
    def __len__(self):
        return self._targets.shape[0]

In [None]:
max_len = 312
dataset = LongAnswerDataset(data[:7000], tokenizer, max_len, kaggle_format=False)
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

dataset = LongAnswerDataset(data[7000:], tokenizer, max_len, kaggle_format=False)
val_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
class LongAnswerModel():
    
    def __init__(self, model, device):
        self.model = model
        self.device = device # or only in train?
        self.model.to(device)
        
    def plot_log(self, train_losses, val_losses, val_fscores):
        clear_output()
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(7, 3))
        fig.suptitle('Training Log', fontsize=8)
        ax1.plot(train_losses)
        ax1.set_title('Train Loss', fontsize=8)
        ax1.tick_params(labelsize=6)
        ax2.plot(val_losses)
        ax2.set_title('Val Loss', fontsize=8)
        ax2.tick_params(labelsize=6)
        ax3.plot(val_fscores)
        ax3.set_title("Val F1", fontsize=8)
        ax3.tick_params(labelsize=6)
        plt.show()
        
    def validate(self, val_dataloader):
        
        self.model.eval()
        
        val_loss, val_fscore = 0, 0
            
        for batch in val_dataloader:

            tokens, labels = batch
            ids = tokens['input_ids'].to(self.device).squeeze(dim=1)
            mask = tokens['attention_mask'].to(self.device).squeeze(dim=1)
            labels = labels.to(self.device) #.squeeze(dim=1)

            output = self.model(input_ids=ids, attention_mask=mask, labels=labels)
            loss = output['loss']
            logits = output['logits']

            val_loss += loss.item()

            flattened_gold = labels.view(-1)
            active_logits = logits.view(-1, model.num_labels)
            flattened_pred = torch.argmax(active_logits, axis=1) 

            mask = labels.view(-1) != -100
            gold = torch.masked_select(flattened_gold, mask)
            pred = torch.masked_select(flattened_pred, mask)

            fscore = f1_score(gold.cpu().numpy(), pred.cpu().numpy(), average='micro')
            val_fscore += fscore

        avg_val_loss = val_loss / len(val_dataloader)
        avg_val_f1 = val_fscore / len(val_dataloader)
        return avg_val_loss, avg_val_f1
    
    def train(self, train_dataloader, val_dataloader, n_epoch, optimizer):
        
        train_losses, val_losses = [], []
        val_fscores = []
                
        for epoch in range(n_epoch):
            
            start_time = time.time()
            
            self.model.train()
            
            train_loss, train_fscore = 0, 0
            
            for batch in train_dataloader:
                
                tokens, labels = batch
                ids = tokens['input_ids'].to(self.device).squeeze(dim=1)
                mask = tokens['attention_mask'].to(self.device).squeeze(dim=1)
                labels = labels.to(self.device) # .squeeze(dim=1)
                
                output = self.model(input_ids=ids, attention_mask=mask, labels=labels)
                loss = output['loss']
                logits = output['logits']
                
                train_loss += loss.item()
                
                self.model.zero_grad()
                loss.backward()
                optimizer.step()
                
                # compute accuracy
                flattened_gold = labels.view(-1)
                active_logits = logits.view(-1, model.num_labels)
                flattened_pred = torch.argmax(active_logits, axis=1) 

                mask = labels.view(-1) != -100
                gold = torch.masked_select(flattened_gold, mask)
                pred = torch.masked_select(flattened_pred, mask)

                fscore = f1_score(gold.cpu().numpy(), pred.cpu().numpy(), average='micro')
                train_fscore += fscore
                
            avg_train_loss = train_loss / len(train_dataloader)
            avg_train_f1 = train_fscore / len(train_dataloader)
            
            avg_val_loss, avg_val_f1 = self.validate(val_dataloader)
            
            train_losses.append(avg_train_loss)
            val_losses.append(avg_val_loss)
            val_fscores.append(avg_val_f1)
            
            self.plot_log(train_losses, val_losses, val_fscores)
            
            print(f'Epoch {epoch}')
            print(f'Train loss: {avg_train_loss:.3f}')
            print(f'Train micro F1: {avg_train_f1:.3f}')
            print(f'Validation loss: {avg_val_loss:.3f}')
            print(f'Validation micro F1: {avg_val_f1:.3f}')
            curr_time = time.time() - start_time
            print(f'Epoch time: {curr_time:.3f}s')

In [None]:
answer_model = LongAnswerModel(model, device=device)

In [None]:
answer_model.train(train_dataloader, val_dataloader, 3, optimizer)