In [86]:
from datasets import load_dataset
import pandas as pd
from ir_eval.metrics import recall, precision, hole, ndcg
from sentence_transformers import SentenceTransformer, util, CrossEncoder
from ir_eval.utils_prompt import load_prompt_text, eval_prompt, preprocess_prompt
import collections
import os
import json
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load trec covid dataset
https://paperswithcode.com/dataset/trec-covid

In [3]:
corpus = load_dataset("BeIR/trec-covid", 'corpus')['corpus']
queries = load_dataset("BeIR/trec-covid", 'queries')['queries']
qrels = load_dataset("BeIR/trec-covid-qrels")


In [4]:
queries

Dataset({
    features: ['_id', 'title', 'text'],
    num_rows: 50
})

In [5]:
def combine_text(example):
    example['full_text'] = '[Title] ' + example['title'] + ' [TEXT] ' + example['text']
    return example
corpus = corpus.map(combine_text)

## Load model

In [6]:
model = SentenceTransformer("all-MiniLM-L6-v2")


In [7]:
corpus_embeddings = model.encode(corpus['full_text'], convert_to_tensor=True, show_progress_bar=True)


Batches:   0%|          | 0/5355 [00:00<?, ?it/s]

In [8]:
query_embeddings = model.encode(queries['text'], convert_to_tensor=True, show_progress_bar=True)


Batches:   0%|          | 0/2 [00:00<?, ?it/s]

In [23]:
query_embeddings.shape

torch.Size([50, 384])

In [24]:
hits = util.semantic_search(query_embeddings, corpus_embeddings, top_k=10)

In [25]:
len(hits)

50

In [26]:
len(hits[0])

10

In [27]:
type(corpus['_id'])

list

In [28]:

retrieval_results = collections.defaultdict(dict)
doc_id_map = corpus['_id']
query_id_map = queries['_id']
for qid, doc_score_list in enumerate(hits):
    qid_key = query_id_map[qid]
    result_dict = {}
    doc_ids = [doc_score['corpus_id'] for doc_score in doc_score_list]
    scores = [doc_score['score'] for doc_score in doc_score_list]
    
    doc_id_keys = list(map(lambda x: doc_id_map[x], doc_ids))
    result_dict = dict(zip(doc_id_keys, scores))
    retrieval_results[qid_key] = result_dict

In [29]:
qrels_for_eval = collections.defaultdict(dict)
for example in qrels['test']:
    qrels_for_eval[str(example['query-id'])][str(example['corpus-id'])] = example['score']



In [30]:
print(recall(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10, 20, 30, 100, 500, 2000]))
print(precision(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10, 20, 30, 100, 500, 2000]))
print(ndcg(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10]))


{'Recall@1': 0.00191, 'Recall@3': 0.00488, 'Recall@5': 0.00775, 'Recall@10': 0.01431, 'Recall@20': 0.01431, 'Recall@30': 0.01431, 'Recall@100': 0.01431, 'Recall@500': 0.01431, 'Recall@2000': 0.01431}
{'Precision@1': 0.74, 'Precision@3': 0.64667, 'Precision@5': 0.612, 'Precision@10': 0.564, 'Precision@20': 0.282, 'Precision@30': 0.188, 'Precision@100': 0.0564, 'Precision@500': 0.01128, 'Precision@2000': 0.00282}
{'NDCG@1': 0.67, 'NDCG@3': 0.60222, 'NDCG@5': 0.57071, 'NDCG@10': 0.53573}


## GPT-4o reranker

### step 1: prepare input dicts

In [99]:
qid_2_query = dict(zip(queries['_id'], queries['text']))
docid_2_title = dict(zip(corpus['_id'], corpus['title']))
docid_2_text = dict(zip(corpus['_id'], corpus['text']))
docid_2_combined_text = dict(zip(corpus['_id'], [title + " " + text for title, text in zip(corpus['title'], corpus['text'])]))

In [54]:
all_input_dicts = []
for qid, doc_score_dict in retrieval_results.items():
    query = qid_2_query[qid]
    for docid, _ in doc_score_dict.items():
        title = docid_2_title[docid]
        text = docid_2_text[docid]
        uid = qid + '@' + docid
        
        record = {'prompt_id': uid,
                  'qid': qid,
                  'docid': docid,
                  'query': query,
                  'title': title,
                  'text': text
                  }
        all_input_dicts.append(record)



## Step2: populate prompt template

In [55]:
prompt_template_path = "/mnt/d/Dropbox/llm_book/repos/ir_eval/prompts/query_doc_rating.jinja"

In [56]:
output_path = "gpt4_scoring.jsonl"
force_rerun = False
model = 'gpt-4o-mini'

In [57]:
def llm_inference(prompt_template_path, all_input_dicts, model, output_path):
    if os.path.exists(output_path):
        print(f"output_path {output_path} exists")
        return
    prompt_template_text = load_prompt_text(prompt_template_path)
    for input_dict in all_input_dicts:
        prompt_info_dict = preprocess_prompt(prompt_template_text, input_dict)
        eval_results = eval_prompt(prompt_info_dict, model=model)
        input_dict.update({f"eval_result_{model}": eval_results})
        #break
    
    with open(output_path,"w") as json_file:
        json.dump(all_input_dicts, json_file, indent=4)
    return all_input_dicts
    


In [58]:
result_dict = llm_inference(prompt_template_path, all_input_dicts, model, output_path)

## Parse score and compute metrics

In [73]:
import re
def parse_score(text, tag='R'):
    pattern = fr"<{tag}>(.*?)</{tag}>"  # Dynamic tag in regex
    matches = re.findall(pattern, text, re.DOTALL)
    if matches:
        return int(matches[0])
    else:
        print("score parsing failure")
        return 0

In [75]:
_ = [d.update({"score": parse_score(d['eval_result_gpt-4o-mini'])}) for d in result_dict]

In [81]:
retrieval_results_gpt4 = collections.defaultdict(dict)
doc_id_map = corpus['_id']
query_id_map = queries['_id']
for d in result_dict:
    qid = d['qid']
    doc_id = d['docid']
    score = d['score']
    retrieval_results_gpt4[qid][doc_id] = score

In [83]:
# gpt4o mini result
print(recall(qrels=qrels_for_eval, results=retrieval_results_gpt4, k_values=[1, 3, 5, 10]))
print(precision(qrels=qrels_for_eval, results=retrieval_results_gpt4, k_values=[1, 3, 5, 10]))
print(ndcg(qrels=qrels_for_eval, results=retrieval_results_gpt4, k_values=[1, 3, 5, 10]))

# {'Recall@1': 0.00217, 'Recall@3': 0.00588, 'Recall@5': 0.00884, 'Recall@10': 0.01431}
# {'Precision@1': 0.8, 'Precision@3': 0.75333, 'Precision@5': 0.696, 'Precision@10': 0.564}
# {'NDCG@1': 0.72, 'NDCG@3': 0.68281, 'NDCG@5': 0.65106, 'NDCG@10': 0.56613}


{'Recall@1': 0.00217, 'Recall@3': 0.00588, 'Recall@5': 0.00884, 'Recall@10': 0.01431}
{'Precision@1': 0.8, 'Precision@3': 0.75333, 'Precision@5': 0.696, 'Precision@10': 0.564}
{'NDCG@1': 0.72, 'NDCG@3': 0.68281, 'NDCG@5': 0.65106, 'NDCG@10': 0.56613}


In [79]:
# biencoder result
print(recall(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10]))
print(precision(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10]))
print(ndcg(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10]))

# biencoder result
# print(recall(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10]))
# print(precision(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10]))
# print(ndcg(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10]))



{'Recall@1': 0.00191, 'Recall@3': 0.00488, 'Recall@5': 0.00775, 'Recall@10': 0.01431}
{'Precision@1': 0.74, 'Precision@3': 0.64667, 'Precision@5': 0.612, 'Precision@10': 0.564}
{'NDCG@1': 0.67, 'NDCG@3': 0.60222, 'NDCG@5': 0.57071, 'NDCG@10': 0.53573}


In [84]:
# perfect rank result
print(recall(qrels=qrels_for_eval, results=qrels_for_eval, k_values=[1, 3, 5, 10]))
print(precision(qrels=qrels_for_eval, results=qrels_for_eval, k_values=[1, 3, 5, 10]))
print(ndcg(qrels=qrels_for_eval, results=qrels_for_eval, k_values=[1, 3, 5, 10]))

# {'Recall@1': 0.00267, 'Recall@3': 0.00802, 'Recall@5': 0.01337, 'Recall@10': 0.02674}
# {'Precision@1': 1.0, 'Precision@3': 1.0, 'Precision@5': 1.0, 'Precision@10': 1.0}
# {'NDCG@1': 1.0, 'NDCG@3': 1.0, 'NDCG@5': 1.0, 'NDCG@10': 1.0}

{'Recall@1': 0.00267, 'Recall@3': 0.00802, 'Recall@5': 0.01337, 'Recall@10': 0.02674}
{'Precision@1': 1.0, 'Precision@3': 1.0, 'Precision@5': 1.0, 'Precision@10': 1.0}
{'NDCG@1': 1.0, 'NDCG@3': 1.0, 'NDCG@5': 1.0, 'NDCG@10': 1.0}


## Cross-encoder reranker

In [88]:
#The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

In [100]:
rerank_results = retrieval_results.copy()
for qid, doc_score_dict in rerank_results.items():
    query = qid_2_query[qid]
    cross_inputs = []
    doc_ids = []
    for docid, _ in doc_score_dict.items():
        
        combined_text = docid_2_combined_text[docid]
        cross_inputs.append([query, combined_text])
        doc_ids.append(docid)
    cross_encoder_scores = cross_encoder.predict(cross_inputs)
    doc_score_dict = dict(zip(doc_ids,cross_encoder_scores.tolist()))
    rerank_results[qid] = doc_score_dict

In [101]:
# cross encoder result
print(recall(qrels=qrels_for_eval, results=rerank_results, k_values=[1, 3, 5, 10]))
print(precision(qrels=qrels_for_eval, results=rerank_results, k_values=[1, 3, 5, 10]))
print(ndcg(qrels=qrels_for_eval, results=rerank_results, k_values=[1, 3, 5, 10]))

# {'Recall@1': 0.002, 'Recall@3': 0.00548, 'Recall@5': 0.00858, 'Recall@10': 0.01431}
# {'Precision@1': 0.78, 'Precision@3': 0.74, 'Precision@5': 0.684, 'Precision@10': 0.564}
# {'NDCG@1': 0.72, 'NDCG@3': 0.67842, 'NDCG@5': 0.64002, 'NDCG@10': 0.56481}

{'Recall@1': 0.002, 'Recall@3': 0.00548, 'Recall@5': 0.00858, 'Recall@10': 0.01431}
{'Precision@1': 0.78, 'Precision@3': 0.74, 'Precision@5': 0.684, 'Precision@10': 0.564}
{'NDCG@1': 0.72, 'NDCG@3': 0.67842, 'NDCG@5': 0.64002, 'NDCG@10': 0.56481}


## Late-Interactive Biencoder
Follow example here: https://github.com/lightonai/pylate?tab=readme-ov-file#retrieve
Set up model and index. Then we can add documents to the index using their embeddings and corresponding ids:

In [102]:
from pylate import indexes, models, retrieve

model = models.ColBERT(
    model_name_or_path="lightonai/colbertv2.0",
)

index = indexes.Voyager(
    index_folder="pylate-index",
    index_name="index",
    override=True,
)

# 
retriever = retrieve.ColBERT(index=index)






In [103]:

# Encode the documents
documents_embeddings = model.encode(
    list(docid_2_combined_text.values()),
    batch_size=32,
    is_query=False, # Encoding documents
    show_progress_bar=True,
)

# Add the documents ids and embeddings to the Voyager index
index.add_documents(
    documents_ids=list(docid_2_combined_text.keys()),
    documents_embeddings=documents_embeddings,
)

Encoding documents (bs=32):   0%|          | 0/5355 [00:00<?, ?it/s]

Adding documents to the index (bs=2000): 100%|██████████| 86/86 [47:54<00:00, 33.43s/it]


<pylate.indexes.voyager.Voyager at 0x7f3d1ce3f990>

In [105]:
# query encoding
queries_embeddings = model.encode(
    list(qid_2_query.values()),
    batch_size=32,
    is_query=True, # Encoding queries
    show_progress_bar=True,
)


Encoding queries (bs=32):   0%|          | 0/2 [00:00<?, ?it/s]

In [106]:
scores = retriever.retrieve(
    queries_embeddings=queries_embeddings, 
    k=10,
)

Retrieving documents (bs=50):  50%|█████     | 1/2 [03:29<03:29, 209.70s/it]


In [108]:
colbert_retrieval_results = collections.defaultdict(dict)
for qid, doc_scores in zip(qid_2_query.keys(), scores):
    for result in doc_scores:    
        colbert_retrieval_results[qid][result['id']] = result['score']


In [109]:
print(recall(qrels=qrels_for_eval, results=colbert_retrieval_results, k_values=[1, 3, 5, 10]))
print(precision(qrels=qrels_for_eval, results=colbert_retrieval_results, k_values=[1, 3, 5, 10]))
print(ndcg(qrels=qrels_for_eval, results=colbert_retrieval_results, k_values=[1, 3, 5, 10]))

# {'Recall@1': 0.00201, 'Recall@3': 0.00639, 'Recall@5': 0.01011, 'Recall@10': 0.01958}
# {'Precision@1': 0.86, 'Precision@3': 0.86, 'Precision@5': 0.816, 'Precision@10': 0.79}
# {'NDCG@1': 0.81, 'NDCG@3': 0.79041, 'NDCG@5': 0.76035, 'NDCG@10': 0.73834}

{'Recall@1': 0.00201, 'Recall@3': 0.00639, 'Recall@5': 0.01011, 'Recall@10': 0.01958}
{'Precision@1': 0.86, 'Precision@3': 0.86, 'Precision@5': 0.816, 'Precision@10': 0.79}
{'NDCG@1': 0.81, 'NDCG@3': 0.79041, 'NDCG@5': 0.76035, 'NDCG@10': 0.73834}
