In [2]:
import argparse
import datetime
import json
import numpy as np
import re
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import AlbertForQuestionAnswering, AlbertTokenizerFast, XLNetForQuestionAnswering, XLNetTokenizerFast

In [3]:
xlnet_dict = torch.load('../model/xlnet.pt',map_location=torch.device("cpu"))
albert_dict = torch.load('../model/albert.pt',map_location=torch.device("cpu"))

In [4]:
# instantiating models

# xlnet
xlnet = XLNetForQuestionAnswering.from_pretrained('xlnet-base-cased')
xlnet.load_state_dict(xlnet_dict["model_state_dict"])


# albert
albert = AlbertForQuestionAnswering.from_pretrained("albert-base-v2")
albert.load_state_dict(albert_dict["model_state_dict"])

Some weights of XLNetForQuestionAnswering were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['end_logits.LayerNorm.weight', 'start_logits.dense.weight', 'end_logits.LayerNorm.bias', 'answer_class.dense_0.weight', 'end_logits.dense_1.weight', 'answer_class.dense_1.weight', 'end_logits.dense_1.bias', 'answer_class.dense_0.bias', 'end_logits.dense_0.bias', 'start_logits.dense.bias', 'end_logits.dense_0.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of AlbertForQuestionAnswering were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [5]:
class SquadDataset(Dataset):

    def __init__(self, input_path, tokenizer_checkpoint):
        """
        input_path: path that contains all the files - contexts, questions, answers, answer spans and question ids
        """
        print(f"Reading in Dataset from {input_path}")
        
        with open(input_path + "/context", encoding='utf-8') as f:
            contexts = f.read().split("\t")
        with open(input_path + "/question", encoding='utf-8') as f:
            questions = f.read().split("\t")
        with open(input_path + "/answer", encoding='utf-8') as f:
            answers = f.read().split("\t")
        with open(input_path + "/answer_span", encoding='utf-8') as f:
            spans = f.read().split("\t")
        with open(input_path + "/question_id", encoding='utf-8') as f:
            qids = f.read().split("\t")

        self.contexts = [ctx.strip() for ctx in contexts][:100]
        self.questions = [qn.strip() for qn in questions][:100]
        self.answers = [ans.strip() for ans in answers][:100]
        self.spans = [span.strip().split() for span in spans][:100]
        self.start_indices = [int(x[0]) for x in self.spans]
        self.end_indices = [int(x[1]) for x in self.spans]
        self.qids = [qid.strip() for qid in qids][:100]

        # intialise XLNetTokenizerFast for input tokenization
        if tokenizer_checkpoint == "xlnet-base-cased":
          self.tokenizer = XLNetTokenizerFast.from_pretrained("xlnet-base-cased")
        elif tokenizer_checkpoint == "albert-base-v2":
          self.tokenizer = AlbertTokenizerFast.from_pretrained("albert-base-v2")
        self.tokenizer.padding_side = "right"

        # extract tokenization outputs
        self.tokenizer_dict = self.tokenize()
        self.sample_mapping, self.offset_mapping = self.preprocess()

        self.input_ids = self.tokenizer_dict["input_ids"]
        self.token_type_ids = self.tokenizer_dict["token_type_ids"]
        self.attention_mask = self.tokenizer_dict["attention_mask"]


    def tokenize(self, max_length=384, doc_stride=128):
        """
        inputs:
        1. max_length: specifies the length of the tokenized text
        2. doc_stride: defines the number of overlapping tokens

        output:
        1. tokenizer_dict, which contains
        - input_ids: list of integer values representing the tokenized text; each integer corresponds to a specific token
        - token_type_ids: to distinguish between question and context
        - attention_mask: a binary mask that tells the model which tokens to mask/not mask
        - sample_mapping: map from a feature to its corresponding example, since one question-context pair might give several features
        - offset_mapping: maps each input id with the corresponding start and end characters in the original text

        Tokenize examples (question-context pairs) with truncation and padding, but keep the overflows using a stride specified by `doc_stride`. 
        When the question-context input exceeds the `max_length`, it will contain more than one feature, and each of these features will have context
        that overlaps a bit with the previous features, and the overlapping is determined by `doc_stride`. This is to ensure that although truncation
        is performed, these overflows will ensure that no answer is missed as long as the answer span is shorter than the length of the overlap.
        """
        print("Performing tokenization on dataset")
        tokenizer_dict = self.tokenizer(
            self.questions,
            self.contexts,
            truncation="only_second",
            padding="max_length",
            max_length=max_length,
            stride=doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True
        )
        return tokenizer_dict

    def preprocess(self):
        """
        This functions is to preprocess the outputs of the tokenizer dictionary.
        Due to the possibility that an example has multiple features, this functions ensure that the start_positions and end_positions are mapped
        correctly
        """
        print("Preprocessing Dataset")

        sample_mapping = self.tokenizer_dict.pop("overflow_to_sample_mapping")
        offset_mapping = self.tokenizer_dict.pop("offset_mapping")

        self.tokenizer_dict["start_positions"] = []
        self.tokenizer_dict["end_positions"] = []

        for i, offsets in enumerate(offset_mapping):
            input_ids = self.tokenizer_dict["input_ids"][i]
            cls_index = input_ids.index(self.tokenizer.cls_token_id)
            sequence_ids = self.tokenizer_dict.sequence_ids(i)

            sample_index = sample_mapping[i]
            answer = self.answers[sample_index]
            start_char = self.start_indices[sample_index]
            end_char = self.end_indices[sample_index]

            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1

            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            # if answer is out of the span
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                self.tokenizer_dict["start_positions"].append(cls_index)
                self.tokenizer_dict["end_positions"].append(cls_index)
            else:
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                self.tokenizer_dict["start_positions"].append(token_start_index - 1)

                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                self.tokenizer_dict["end_positions"].append(token_end_index + 1)
        return sample_mapping, offset_mapping


    def __len__(self):
        """
        Return the number of features in the data
        """
        return len(self.sample_mapping)

    def __getitem__(self, i):

        og_index = self.sample_mapping[i]

        item_dict = {
            "input_ids": torch.tensor(self.input_ids[i]),
            "attention_mask" : torch.tensor(self.attention_mask[i]),
            "start_positions" : torch.tensor(self.tokenizer_dict["start_positions"][i]),
            "end_positions" : torch.tensor(self.tokenizer_dict["end_positions"][i]),
            "og_indices": og_index,
            "og_contexts": self.contexts[og_index],
            "og_questions": self.questions[og_index],
            "og_answers": self.answers[og_index],
            "og_start_indices": self.start_indices[og_index],
            "og_end_indices": self.end_indices[og_index],
            "offset_mapping": torch.tensor(self.offset_mapping[i]),
            "og_question_ids": self.qids[og_index]

        }
        return item_dict

In [6]:
xlnet_data = SquadDataset("../data/curated/test_data", "xlnet-base-cased")
albert_data = SquadDataset("../data/curated/test_data", "albert-base-v2")

Reading in Dataset from ../data/curated/test_data
Performing tokenization on dataset
Preprocessing Dataset
Reading in Dataset from ../data/curated/test_data
Performing tokenization on dataset
Preprocessing Dataset


In [7]:
def test_albert(model, dataset, n_best_size=20, max_answer_length=30, device='cpu'):
    model.eval()

    test_dataloader = DataLoader(dataset, batch_size=16, shuffle=False)

    pred = {}

    print("Making Predictions on Test Dataset")
    with torch.no_grad():
        for data in test_dataloader:
            input_ids = data["input_ids"].to(device)
            attention_mask = data["attention_mask"].to(device)
            start = data["start_positions"].to(device)
            end = data["end_positions"].to(device)

            output = model(input_ids=input_ids, attention_mask=attention_mask)

            offset_mapping = data["offset_mapping"]
            context = data["og_contexts"]
            answer = data["og_answers"]
            question = data["og_questions"]
            qids = data["og_question_ids"]

            for i in range(len(input_ids)):

                start_logits = F.softmax(output.start_logits[i], dim=0).cpu().detach().numpy()
                end_logits = F.softmax(output.end_logits[i], dim=0).cpu().detach().numpy()
                start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
                end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()

                offsets = offset_mapping[i]
                ctxt = context[i]
                qid = qids[i]
                ans = answer[i]

                start_candidates = {}
                end_candidates = {}

                for start in start_indexes:
                  logits = start_logits[start]
                  start_char = offsets[start][0].item()
                  if start_candidates.get(start_char) == None or start_candidates.get(start_char) < logits:
                    start_candidates[start_char] = logits

                for end in end_indexes:
                  logits = end_logits[end]
                  end_char = offsets[end][1].item()
                  if end_candidates.get(end_char) == None or end_candidates.get(end_char) < logits:
                    end_candidates[end_char] = logits

                pred[(qid, ans, ctxt)] = {"start": start_candidates, "end": end_candidates}

    return pred

def test_xlnet(model, dataset, n_best_size=20, max_answer_length=30, device='cpu'):
    model.eval()

    test_dataloader = DataLoader(dataset, batch_size=16, shuffle=False)

    pred = {}

    print("Making Predictions on Test Dataset")
    with torch.no_grad():
        for data in test_dataloader:
            input_ids = data["input_ids"].to(device)
            attention_mask = data["attention_mask"].to(device)
            start = data["start_positions"].to(device)
            end = data["end_positions"].to(device)

            output = model(input_ids=input_ids, attention_mask=attention_mask)

            offset_mapping = data["offset_mapping"]
            context = data["og_contexts"]
            answer = data["og_answers"]
            question = data["og_questions"]
            qids = data["og_question_ids"]

            for i in range(len(input_ids)):
                start_logits = output.start_top_log_probs[i].cpu().detach().numpy()
                end_logits = output.end_top_log_probs[i].cpu().detach().numpy()
                start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
                end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()

                start_top_indexes = output.start_top_index[i]
                end_top_indexes = output.end_top_index[i]

                offsets = offset_mapping[i]
                ctxt = context[i]
                qid = qids[i]
                ans = answer[i]

                start_candidates = {}
                end_candidates = {}
                for start in start_indexes:
                  logits = start_logits[start]
                  start_index = start_top_indexes[start]
                  start_char = offsets[start_index][0].item()
                  if start_candidates.get(start_char) == None or start_candidates.get(start_char) < logits:
                    start_candidates[start_char] = logits

                for end in end_indexes:
                  logits = end_logits[end]
                  end_index = end_top_indexes[end]
                  end_char = offsets[end_index][1].item()
                  if end_candidates.get(end_char) == None or end_candidates.get(end_char) < logits:
                    end_candidates[end_char] = logits

                pred[(qid, ans, ctxt)] = {"start": start_candidates, "end": end_candidates}

    return pred

In [8]:
n_best_size = 20
max_answer_length = 50

In [9]:
albert_pred = test_albert(albert, albert_data, n_best_size, max_answer_length=30, device='cpu')
xlnet_pred = test_xlnet(xlnet, xlnet_data, n_best_size, max_answer_length=30, device='cpu')

Making Predictions on Test Dataset
Making Predictions on Test Dataset


In [11]:
albert_pred.keys() == xlnet_pred.keys()

True

## Test on One Question

In [None]:
example_key = next(iter(albert_pred.keys()))
example_key

In [None]:
print(albert_pred[example_key]["start"])
print(albert_pred[example_key]["end"])

In [None]:
print(xlnet_pred[example_key]["start"])
print(xlnet_pred[example_key]["end"])

In [None]:
common_start_keys = set(albert_pred[example_key]["start"]).intersection(xlnet_pred[example_key]["start"])
common_end_keys = set(albert_pred[example_key]["end"]).intersection(xlnet_pred[example_key]["end"])

In [None]:
print(common_start_keys)
print(common_end_keys)

In [None]:
example_key[2][498:518]

In [None]:
common_start = {}
common_end = {}

In [None]:
for start in common_start_keys:
    multiplied_score = albert_pred[example_key]["start"][start] * xlnet_pred[example_key]["start"][start]
    common_start[start] = multiplied_score

In [None]:
common_start

In [None]:
for end in common_end_keys:
    multiplied_score = albert_pred[example_key]["end"][end] * xlnet_pred[example_key]["end"][end]
    common_end[end] = multiplied_score

In [None]:
common_end

In [None]:
highest_start = max(common_start.items(), key=lambda x:x[1])[0]
highest_end = max(common_end.items(), key=lambda x:x[1])[0]
print(highest_start, highest_end)

In [None]:
example_key[2][highest_start:highest_end]

In [None]:
valid_ans = []
for start in common_start.keys():
    for end in common_end.keys():
        if (end < start) or (end - start + 1) > max_answer_length:
            continue
        if start <= end:
            valid_ans.append({
                "score":common_start[start] + common_end[end],
                "text":example_key[2][start:end]
            })

In [None]:
sorted(valid_ans, key=lambda x: x["score"], reverse=True)[0]

## Multiplication

In [12]:
def get_answer(model1_pred, model2_pred):
    assert model1_pred.keys() == model2_pred.keys(), "Predictions are not on the same dataset"
    final_predictions = {}
    questions = model1_pred.keys()
    for question in questions:
        qid, actual_ans, context = question
        common_start_keys = set(model1_pred[question]["start"]).intersection(model2_pred[question]["start"])
        common_end_keys = set(model1_pred[question]["end"]).intersection(model2_pred[question]["end"])
        
        common_start = {}
        common_end = {}
        for start in common_start_keys:
            multiplied_score = albert_pred[question]["start"][start] * xlnet_pred[question]["start"][start]
            common_start[start] = multiplied_score
        for end in common_end_keys:
            multiplied_score = albert_pred[question]["end"][end] * xlnet_pred[question]["end"][end]
            common_end[end] = multiplied_score
#         highest_start = max(common_start.items(), key=lambda x:x[1])[0]
#         highest_end = max(common_end.items(), key=lambda x:x[1])[0]
#         highest_score_ans = context[highest_start:highest_end]
#         final_predictions.update({actual_ans:highest_score_ans})

        valid_ans = []
        for start in common_start.keys():
            for end in common_end.keys():
                if (end < start) or (end - start + 1) > max_answer_length:
                    continue
                if start <= end:
                    valid_ans.append({
                        "score":common_start[start] + common_end[end],
                        "text":context[start:end]
                    })
        final_pred = sorted(valid_ans, key=lambda x: x["score"], reverse=True)[0]["text"]
        final_predictions.update({qid:final_pred})
    return final_predictions

In [13]:
final_pred = get_answer(albert_pred, xlnet_pred)

In [14]:
final_pred

{'5729281baf94a219006aa122': 'Morocco and Ethiopia',
 '57268bf9dd62a815002e890c': 'The European Court of Justice',
 '572ff7ab04bcaa1900d76f53': 'Boven Merwede',
 '57277cf6dd62a815002e9e78': 'entertainment division',
 '56e0d6367aa994140058e773': 'fifty thousand dollars',
 '571c9348dd7acb1400e4c116': '1895',
 '57265526708984140094c2c1': 'November 1979',
 '57287ddf3acd2414000dfa3f': 'Privy Council',
 '57302efe04bcaa1900d772f6': 'to change Muslim public opinion',
 '5728177f2ca10214002d9db0': 'Peter Howell',
 '5729081d3f37b31900477fab': 'Neutrophils and macrophages',
 '56e11afbcd28a01900c675c8': 'newspaper editor',
 '5729e4291d04691400779653': 'decline of organized labor',
 '5729f9953f37b3190047861f': 'immunomodulators',
 '56beb6533aeaaa14008c9290': 'Brandon Marshall',
 '56d9b4ebdc89441400fdb70c': 'February 1, 2016',
 '571cc3dedd7acb1400e4c148': 'unpaired electrons',
 '571a50df4faf5e1900b8a962': 'Combustion hazards',
 '572659ea5951b619008f7051': 'eight',
 '57282036ff5b5019007d9da0': 'BBC On