#### Haystack Passage Retirevers

Previously we implemented our own BM25 and DPR retreiver models. We will now switch to using retreiver models provided by the Haystack library, which are optimized for better performance and have loads of useful features.

In [1]:
from haystack.document_stores import InMemoryDocumentStore, FAISSDocumentStore
from haystack.nodes import BM25Retriever, DensePassageRetriever
from utils import *
import random
import numpy as np
from tqdm import tqdm

In [2]:
# load data from file
passages, train_data, val_data = load_data(clean=True, clean_threshold=30)

Number of evidence passages: 1208827
Number of training instances: 1228
Number of validation instances: 154
Number of evidence passages remaining after cleaning: 1204715


In [20]:
documents = [{"id":p_id, "content": p_text} for p_id, p_text in list(passages.items())]

# create haystack in-memory document store
document_store = InMemoryDocumentStore(use_bm25=True)
document_store.write_documents(documents)

Updating BM25 representation...: 100%|██████████| 1204715/1204715 [00:17<00:00, 69241.03 docs/s]


In [21]:
# set up a haystack BM-25 retreiver
retreiver = BM25Retriever(document_store=document_store)

In [35]:
train_claims = list(train_data.items())
val_claims = list(val_data.items())

Let's test out this BM25 retreiver on some example claims.

In [29]:
# now do a quick test of the retriever
idx = random.randint(0, len(train_claims))  
claim_text = train_claims[idx][1]['claim_text']
gold_evidence_list = train_claims[idx][1]["evidences"]

# retreive BM25 top-5 documents  
topk_documents = retreiver.retrieve(query=claim_text, top_k=5)

print(f"Claim --> {claim_text}")
print(f"\nGold evidences: ")
for evidence in gold_evidence_list:
    print(f"\t {evidence} --> {passages[evidence]}")

print(f"\nBM25 top-5 documents:")
for doc in topk_documents:
    print(f"\t{doc.id} --> {doc.content}")


Claim --> that atmospheric CO2 increase that we observe is a product of temperature  increase, and not the other way around, meaning it is a product of  natural variation...

Gold evidences: 
	 evidence-368192 --> Increases in atmospheric concentrations of CO 2 and other long-lived greenhouse gases such as methane, nitrous oxide and ozone have correspondingly strengthened their absorption and emission of infrared radiation, causing the rise in average global temperature since the mid-20th century.
	 evidence-423643 --> During the late 20th century, a scientific consensus evolved that increasing concentrations of greenhouse gases in the atmosphere cause a substantial rise in global temperatures and changes to other parts of the climate system, with consequences for the environment and for human health.

BM25 top-5 documents:
	evidence-100018 --> The ice core data shows that temperature change causes the level of atmospheric CO2 to change - not the other way round.
	evidence-548766 --> D

Let's compute the evaluation metrics for this BM25 retreiver and see how much it differs from our "home-made" BM25 implementation.

In [34]:
def eval(claims_list, retreiver, topk=[5]):
    precision_total = np.zeros(len(topk))
    recall_total = np.zeros(len(topk))
    f1_total = np.zeros(len(topk))

    for claim_id, claim in tqdm(claims_list):
        claim_text = claim['claim_text']
        gold_evidence_list = claim['evidences']
        # get BM25 top-k passages 
        topk_documents = retreiver.retrieve(query=claim_text, top_k=max(topk)) 
        
        # keep top-k reranked passages
        for i,k in enumerate(topk):
            retreived_doc_ids = [doc.id for doc in topk_documents[:k]]
            intersection = set(retreived_doc_ids).intersection(gold_evidence_list)
            precision = len(intersection) / len(retreived_doc_ids)
            recall = len(intersection) / len(gold_evidence_list)
            f1 = (2*precision*recall/(precision + recall)) if (precision + recall) > 0 else 0 
            precision_total[i] += precision
            recall_total[i] += recall
            f1_total[i] += f1

    precision_avg = precision_total / len(claims_list)
    recall_avg = recall_total / len(claims_list)
    f1_avg = f1_total / len(claims_list)    

    # convert to dictionary
    precision_avg = {f"Precision@{k}":v for k,v in zip(topk, precision_avg)}
    recall_avg = {f"Recall@{k}":v for k,v in zip(topk, recall_avg)}
    f1_avg = {f"F1@{k}":v for k,v in zip(topk, f1_avg)}

    print(f"\nAvg Precision: {precision_avg}, Avg Recall: {recall_avg}, Avg F1: {f1_avg}")
    return precision_avg, recall_avg, f1_avg

In [36]:
# eval on training set
print("Evaluating on training set...")
eval(train_claims, retreiver, topk=[1, 3, 5, 10, 20, 50, 100, 250, 500, 1000])

print("Evalutating on validation set...")
eval(val_claims, retreiver, topk=[1, 3, 5, 10, 20, 50, 100, 250, 500, 1000])

Evaluating on training set...


100%|██████████| 1228/1228 [1:21:53<00:00,  4.00s/it]



Avg Precision: {'Precision@1': 0.16042345276872963, 'Precision@3': 0.11264929424538521, 'Precision@5': 0.08762214983713407, 'Precision@10': 0.06001628664495125, 'Precision@20': 0.03811074918566769, 'Precision@50': 0.020114006514657747, 'Precision@100': 0.011995114006514512, 'Precision@250': 0.005954397394136726, 'Precision@500': 0.0034055374592833368, 'Precision@1000': 0.0019144951140064824}, Avg Recall: {'Recall@1': 0.05359663409337685, 'Recall@3': 0.11235070575461469, 'Recall@5': 0.14132736156351794, 'Recall@10': 0.1941096634093376, 'Recall@20': 0.24214169381107453, 'Recall@50': 0.3116042345276867, 'Recall@100': 0.3753528773072739, 'Recall@250': 0.46114277958740424, 'Recall@500': 0.5226520086862111, 'Recall@1000': 0.5875814332247552}, Avg F1: {'F1@1': 0.07581433224755703, 'F1@3': 0.10515549868155738, 'F1@5': 0.1021747841373252, 'F1@10': 0.08788494350383635, 'F1@20': 0.06414886876329999, 'F1@50': 0.03733490936246122, 'F1@100': 0.02309283869110577, 'F1@250': 0.011725165815050799, 'F1@

100%|██████████| 154/154 [10:13<00:00,  3.98s/it]


Avg Precision: {'Precision@1': 0.15584415584415584, 'Precision@3': 0.11471861471861472, 'Precision@5': 0.08831168831168827, 'Precision@10': 0.0571428571428571, 'Precision@20': 0.036363636363636334, 'Precision@50': 0.021038961038961055, 'Precision@100': 0.012597402597402607, 'Precision@250': 0.006077922077922082, 'Precision@500': 0.003428571428571431, 'Precision@1000': 0.0019480519480519494}, Avg Recall: {'Recall@1': 0.07835497835497834, 'Recall@3': 0.12911255411255412, 'Recall@5': 0.16320346320346313, 'Recall@10': 0.21277056277056272, 'Recall@20': 0.25930735930735943, 'Recall@50': 0.3464285714285715, 'Recall@100': 0.4102813852813853, 'Recall@250': 0.49058441558441546, 'Recall@500': 0.5627705627705626, 'Recall@1000': 0.6457792207792207}, Avg F1: {'F1@1': 0.09632034632034633, 'F1@3': 0.11162646876932593, 'F1@5': 0.10682333539476398, 'F1@10': 0.08540420618342694, 'F1@20': 0.061792711335342615, 'F1@50': 0.039133716128973914, 'F1@100': 0.024275545967757394, 'F1@250': 0.011972195531450198, 




({'Precision@1': 0.15584415584415584,
  'Precision@3': 0.11471861471861472,
  'Precision@5': 0.08831168831168827,
  'Precision@10': 0.0571428571428571,
  'Precision@20': 0.036363636363636334,
  'Precision@50': 0.021038961038961055,
  'Precision@100': 0.012597402597402607,
  'Precision@250': 0.006077922077922082,
  'Precision@500': 0.003428571428571431,
  'Precision@1000': 0.0019480519480519494},
 {'Recall@1': 0.07835497835497834,
  'Recall@3': 0.12911255411255412,
  'Recall@5': 0.16320346320346313,
  'Recall@10': 0.21277056277056272,
  'Recall@20': 0.25930735930735943,
  'Recall@50': 0.3464285714285715,
  'Recall@100': 0.4102813852813853,
  'Recall@250': 0.49058441558441546,
  'Recall@500': 0.5627705627705626,
  'Recall@1000': 0.6457792207792207},
 {'F1@1': 0.09632034632034633,
  'F1@3': 0.11162646876932593,
  'F1@5': 0.10682333539476398,
  'F1@10': 0.08540420618342694,
  'F1@20': 0.061792711335342615,
  'F1@50': 0.039133716128973914,
  'F1@100': 0.024275545967757394,
  'F1@250': 0.011

The best F1-score is about 11% for top-3 retreived documents. 

Now we will finetune a DPR model using haystack and see how well it performs. First, let's put the data in the right format and save it in json files

In [4]:
# create pairs of claim and evidence
"""
train_pairs = []
for claim_id, claim in train_data.items():
    claim_text = claim['claim_text']
    for evidence_id in claim['evidences']:
        train_pairs.append((claim_id, evidence_id)) 
        
val_pairs = []
for claim_id, claim in val_data.items():
    claim_text = claim['claim_text']
    for evidence_id in claim['evidences']:
        val_pairs.append((claim_id, evidence_id)) 

# now let's get some hard negatives that we obtained using our old DPR implementation
with open("dpr_embeddings/train_hard_negatives.pkl", "rb") as f:
    train_hard_negatives = pickle.load(f)
with open("dpr_embeddings/val_hard_negatives.pkl", "rb") as f:
    val_hard_negatives = pickle.load(f)  

# now lets create the DPR training and validation instances, keep top 5 hard negatives
train_instances = []
for claim_id, evidence_id in train_pairs:
    train_instances.append({"question": train_data[claim_id]['claim_text'], "positive_ctxs": [{"title": "no_title", "text":passages[evidence_id], "passages_id":evidence_id}], "hard_negative_ctxs": [{"title": "no_title", "text":passages[neg_id], "passages_id":neg_id} for neg_id in train_hard_negatives[claim_id][:5]]}) 
val_instances = []
for claim_id, evidence_id in val_pairs:
    val_instances.append({"question": val_data[claim_id]['claim_text'], "positive_ctxs": [{"title": "no_title", "text":passages[evidence_id], "passages_id":evidence_id}], "hard_negative_ctxs": [{"title": "no_title", "text":passages[neg_id], "passages_id":neg_id} for neg_id in val_hard_negatives[claim_id][:5]]})        

# Write to JSON file
with open('project-data/train_data_dpr.json', 'w') as f:
    json.dump(train_instances, f)       

with open('project-data/dev_data_dpr.json', 'w') as f:
    json.dump(val_instances, f)       
"""

Now, let's set up the haystack retreiver and train it.

In [3]:
# Initialize the document store
document_store = FAISSDocumentStore()

# Initialize the retriever
retriever = DensePassageRetriever(
    document_store=document_store,
    query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
    passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
    use_gpu=True,
)

# Specify the directory where your training data is located
doc_dir = "project-data"

# Specify the names of your training, development, and test files
train_filename = "train_data_dpr.json"
dev_filename = "dev_data_dpr.json"

# Specify the directory where the trained model should be saved
save_dir = "haystack_dpr_finetuned"


  return self.fget.__get__(instance, owner)()


In [4]:
# Start training the model
retriever.train(
    data_dir=doc_dir,
    train_filename=train_filename,
    dev_filename=dev_filename,
    n_epochs=1,
    batch_size=16,
    grad_acc_steps=8,
    save_dir=save_dir,
    evaluate_every=100,
    embed_title=False,
    num_positives=1,
    num_hard_negatives=1,
)

# load trained retreiver from file
#retriever = DensePassageRetriever.load(load_dir=save_dir, document_store=document_store)

Preprocessing dataset: 100%|██████████| 9/9 [00:02<00:00,  3.69 Dicts/s]
Preprocessing dataset: 100%|██████████| 1/1 [00:00<00:00,  3.49 Dicts/s]
Train epoch 0/0 (Cur. train loss: 10.8034):   0%|          | 1/258 [00:08<36:22,  8.49s/it]

In [None]:
# write documents to document store
documents = [{"id":p_id, "content": p_text} for p_id, p_text in list(passages.items())]
document_store.write_documents(documents)
document_store.update_embeddings(retriever)

In [None]:

# Assuming 'retriever' is your trained DPR model and 'document_store' is your document stor
query = "What is artificial intelligence?"

# Retrieve top-k documents
results = retriever.retrieve(query, top_k=5)

# 'results' is a list of Document objects. You can access the text of each document with the '.text' attribute.
for result in results:
    print(result.text)
