#### Cross-Encoder Passage Retrieval Model

We will now explore an alternative mnodel for document retreival, called a `cross-encoder`. Unlike the bi-encoder, we will use a single BERT model and feed it input sequence which is a concatenation of a claim-passage pair. Then using the output embedding of the [CLS] token, we perform a `binary classification` of whether or not, the passage is relevant to this claim or not. We can set up training instances of both positive/relevant and negative/non-relevant pairs and train using binary cross-entropy loss. Then the sigmoid of the output logit can be interpreted as a relevancy score between $[0,1]$. 

Since each claim can have multiple relevant evidence passages, we can create multiple positive pairs. Then to have a balanced distribution of the two classes, we would also create the same number of negative pairs. However, we need to figure out a way to select the negative passages. The simplest way is to just `randomly sample` a passage from the document store which is also not in the list of positive passages. Howerever, these random samples may be too easy for our model to detect, and so ideally we would want to use some form of `hard-negative mining` to select good negative passages. 

So first, lets set up our dataset and implement hard-negative mining. 

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizerFast
from collections import Counter
import random
from tqdm import tqdm
import psutil
from utils import *
from DPR_biencoder_simple import *
import wandb

%load_ext autoreload
%autoreload 2

wandb.login()
print(torch.cuda.is_available())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True


In [2]:
# load the data
document_store, train_data, val_data = load_data(clean=True)

Number of evidence passages: 1208827
Number of training instances: 1228
Number of validation instances: 154
Number of evidence passages remaining after cleaning: 1190647


Hard-negative Mining: We will use our pre-trained dpr model to mine for hard negatives. We will retrieve the top-k passages for each training query and then sample hard negatives from there.

In [3]:
# load pretrained dpr model
DEVICE = "cuda"
model = BERTBiEncoder().to(DEVICE)
model = load_dpr_model_checkpoint(model)

# load dpr passage embeddings
evidence_passage_embeds, passage_ids = load_dpr_passage_embeddings()

Loaded model from checkpoint!


In [5]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
block_size = 196

# now get the hard negatives for each question
def get_hard_negatives(dpr_model, tokenizer, data, passage_ids, block_size, k=100):
    """
    Get the k hard negatives for each claim in the dataset
    """
    claims_list = list(data.items()) 
    hard_negatives = {}
    for claim_id, claim in tqdm(claims_list):
        claim_text = claim["claim_text"]
        gold_evidence_list = claim["evidences"]
        topk_passage_ids, topk_scores = find_topk_evidence_dpr(dpr_model, tokenizer, claim_text, evidence_passage_embeds, passage_ids, block_size, k=k)
        # remove the gold evidence from the topk passages
        topk_passage_ids = [p_id for p_id in topk_passage_ids if p_id not in gold_evidence_list]
        hard_negatives[claim_id] = topk_passage_ids

    return hard_negatives

In [6]:
train_hard_negatives = get_hard_negatives(model, tokenizer, train_data, passage_ids, block_size)

100%|██████████| 1228/1228 [00:27<00:00, 44.10it/s]


In [7]:
val_hard_negatives = get_hard_negatives(model, tokenizer, val_data, passage_ids, block_size)

100%|██████████| 154/154 [00:03<00:00, 41.33it/s]


In [14]:
# show some examples of the hard negatives
claim_id = random.choice(list(train_hard_negatives.keys()))
print("Claim:", train_data[claim_id]["claim_text"])
print(f"\nGold evidence: ")
for evidence in train_data[claim_id]["evidences"]:
    print(f"\t{document_store[evidence]}")
print(f"\nTop-5 Hard negatives: ")
for passage_id in train_hard_negatives[claim_id][:10]:
    print(f"\t{document_store[passage_id]}")

Claim: an airplane is contributing to the emissions that put the frozen continent at risk.

Gold evidence: 
	This enables the entire craft to contribute to lift generation with the result of potentially increased fuel economy.
	The airline industry is responsible for about 11% of greenhouse gases emitted by the U.S. transportation sector.
	Aviation's share of the greenhouse gas emissions is poised to grow, as air travel increases and ground vehicles use more alternative fuels like ethanol and biodiesel.
	Boeing estimates that biofuels could reduce flight-related greenhouse-gas emissions by 60 to 80%.
	In November 2017, a statement by 15,364 scientists from 184 countries indicated that increasing levels of greenhouse gases from use of fossil fuels, human population growth, deforestation, and overuse of land for agricultural production, particularly by farming ruminants for meat consumption, are trending in ways that forecast an increase in human misery over coming decades.

Top-5 Hard n