#### Haystack Passage Retirevers

Previously we implemented our own BM25 and DPR retreiver models. We will now switch to using retreiver models provided by the Haystack library, which are optimized for better performance and have loads of useful features.

In [32]:
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes import BM25Retriever
from utils import *
import random
import numpy as np
from tqdm import tqdm

In [3]:
# load data from file
passages, train_data, val_data = load_data(clean=True, clean_threshold=30)


Number of evidence passages: 1208827
Number of training instances: 1228
Number of validation instances: 154
Number of evidence passages remaining after cleaning: 1204715


In [20]:
documents = [{"id":p_id, "content": p_text} for p_id, p_text in list(passages.items())]

# create haystack in-memory document store
document_store = InMemoryDocumentStore(use_bm25=True)
document_store.write_documents(documents)

Updating BM25 representation...: 100%|██████████| 1204715/1204715 [00:17<00:00, 69241.03 docs/s]


In [21]:
# set up a haystack BM-25 retreiver
retreiver = BM25Retriever(document_store=document_store)

In [35]:
train_claims = list(train_data.items())
val_claims = list(val_data.items())

Let's test out this BM25 retreiver on some example claims.

In [29]:
# now do a quick test of the retriever
idx = random.randint(0, len(train_claims))  
claim_text = train_claims[idx][1]['claim_text']
gold_evidence_list = train_claims[idx][1]["evidences"]

# retreive BM25 top-5 documents  
topk_documents = retreiver.retrieve(query=claim_text, top_k=5)

print(f"Claim --> {claim_text}")
print(f"\nGold evidences: ")
for evidence in gold_evidence_list:
    print(f"\t {evidence} --> {passages[evidence]}")

print(f"\nBM25 top-5 documents:")
for doc in topk_documents:
    print(f"\t{doc.id} --> {doc.content}")


Claim --> that atmospheric CO2 increase that we observe is a product of temperature  increase, and not the other way around, meaning it is a product of  natural variation...

Gold evidences: 
	 evidence-368192 --> Increases in atmospheric concentrations of CO 2 and other long-lived greenhouse gases such as methane, nitrous oxide and ozone have correspondingly strengthened their absorption and emission of infrared radiation, causing the rise in average global temperature since the mid-20th century.
	 evidence-423643 --> During the late 20th century, a scientific consensus evolved that increasing concentrations of greenhouse gases in the atmosphere cause a substantial rise in global temperatures and changes to other parts of the climate system, with consequences for the environment and for human health.

BM25 top-5 documents:
	evidence-100018 --> The ice core data shows that temperature change causes the level of atmospheric CO2 to change - not the other way round.
	evidence-548766 --> D

Let's compute the evaluation metrics for this BM25 retreiver.

In [30]:
train_claims[0]

('claim-1937',
 {'claim_text': 'Not only is there no scientific evidence that CO2 is a pollutant, higher CO2 concentrations actually help ecosystems support more plant and animal life.',
  'claim_label': 'DISPUTED',
  'evidences': ['evidence-442946', 'evidence-1194317', 'evidence-12171']})

In [34]:
def eval(claims_list, retreiver, topk=[5]):
    precision_total = np.zeros(len(topk))
    recall_total = np.zeros(len(topk))
    f1_total = np.zeros(len(topk))

    for claim_id, claim in tqdm(claims_list):
        claim_text = claim['claim_text']
        gold_evidence_list = claim['evidences']
        # get BM25 top-k passages 
        topk_documents = retreiver.retrieve(query=claim_text, top_k=max(topk)) 
        
        # keep top-k reranked passages
        for i,k in enumerate(topk):
            retreived_doc_ids = [doc.id for doc in topk_documents[:k]]
            intersection = set(retreived_doc_ids).intersection(gold_evidence_list)
            precision = len(intersection) / len(retreived_doc_ids)
            recall = len(intersection) / len(gold_evidence_list)
            f1 = (2*precision*recall/(precision + recall)) if (precision + recall) > 0 else 0 
            precision_total[i] += precision
            recall_total[i] += recall
            f1_total[i] += f1

    precision_avg = precision_total / len(claims_list)
    recall_avg = recall_total / len(claims_list)
    f1_avg = f1_total / len(claims_list)    

    # convert to dictionary
    precision_avg = {f"Precision@{k}":v for k,v in zip(topk, precision_avg)}
    recall_avg = {f"Recall@{k}":v for k,v in zip(topk, recall_avg)}
    f1_avg = {f"F1@{k}":v for k,v in zip(topk, f1_avg)}

    print(f"\nAvg Precision: {precision_avg}, Avg Recall: {recall_avg}, Avg F1: {f1_avg}")
    return precision_avg, recall_avg, f1_avg

In [36]:
# eval on training set
print("Evaluating on training set...")
eval(train_claims, retreiver, topk=[1, 3, 5, 10, 20, 50, 100, 250, 500, 1000])

print("Evalutating on validation set...")
eval(val_claims, retreiver, topk=[1, 3, 5, 10, 20, 50, 100, 250, 500, 1000])

Evaluating on training set...


  1%|          | 13/1228 [01:00<1:45:25,  5.21s/it]