In [1]:
import argparse
import datetime
import json
import numpy as np
import re
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import AlbertForQuestionAnswering, AlbertTokenizerFast, XLNetForQuestionAnswering, XLNetTokenizerFast

In [2]:
xlnet_dict = torch.load('../model/xlnet.pt',map_location=torch.device("cpu"))
albert_dict = torch.load('../model/albert.pt',map_location=torch.device("cpu"))

In [7]:
# instantiating models

# xlnet
xlnet = XLNetForQuestionAnswering.from_pretrained('xlnet-base-cased')
xlnet.load_state_dict(xlnet_dict["model_state_dict"])


# albert
albert = AlbertForQuestionAnswering.from_pretrained("albert-base-v2")
albert.load_state_dict(albert_dict["model_state_dict"])

Some weights of XLNetForQuestionAnswering were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['answer_class.dense_0.bias', 'end_logits.LayerNorm.weight', 'answer_class.dense_0.weight', 'end_logits.LayerNorm.bias', 'end_logits.dense_0.bias', 'end_logits.dense_1.bias', 'end_logits.dense_1.weight', 'start_logits.dense.weight', 'answer_class.dense_1.weight', 'start_logits.dense.bias', 'end_logits.dense_0.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading model.safetensors:   0%|          | 0.00/47.4M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Some weights of AlbertForQuestionAnswering were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [3]:
class SquadDataset(Dataset):

    def __init__(self, input_path, tokenizer_checkpoint):
        """
        input_path: path that contains all the files - contexts, questions, answers, answer spans and question ids
        """
        print(f"Reading in Dataset from {input_path}")
        
        with open(input_path + "/context", encoding='utf-8') as f:
            contexts = f.read().split("\t")
        with open(input_path + "/question", encoding='utf-8') as f:
            questions = f.read().split("\t")
        with open(input_path + "/answer", encoding='utf-8') as f:
            answers = f.read().split("\t")
        with open(input_path + "/answer_span", encoding='utf-8') as f:
            spans = f.read().split("\t")
        with open(input_path + "/question_id", encoding='utf-8') as f:
            qids = f.read().split("\t")

        self.contexts = [ctx.strip() for ctx in contexts][:100]
        self.questions = [qn.strip() for qn in questions][:100]
        self.answers = [ans.strip() for ans in answers][:100]
        self.spans = [span.strip().split() for span in spans][:100]
        self.start_indices = [int(x[0]) for x in self.spans]
        self.end_indices = [int(x[1]) for x in self.spans]
        self.qids = [qid.strip() for qid in qids][:100]

        # intialise XLNetTokenizerFast for input tokenization
        if tokenizer_checkpoint == "xlnet-base-cased":
          self.tokenizer = XLNetTokenizerFast.from_pretrained("xlnet-base-cased")
        elif tokenizer_checkpoint == "albert-base-v2":
          self.tokenizer = AlbertTokenizerFast.from_pretrained("albert-base-v2")
        self.tokenizer.padding_side = "right"

        # extract tokenization outputs
        self.tokenizer_dict = self.tokenize()
        self.sample_mapping, self.offset_mapping = self.preprocess()

        self.input_ids = self.tokenizer_dict["input_ids"]
        self.token_type_ids = self.tokenizer_dict["token_type_ids"]
        self.attention_mask = self.tokenizer_dict["attention_mask"]


    def tokenize(self, max_length=384, doc_stride=128):
        """
        inputs:
        1. max_length: specifies the length of the tokenized text
        2. doc_stride: defines the number of overlapping tokens

        output:
        1. tokenizer_dict, which contains
        - input_ids: list of integer values representing the tokenized text; each integer corresponds to a specific token
        - token_type_ids: to distinguish between question and context
        - attention_mask: a binary mask that tells the model which tokens to mask/not mask
        - sample_mapping: map from a feature to its corresponding example, since one question-context pair might give several features
        - offset_mapping: maps each input id with the corresponding start and end characters in the original text

        Tokenize examples (question-context pairs) with truncation and padding, but keep the overflows using a stride specified by `doc_stride`. 
        When the question-context input exceeds the `max_length`, it will contain more than one feature, and each of these features will have context
        that overlaps a bit with the previous features, and the overlapping is determined by `doc_stride`. This is to ensure that although truncation
        is performed, these overflows will ensure that no answer is missed as long as the answer span is shorter than the length of the overlap.
        """
        print("Performing tokenization on dataset")
        tokenizer_dict = self.tokenizer(
            self.questions,
            self.contexts,
            truncation="only_second",
            padding="max_length",
            max_length=max_length,
            stride=doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True
        )
        return tokenizer_dict

    def preprocess(self):
        """
        This functions is to preprocess the outputs of the tokenizer dictionary.
        Due to the possibility that an example has multiple features, this functions ensure that the start_positions and end_positions are mapped
        correctly
        """
        print("Preprocessing Dataset")

        sample_mapping = self.tokenizer_dict.pop("overflow_to_sample_mapping")
        offset_mapping = self.tokenizer_dict.pop("offset_mapping")

        self.tokenizer_dict["start_positions"] = []
        self.tokenizer_dict["end_positions"] = []

        for i, offsets in enumerate(offset_mapping):
            input_ids = self.tokenizer_dict["input_ids"][i]
            cls_index = input_ids.index(self.tokenizer.cls_token_id)
            sequence_ids = self.tokenizer_dict.sequence_ids(i)

            sample_index = sample_mapping[i]
            answer = self.answers[sample_index]
            start_char = self.start_indices[sample_index]
            end_char = self.end_indices[sample_index]

            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1

            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            # if answer is out of the span
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                self.tokenizer_dict["start_positions"].append(cls_index)
                self.tokenizer_dict["end_positions"].append(cls_index)
            else:
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                self.tokenizer_dict["start_positions"].append(token_start_index - 1)

                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                self.tokenizer_dict["end_positions"].append(token_end_index + 1)
        return sample_mapping, offset_mapping


    def __len__(self):
        """
        Return the number of features in the data
        """
        return len(self.sample_mapping)

    def __getitem__(self, i):

        og_index = self.sample_mapping[i]

        item_dict = {
            "input_ids": torch.tensor(self.input_ids[i]),
            "attention_mask" : torch.tensor(self.attention_mask[i]),
            "start_positions" : torch.tensor(self.tokenizer_dict["start_positions"][i]),
            "end_positions" : torch.tensor(self.tokenizer_dict["end_positions"][i]),
            "og_indices": og_index,
            "og_contexts": self.contexts[og_index],
            "og_questions": self.questions[og_index],
            "og_answers": self.answers[og_index],
            "og_start_indices": self.start_indices[og_index],
            "og_end_indices": self.end_indices[og_index],
            "offset_mapping": torch.tensor(self.offset_mapping[i]),
            "og_question_ids": self.qids[og_index]

        }
        return item_dict

In [4]:
xlnet_data = SquadDataset("../data/curated/test_data", "xlnet-base-cased")
albert_data = SquadDataset("../data/curated/test_data", "albert-base-v2")

Reading in Dataset from ../data/curated/test_data
Performing tokenization on dataset
Preprocessing Dataset
Reading in Dataset from ../data/curated/test_data
Performing tokenization on dataset
Preprocessing Dataset


In [5]:
def test_albert(model, dataset, n_best_size=20, max_answer_length=30, device='cpu'):
    model.eval()

    test_dataloader = DataLoader(dataset, batch_size=16, shuffle=False)

    pred = {}

    print("Making Predictions on Test Dataset")
    with torch.no_grad():
        for data in test_dataloader:
            input_ids = data["input_ids"].to(device)
            attention_mask = data["attention_mask"].to(device)
            start = data["start_positions"].to(device)
            end = data["end_positions"].to(device)

            output = model(input_ids=input_ids, attention_mask=attention_mask)

            offset_mapping = data["offset_mapping"]
            context = data["og_contexts"]
            answer = data["og_answers"]
            question = data["og_questions"]
            qids = data["og_question_ids"]

            for i in range(len(input_ids)):

                start_logits = F.softmax(output.start_logits[i], dim=0).cpu().detach().numpy()
                end_logits = F.softmax(output.end_logits[i], dim=0).cpu().detach().numpy()
                start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
                end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()

                offsets = offset_mapping[i]
                ctxt = context[i]
                qid = qids[i]
                ans = answer[i]

                start_candidates = {}
                end_candidates = {}

                for start in start_indexes:
                  logits = start_logits[start]
                  start_char = offsets[start][0].item()
                  if start_candidates.get(start_char) == None or start_candidates.get(start_char) < logits:
                    start_candidates[start_char] = logits

                for end in end_indexes:
                  logits = end_logits[end]
                  end_char = offsets[end][1].item()
                  if end_candidates.get(end_char) == None or end_candidates.get(end_char) < logits:
                    end_candidates[end_char] = logits

                pred[(qid, ans, ctxt)] = {"start": start_candidates, "end": end_candidates}

    return pred

def test_xlnet(model, dataset, n_best_size=20, max_answer_length=30, device='cpu'):
    model.eval()

    test_dataloader = DataLoader(dataset, batch_size=16, shuffle=False)

    pred = {}

    print("Making Predictions on Test Dataset")
    with torch.no_grad():
        for data in test_dataloader:
            input_ids = data["input_ids"].to(device)
            attention_mask = data["attention_mask"].to(device)
            start = data["start_positions"].to(device)
            end = data["end_positions"].to(device)

            output = model(input_ids=input_ids, attention_mask=attention_mask)

            offset_mapping = data["offset_mapping"]
            context = data["og_contexts"]
            answer = data["og_answers"]
            question = data["og_questions"]
            qids = data["og_question_ids"]

            for i in range(len(input_ids)):
                start_logits = output.start_top_log_probs[i].cpu().detach().numpy()
                end_logits = output.end_top_log_probs[i].cpu().detach().numpy()
                start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
                end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()

                start_top_indexes = output.start_top_index[i]
                end_top_indexes = output.end_top_index[i]

                offsets = offset_mapping[i]
                ctxt = context[i]
                qid = qids[i]
                ans = answer[i]

                start_candidates = {}
                end_candidates = {}
                for start in start_indexes:
                  logits = start_logits[start]
                  start_index = start_top_indexes[start]
                  start_char = offsets[start_index][0].item()
                  if start_candidates.get(start_char) == None or start_candidates.get(start_char) < logits:
                    start_candidates[start_char] = logits

                for end in end_indexes:
                  logits = end_logits[end]
                  end_index = end_top_indexes[end]
                  end_char = offsets[end_index][1].item()
                  if end_candidates.get(end_char) == None or end_candidates.get(end_char) < logits:
                    end_candidates[end_char] = logits

                pred[(qid, ans, ctxt)] = {"start": start_candidates, "end": end_candidates}

    return pred

In [8]:
n_best_size = 20
max_answer_length = 30

In [9]:
albert_pred = test_albert(albert, albert_data, n_best_size, max_answer_length=30, device='cpu')
xlnet_pred = test_xlnet(xlnet, xlnet_data, n_best_size, max_answer_length=30, device='cpu')

Making Predictions on Test Dataset
Making Predictions on Test Dataset


In [13]:
albert_pred.keys() == xlnet_pred.keys()

True

## Test on One Question

In [17]:
example_key = next(iter(albert_pred.keys()))
example_key

('5729281baf94a219006aa122',
 'Morocco and Ethiopia',
 "Kenya is active in several sports, among them cricket, rallying, football, rugby union and boxing. The country is known chiefly for its dominance in middle-distance and long-distance athletics, having consistently produced Olympic and Commonwealth Games champions in various distance events, especially in 800 m, 1,500 m, 3,000 m steeplechase, 5,000 m, 10,000 m and the marathon. Kenyan athletes (particularly Kalenjin) continue to dominate the world of distance running, although competition from Morocco and Ethiopia has reduced this supremacy. Kenya's best-known athletes included the four-time women's Boston Marathon winner and two-time world champion Catherine Ndereba, 800m world record holder David Rudisha, former Marathon world record-holder Paul Tergat, and John Ngugi.")

In [23]:
print(albert_pred[example_key]["start"])
print(albert_pred[example_key]["end"])

{498: 0.9720771, 481: 0.017689329, 223: 0.0024769614, 510: 0.0023586175, 493: 0.0015495653, 410: 0.0009279516, 380: 0.00050270767, 149: 0.00038564613, 472: 0.0002784642, 432: 0.00019917496, 506: 0.00012109086, 235: 0.00010860085, 454: 0.00010008272, 598: 8.6340566e-05, 397: 7.65662e-05, 169: 6.5610875e-05, 445: 5.91012e-05, 214: 5.425412e-05, 547: 4.9051225e-05, 523: 4.2499672e-05}
{518: 0.9825853, 505: 0.0047336314, 545: 0.0024236734, 418: 0.0018414374, 546: 0.001636124, 247: 0.0009028191, 386: 0.0004802582, 470: 0.0004641103, 192: 0.00036206061, 509: 0.00031907737, 395: 0.00031776095, 522: 0.00028282977, 530: 0.00028102836, 263: 0.0002704699, 492: 0.00022276444, 552: 0.00019446915, 164: 0.00018576755, 378: 0.00016928012, 779: 0.00016635445, 603: 0.00013815687}


In [24]:
print(xlnet_pred[example_key]["start"])
print(xlnet_pred[example_key]["end"])

{498: 0.9991967, 481: 0.00046435793, 510: 0.00020733097, 149: 5.7797803e-05, 493: 1.7289407e-05}
{518: 0.9999844, 247: 0.36447152, 253: 0.13642289, 192: 0.11024575, 164: 0.107844435, 155: 0.1025703, 522: 0.00777565, 505: 0.0016211064, 509: 0.00014509284, 530: 0.00013444948, 470: 3.541387e-05}


In [31]:
common_start_keys = set(albert_pred[example_key]["start"]).intersection(xlnet_pred[example_key]["start"])
common_end_keys = set(albert_pred[example_key]["end"]).intersection(xlnet_pred[example_key]["end"])

In [32]:
print(common_start_keys)
print(common_end_keys)

{481, 493, 498, 149, 510}
{192, 164, 518, 522, 530, 470, 247, 505, 509}


In [36]:
example_key[2][498:518]

'Morocco and Ethiopia'

In [37]:
common_start = {}
common_end = {}

In [38]:
for start in common_start_keys:
    multiplied_score = albert_pred[example_key]["start"][start] * xlnet_pred[example_key]["start"][start]
    common_start[start] = multiplied_score

In [39]:
common_start

{481: 8.21418e-06,
 493: 2.6791065e-08,
 498: 0.9712962,
 149: 2.2289498e-08,
 510: 4.8901444e-07}

In [40]:
for end in common_end_keys:
    multiplied_score = albert_pred[example_key]["end"][end] * xlnet_pred[example_key]["end"][end]
    common_end[end] = multiplied_score

In [41]:
common_end

{192: 3.9915645e-05,
 164: 2.0033996e-05,
 518: 0.98257,
 522: 2.1991852e-06,
 530: 3.778412e-08,
 470: 1.643594e-08,
 247: 0.00032905187,
 505: 7.673721e-06,
 509: 4.6295842e-08}

In [45]:
highest_start = max(common_start.items(), key=lambda x:x[1])[0]
highest_end = max(common_end.items(), key=lambda x:x[1])[0]
print(highest_start, highest_end)

498 518


In [47]:
example_key[2][highest_start:highest_end]

'Morocco and Ethiopia'

In [58]:
valid_ans = []
for start in common_start.keys():
    for end in common_end.keys():
        if (end < start) or (end - start + 1) > max_answer_length:
            continue
        if start <= end:
            valid_ans.append({
                "score":common_start[start] + common_end[end],
                "text":example_key[2][start:end]
            })

In [63]:
sorted(valid_ans, key=lambda x: x["score"], reverse=True)[0]

{'score': 1.9538662, 'text': 'Morocco and Ethiopia'}

## Multiplication

In [67]:
def get_answer(model1_pred, model2_pred):
    assert model1_pred.keys() == model2_pred.keys(), "Predictions are not on the same dataset"
    final_predictions = {}
    questions = model1_pred.keys()
    for question in questions:
        qid, actual_ans, context = question
        common_start_keys = set(model1_pred[example_key]["start"]).intersection(model2_pred[example_key]["start"])
        common_end_keys = set(model1_pred[example_key]["end"]).intersection(model2_pred[example_key]["end"])
        
        common_start = {}
        common_end = {}
        for start in common_start_keys:
            multiplied_score = albert_pred[example_key]["start"][start] * xlnet_pred[example_key]["start"][start]
            common_start[start] = multiplied_score
        for end in common_end_keys:
            multiplied_score = albert_pred[example_key]["end"][end] * xlnet_pred[example_key]["end"][end]
            common_end[end] = multiplied_score
#         highest_start = max(common_start.items(), key=lambda x:x[1])[0]
#         highest_end = max(common_end.items(), key=lambda x:x[1])[0]
#         highest_score_ans = context[highest_start:highest_end]
#         final_predictions.update({actual_ans:highest_score_ans})

        valid_ans = []
        for start in common_start.keys():
            for end in common_end.keys():
                if (end < start) or (end - start + 1) > max_answer_length:
                    continue
                if start <= end:
                    valid_ans.append({
                        "score":common_start[start] + common_end[end],
                        "text":context[start:end]
                    })
        final_pred = sorted(valid_ans, key=lambda x: x["score"], reverse=True)[0]["text"]
        final_predictions.update({actual_ans:final_pred})
    return final_predictions

In [68]:
final_pred = get_answer(albert_pred, xlnet_pred)

In [69]:
final_pred

{'Morocco and Ethiopia': 'Morocco and Ethiopia',
 'The European Court of Justice': 'of Justice is the hi',
 'Waal': ', Nieuwe Maas ("New ',
 'the entertainment division': 'edesigned as part of',
 'fifty thousand dollars': 'f work, Tesla fulfil',
 '1895': 'a mixture of acetyle',
 'November 1979': 'ilitaries. By 1979, ',
 'the Privy Council': ' justice. Cases invo',
 'ideological struggle': 'ed HT as their key i',
 'Peter Howell': '05, Murray Gold prov',
 'Neutrophils and macrophages': 'ene of infection. Ma',
 'newspaper editor': '',
 'decline of organized labor': 'l pattern is clear; ',
 'immunomodulators': 'one and vitamin D.',
 'Brandon Marshall': 'ackles with 109, whi',
 'February 1, 2016': '',
 'unpaired electrons': 'net.[c]',
 'compounds of oxygen with a high oxidative': 'rates, and dichromat',
 'eight rows': 'rbances created by t',
 'BBC Three': 'alekmania" period (c',
 'Santa Clara Marriott.': '',
 'Jon Pertwee': 'ctorin\' the Tardis" ',
 'the United Kingdom, Australia, Canada an