#### In the previous notebook, we tried out a simple BM25 retreiveal system and saw that it performed poorly. The average F1 score for top-5 retreival was only about 9%. Now, we will train a Dense Passage Retrieval (DPR) model and see if we can do better.

In [1]:
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizerFast
from collections import Counter
import csv
import random
from tqdm import tqdm
import psutil
import json
import os
from DPR_biencoder_simple import *

%load_ext autoreload
%autoreload 2

First, let's load the data and do filter out some of the bad passages like we did before.

In [2]:
# load the evidence passages
with open("project-data/evidence.json", "r") as train_file:
    document_store = json.load(train_file)         
print(f"Number of evidence passages: {len(document_store)}")

# load the training data insttances
with open("project-data/train-claims.json", "r") as train_file:
    train_data = json.load(train_file)
print(f"Number of training instances: {len(train_data)}")

# load the validation data instances
with open("project-data/dev-claims.json", "r") as dev_file:
    val_data = json.load(dev_file)    
print(f"Number of validation instances: {len(val_data)}")

# we remove duplicate values from the document_store dictionary (we arbitrarily keep the first one)
seen = set()
document_store_no_duplicates = {}
for key, value in document_store.items():
    if value not in seen:
        document_store_no_duplicates[key] = value
        seen.add(value)

# remove all "bad" documents from the document store, except those that occur in claim gold evidence lists, we will define "bad" documents as ones that have less than 50 characters
claim_evidence_list = [claim['evidences'] for claim in train_data.values()]
claim_evidence_list = claim_evidence_list + [claim['evidences'] for claim in val_data.values()]
claim_evidence_list = list(set([evidence for evidence_list in claim_evidence_list for evidence in evidence_list]))

document_store_cleaned = {i: evidence_text for i, evidence_text in document_store_no_duplicates.items() if len(evidence_text) >= 30 or i in claim_evidence_list}
print(f"Number of evidence passages remaining after cleaning: {len(document_store_cleaned)}")

Number of evidence passages: 1208827
Number of training instances: 1228
Number of validation instances: 154
Number of evidence passages remaining after cleaning: 1190647


In [5]:
claim_evidence_list_train = set([ev for claim in train_data.values() for ev in claim['evidences']])
claim_evidence_list_val = set([ev for claim in val_data.values() for ev in claim['evidences']])
print(f"Number of unique evidence passages in training data: {len(claim_evidence_list_train)}")
print(f"Number of unique evidence passages in validation data: {len(claim_evidence_list_val)}")

Number of unique evidence passages in training data: 3121
Number of unique evidence passages in validation data: 463


To train the DPR retriever, we need to create pairs of $(\text{claim}_i, \text{evidence passage}_{i,1})$. However each claim can have multiple evidence passages, so we will create multiples pairs: $(\text{claim}_i, \text{evidence passage}_{i,1})$, $(\text{claim}_i, \text{evidence passage}_{i,2})$, ...

Next, we prepare a minibatch of claims and corresponding passages. Then given a matrix $C$ of shape $(B,d)$ containing the batch of encoded claim vectors (where $d$ is the hidden dimensions of the encoded vectors) and a matrix $P$ of the same shape containing the batch of encoded passage vectors, we can compute the matrix $CP^T$ of shape $(B,B)$ whose $(i,j)th$ entry given us the dot product between the ith question and the jth passage. The elements along the diagonal of $CP^T$ are the scores for positive pairs and off-diagonal entries are for negative pairs. Then we can train a softmax classifier to classify the diagonal term in each row as the score for the positive class and the remaining B-1 terms as non-positive classes. This is the trick of `in-batch negatves` 

However, one issue is that out of the ~1M different passages, only ~3000 of them appear as positive evidences in (claim, evidence) pairs. Therefore, the in-batch negatives will also be restricted to these few passages. For good performance, we need to be able to select negatives from the all of the evidence passages from the document store. That's why in addition to passing a minibatch of claims $C$ and corresponding positive passages $P$, we will also pass in a batch of negatives $N$ which are selectled from the set of passgaes outside of the 3000 that appear as positives. Then we compute the matrix $CN^T$ which has shape $(B,B)$. All terms in the $ith$ row of this matrix are scores for the $ith$ claim with $B$ negatives. Then by horizontally concatenating: $[CP^T; CN^T]$, we get a matrix of shape $(B,2B)$ and we can just train a softmax classifier to classify the diagonal term in each row as the score for the positive class and the remaining 2B-1 terms as non-positive classes.

`Hard-negative mining`: We could simply just create a batch of N negatives by randomly sampling the set fo all negatives. However, a slightly better option would be to select the "hard" negatives. These are passages which are very similar to the positive ones. We could either use the highest scoring non-positive documents from a BM25 retreiver as hard negatives. Or we could first train our model with random negative selection. Then use the highest scoring non-positives from our trained model as negatives and do some finetuning. 

In [12]:
# let's separate out all ~3000 positive evidence passages from the document store and define the remaining as negatives.
all_passages_ids = list(document_store_cleaned.keys())
positives_ids_train = claim_evidence_list_train
negatives_ids_train = list(set(all_passages_ids) - set(positives_ids_train))
                         
# create claim-positive pairs
claim_positive_pairs_train = []
for claim_id in train_data.keys():
    for evidence_id in train_data[claim_id]['evidences']:
        claim_positive_pairs_train.append((claim_id, evidence_id))    


In [None]:
# now let's create a pytroch dataset
class ClaimsDataset(Dataset):
    def __init__(self, claims_data, document_store, block_size=256):
        self.claims_data = claims_data
        self.document_store = document_store
        self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
        self.block_size = block_size
        self.create_pairs()

    def create_pairs(self):
        claim_evidence_list = set([ev for claim in self.claims_data.values() for ev in claim['evidences']])
        all_passages_ids = list(document_store.keys())
        positives_ids = claim_evidence_list
        negatives_ids = list(set(all_passages_ids) - set(positives_ids))
        claim_positive_pairs = []
        for claim_id in self.claims_data.keys():
            for evidence_id in self.claims_data[claim_id]['evidences']:
                claim_positive_pairs.append((claim_id, evidence_id))    


    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        # get the question and context passage
        q = self.questions[idx][0]
        passage_idx = self.questions[idx][1]
        question = q['question']
        passage = self.passages[passage_idx]
        # tokenize the context passage
        passage_encoding = self.tokenizer.encode_plus(passage, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        passage_idx = passage_encoding['input_ids']
        # tokenize the question
        question_encoding = self.tokenizer.encode_plus(question, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        question_idx = question_encoding['input_ids']

        # get answer span start and end character positions, for multiple answers, we will only use the first answer
        first_answer_idx = 0
        answer_start_char = q['answers'][first_answer_idx]['answer_start']
        answer_end_char = answer_start_char + len(q['answers'][first_answer_idx]['text'])
        # convert char positions to token positions
        answer_start_token = passage_encoding.char_to_token(answer_start_char)
        answer_end_token = passage_encoding.char_to_token(answer_end_char-1)

        # select a window size so that the passage sequence will be no longer than the block size
        window_size_tokens = self.block_size - 2 # 2 special tokens ([CLS], [SEP])
        # now create a window around the answer span, pick the window start position randomly
        window_start_min = max(0, answer_end_token - window_size_tokens + 1)
        window_start_max = min(answer_start_token, max(0,len(passage_idx) - window_size_tokens)) # we want to make the window as large as possible, but not go over the end of the context
        window_start = random.randint(window_start_min, window_start_max)
        window_end = window_start + window_size_tokens        
        # select window of passage tokens
        passage_window_tokens = passage_idx[window_start:window_end]

        # create padded query and passage sequences
        question_idx = [self.tokenizer.cls_token_id] + question_idx + [self.tokenizer.sep_token_id]
        question_idx = question_idx + [self.tokenizer.pad_token_id]*(self.block_size-len(question_idx))
        passage_idx = [self.tokenizer.cls_token_id] + passage_window_tokens + [self.tokenizer.sep_token_id]
        passage_idx = passage_idx + [self.tokenizer.pad_token_id]*(self.block_size-len(passage_idx))

        # make sure the passage sequence and query sequences are not longer than max_length
        if len(question_idx) > self.block_size or len(passage_idx) > self.block_size:
            raise Exception(f"Passage sequence length {len(passage_idx)} or question sequence length {len(question_idx)} is longer than max_length {self.block_size}!")
        
        # create attention masks
        question_attn_mask = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in question_idx]
        passage_attn_mask  = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in passage_idx]

        # convert to tensors
        question_idx = torch.tensor(question_idx)
        question_attn_mask = torch.tensor(question_attn_mask)
        passage_idx = torch.tensor(passage_idx)
        passage_attn_mask = torch.tensor(passage_attn_mask)

        return question_idx, question_attn_mask, passage_idx, passage_attn_mask