In [2]:
from datasets import load_dataset
import pandas as pd
from ir_eval.metrics import recall, precision, hole, ndcg
from sentence_transformers import SentenceTransformer, util, CrossEncoder
from ir_eval.utils_prompt import load_prompt_text, eval_prompt, preprocess_prompt, tag_parser
import collections
import os
import json
import numpy as np
%load_ext autoreload
%autoreload 2


## Load trec covid dataset
https://paperswithcode.com/dataset/trec-covid

In [5]:
corpus = load_dataset("BeIR/trec-covid", 'corpus')['corpus']
queries = load_dataset("BeIR/trec-covid", 'queries')['queries']
qrels = load_dataset("BeIR/trec-covid-qrels")


In [12]:
queries

Dataset({
    features: ['_id', 'title', 'text'],
    num_rows: 50
})

In [6]:
def combine_text(example):
    example['full_text'] = '[Title] ' + example['title'] + ' [TEXT] ' + example['text']
    return example
corpus = corpus.map(combine_text)

In [7]:
qrels_for_eval = collections.defaultdict(dict)
for example in qrels['test']:
    qrels_for_eval[str(example['query-id'])][str(example['corpus-id'])] = example['score']

## Flag Embedding dense

In [8]:
from FlagEmbedding import FlagAutoModel

In [9]:
model = FlagAutoModel.from_finetuned('BAAI/bge-m3',
                                     use_fp16=True)

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [10]:
def get_retrieval_results(query_embeddings, corpus_embeddings, top_k=10):
    hits = util.semantic_search(query_embeddings, corpus_embeddings, top_k=top_k)
    retrieval_results = collections.defaultdict(dict)
    doc_id_map = corpus['_id']
    query_id_map = queries['_id']
    for qid, doc_score_list in enumerate(hits):
        qid_key = query_id_map[qid]
        result_dict = {}
        doc_ids = [doc_score['corpus_id'] for doc_score in doc_score_list]
        scores = [doc_score['score'] for doc_score in doc_score_list]
        
        doc_id_keys = list(map(lambda x: doc_id_map[x], doc_ids))
        result_dict = dict(zip(doc_id_keys, scores))
        retrieval_results[qid_key] = result_dict
    return retrieval_results

In [11]:
def eval_model(retrieval_results, qrels_for_eval):
    print(recall(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10, 20, 30, 100, 500, 2000]))
    print(precision(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10, 20, 30, 100, 500, 2000]))
    print(ndcg(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10]))

In [12]:
corpus_embeddings = model.encode(corpus['full_text'], return_dense=True, return_sparse=True, return_colbert_vecs=False)
query_embeddings = model.encode(queries['text'], return_dense=True, return_sparse=True, return_colbert_vecs=False)


pre tokenize: 100%|██████████| 670/670 [00:12<00:00, 53.55it/s]
You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Inference Embeddings: 100%|██████████| 670/670 [09:27<00:00,  1.18it/s]


In [33]:
dense_retrieval_results = get_retrieval_results(
    query_embeddings["dense_vecs"], corpus_embeddings["dense_vecs"], top_k=10
)
eval_model(dense_retrieval_results, qrels_for_eval)

# {'Recall@1': 0.00216, 'Recall@3': 0.00667, 'Recall@5': 0.0106, 'Recall@10': 0.01958, 'Recall@20': 0.01958, 'Recall@30': 0.01958, 'Recall@100': 0.01958, 'Recall@500': 0.01958, 'Recall@2000': 0.01958}
# {'Precision@1': 0.84, 'Precision@3': 0.84, 'Precision@5': 0.82, 'Precision@10': 0.778, 'Precision@20': 0.389, 'Precision@30': 0.25933, 'Precision@100': 0.0778, 'Precision@500': 0.01556, 'Precision@2000': 0.00389}
# {'NDCG@1': 0.82, 'NDCG@3': 0.80181, 'NDCG@5': 0.77883, 'NDCG@10': 0.74813}

{'Recall@1': 0.00216, 'Recall@3': 0.00667, 'Recall@5': 0.0106, 'Recall@10': 0.01958, 'Recall@20': 0.01958, 'Recall@30': 0.01958, 'Recall@100': 0.01958, 'Recall@500': 0.01958, 'Recall@2000': 0.01958}
{'Precision@1': 0.84, 'Precision@3': 0.84, 'Precision@5': 0.82, 'Precision@10': 0.778, 'Precision@20': 0.389, 'Precision@30': 0.25933, 'Precision@100': 0.0778, 'Precision@500': 0.01556, 'Precision@2000': 0.00389}
{'NDCG@1': 0.82, 'NDCG@3': 0.80181, 'NDCG@5': 0.77883, 'NDCG@10': 0.74813}


## Flag Embedding sparse

In [34]:
# compute the scores via lexical matching
lexical_scores = model.compute_lexical_matching_score(
    query_embeddings["lexical_weights"], corpus_embeddings["lexical_weights"]
)

In [None]:
def get_sparse_retrieval_results(query_lexical_weights, corpus_lexcial_weights, top_k=10):
    # note that here all corpus documents are used as candidates to compute score with the query
    # ideally, we should use an inverted index and a matching phase to reduce the number of candidates.
    
    lexical_scores = model.compute_lexical_matching_score(
        query_lexical_weights, corpus_lexcial_weights
    )
    sorted_doc_ids = np.argsort(-lexical_scores, axis=1)[:,:top_k]
    sorted_scores = lexical_scores[np.arange(sorted_doc_ids.shape[0])[:, None], sorted_doc_ids]
    raw_results = [
    [{"corpus_id": doc_id, "score": score} for doc_id, score in zip(doc_ids, scores)]
    for doc_ids, scores in zip(sorted_doc_ids, sorted_scores)
    ]   
    retrieval_results = collections.defaultdict(dict)
    doc_id_map = corpus['_id']
    query_id_map = queries['_id']
    for qid, doc_score_list in enumerate(raw_results):
        qid_key = query_id_map[qid]
        result_dict = {}
        doc_ids = [doc_score['corpus_id'] for doc_score in doc_score_list]
        scores = [doc_score['score'] for doc_score in doc_score_list]
        
        doc_id_keys = list(map(lambda x: doc_id_map[x], doc_ids))
        result_dict = dict(zip(doc_id_keys, scores))
        retrieval_results[qid_key] = result_dict
    return retrieval_results

In [36]:

sparse_retrieval_results = get_sparse_retrieval_results(
    query_embeddings["lexical_weights"], corpus_embeddings["lexical_weights"], top_k=10
)
eval_model(sparse_retrieval_results, qrels_for_eval)
# {'Recall@1': 0.00177, 'Recall@3': 0.0056, 'Recall@5': 0.00864, 'Recall@10': 0.01524, 'Recall@20': 0.01524, 'Recall@30': 0.01524, 'Recall@100': 0.01524, 'Recall@500': 0.01524, 'Recall@2000': 0.01524}
# {'Precision@1': 0.68, 'Precision@3': 0.68667, 'Precision@5': 0.64, 'Precision@10': 0.58, 'Precision@20': 0.29, 'Precision@30': 0.19333, 'Precision@100': 0.058, 'Precision@500': 0.0116, 'Precision@2000': 0.0029}
# {'NDCG@1': 0.59, 'NDCG@3': 0.5986, 'NDCG@5': 0.5708, 'NDCG@10': 0.53365}

{'Recall@1': 0.00177, 'Recall@3': 0.0056, 'Recall@5': 0.00864, 'Recall@10': 0.01524, 'Recall@20': 0.01524, 'Recall@30': 0.01524, 'Recall@100': 0.01524, 'Recall@500': 0.01524, 'Recall@2000': 0.01524}
{'Precision@1': 0.68, 'Precision@3': 0.68667, 'Precision@5': 0.64, 'Precision@10': 0.58, 'Precision@20': 0.29, 'Precision@30': 0.19333, 'Precision@100': 0.058, 'Precision@500': 0.0116, 'Precision@2000': 0.0029}
{'NDCG@1': 0.59, 'NDCG@3': 0.5986, 'NDCG@5': 0.5708, 'NDCG@10': 0.53365}


### Inspecting lexical weights

In [37]:
token_ids = list(query_embeddings['lexical_weights'][0].keys())
token_weight = query_embeddings['lexical_weights'][0].values()

In [38]:
model.convert_id_to_token(query_embeddings['lexical_weights'][0])

{'what': np.float16(0.1346),
 'is': np.float16(0.1025),
 'the': np.float16(0.1328),
 'origin': np.float16(0.298),
 'of': np.float16(0.11835),
 'CO': np.float16(0.1787),
 'VID': np.float16(0.2443),
 '-19': np.float16(0.2563)}

In [39]:
model.convert_id_to_token(query_embeddings['lexical_weights'][1])

{'how': np.float16(0.0609),
 'does': np.float16(0.03775),
 'the': np.float16(0.05612),
 'corona': np.float16(0.1761),
 'virus': np.float16(0.2336),
 'respond': np.float16(0.2029),
 'to': np.float16(0.05173),
 'changes': np.float16(0.1879),
 'in': np.float16(0.06058),
 'weather': np.float16(0.2391)}

In [40]:
query_embeddings['lexical_weights'][1]

defaultdict(int,
            {'3642': np.float16(0.0609),
             '14602': np.float16(0.03775),
             '70': np.float16(0.05612),
             '109728': np.float16(0.1761),
             '76912': np.float16(0.2336),
             '35644': np.float16(0.2029),
             '47': np.float16(0.05173),
             '65572': np.float16(0.1879),
             '23': np.float16(0.06058),
             '92949': np.float16(0.2391)})

In [41]:
# Assuming `token_weights` is the dictionary you want to sort
token_weights = model.convert_id_to_token(corpus_embeddings['lexical_weights'][1])

# Sort the dictionary by value in descending order
sorted_token_weights = dict(sorted(token_weights.items(), key=lambda item: item[1], reverse=True))

# Print the sorted dictionary
print(sorted_token_weights)



{'NO': np.float16(0.2354), '•': np.float16(0.2316), 'inflammation': np.float16(0.2189), 'tric': np.float16(0.2039), 'lung': np.float16(0.2026), 'media': np.float16(0.2024), 'dependent': np.float16(0.192), 'respirator': np.float16(0.1849), 'oxid': np.float16(0.1842), 'flam': np.float16(0.1842), 'tract': np.float16(0.176), 'disease': np.float16(0.1725), 'phil': np.float16(0.1674), 'injury': np.float16(0.1669), 'rite': np.float16(0.1613), 'ni': np.float16(0.1595), 'associated': np.float16(0.1565), 'Ni': np.float16(0.1549), 'stress': np.float16(0.1522), 'nit': np.float16(0.149), 'contribution': np.float16(0.1477), 'pro': np.float16(0.1472), 'eleva': np.float16(0.1465), 'models': np.float16(0.1448), 'mechanism': np.float16(0.1442), 'infla': np.float16(0.1442), 'ma': np.float16(0.1381), 'production': np.float16(0.1366), 'oxy': np.float16(0.1357), 'protein': np.float16(0.1345), 'support': np.float16(0.1311), 'tor': np.float16(0.1283), 'evidence': np.float16(0.1272), 'mye': np.float16(0.1249),

In [42]:
corpus[1]

{'_id': '02tnwd4m',
 'title': 'Nitric oxide: a pro-inflammatory mediator in lung disease?',
 'text': 'Inflammatory diseases of the respiratory tract are commonly associated with elevated production of nitric oxide (NO•) and increased indices of NO• -dependent oxidative stress. Although NO• is known to have anti-microbial, anti-inflammatory and anti-oxidant properties, various lines of evidence support the contribution of NO• to lung injury in several disease models. On the basis of biochemical evidence, it is often presumed that such NO• -dependent oxidations are due to the formation of the oxidant peroxynitrite, although alternative mechanisms involving the phagocyte-derived heme proteins myeloperoxidase and eosinophil peroxidase might be operative during conditions of inflammation. Because of the overwhelming literature on NO• generation and activities in the respiratory tract, it would be beyond the scope of this commentary to review this area comprehensively. Instead, it focuses on

### Built an inverted index

In [43]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')

In [44]:
tokenizer.decode([76912])

'virus'

In [45]:
tokenizer.vocab_size

250002

In [46]:
import collections
index = collections.defaultdict(list)

In [47]:
for docid, doc_terms in enumerate(corpus_embeddings['lexical_weights']):
    for term_id, _ in doc_terms.items():
        index[term_id].append(docid)

### Speed up sparse retrieval using inverted index

In [48]:
def _get_sorted_retrieved_documents(query_lexical_weights, index, corpus_lexcial_weights):
    
    retrieved_doc_per_query_list = []
    for query_term_dict in query_lexical_weights:
        retrieved_doc_per_query = collections.defaultdict(float)
        for query_term, query_term_weight in query_term_dict.items():
            matched_doc_list = index[query_term]
            for doc_id in matched_doc_list:
                doc_term_weights = corpus_lexcial_weights[doc_id]
                retrieved_doc_per_query[doc_id] += query_term_weight * doc_term_weights.get(query_term, 0.0)
        sorted_retrieved_doc_per_query = sorted(retrieved_doc_per_query.items(), key=lambda item: item[1], reverse=True)
        retrieved_doc_per_query_list.append(sorted_retrieved_doc_per_query)
    return retrieved_doc_per_query_list

In [49]:
def get_sparse_retrieval_results_using_inverted_index(query_lexical_weights, corpus_lexcial_weights, index, top_k=10):
    
    retrieved_doc_per_query_list = _get_sorted_retrieved_documents(query_lexical_weights, index, corpus_lexcial_weights)
    raw_results = [
    [{"corpus_id": doc_id, "score": score} for doc_id, score in retrieved_doc_per_query[:top_k]]
    for retrieved_doc_per_query in retrieved_doc_per_query_list
    ]   
    retrieval_results = collections.defaultdict(dict)
    doc_id_map = corpus['_id']
    query_id_map = queries['_id']
    for qid, doc_score_list in enumerate(raw_results):
        qid_key = query_id_map[qid]
        result_dict = {}
        doc_ids = [doc_score['corpus_id'] for doc_score in doc_score_list]
        scores = [doc_score['score'] for doc_score in doc_score_list]
        
        doc_id_keys = list(map(lambda x: doc_id_map[x], doc_ids))
        result_dict = dict(zip(doc_id_keys, scores))
        retrieval_results[qid_key] = result_dict
    return retrieval_results

In [50]:
retrieval_results = get_sparse_retrieval_results_using_inverted_index(
    query_embeddings["lexical_weights"], corpus_embeddings["lexical_weights"], index, top_k=10
)
eval_model(retrieval_results, qrels_for_eval)

# {'Recall@1': 0.00177, 'Recall@3': 0.00555, 'Recall@5': 0.00861, 'Recall@10': 0.01524, 'Recall@20': 0.01524, 'Recall@30': 0.01524, 'Recall@100': 0.01524, 'Recall@500': 0.01524, 'Recall@2000': 0.01524}
# {'Precision@1': 0.68, 'Precision@3': 0.68, 'Precision@5': 0.64, 'Precision@10': 0.58, 'Precision@20': 0.29, 'Precision@30': 0.19333, 'Precision@100': 0.058, 'Precision@500': 0.0116, 'Precision@2000': 0.0029}
# {'NDCG@1': 0.59, 'NDCG@3': 0.5938, 'NDCG@5': 0.57128, 'NDCG@10': 0.53323}

{'Recall@1': 0.00177, 'Recall@3': 0.00555, 'Recall@5': 0.00861, 'Recall@10': 0.01524, 'Recall@20': 0.01524, 'Recall@30': 0.01524, 'Recall@100': 0.01524, 'Recall@500': 0.01524, 'Recall@2000': 0.01524}
{'Precision@1': 0.68, 'Precision@3': 0.68, 'Precision@5': 0.64, 'Precision@10': 0.58, 'Precision@20': 0.29, 'Precision@30': 0.19333, 'Precision@100': 0.058, 'Precision@500': 0.0116, 'Precision@2000': 0.0029}
{'NDCG@1': 0.59, 'NDCG@3': 0.5938, 'NDCG@5': 0.57128, 'NDCG@10': 0.53323}


## Flag Embedding hybrid

In [51]:
def get_hybrid_retrieval_results(query_embeddings, corpus_embeddings, top_k=10):
    
    # compute the scores via lexical mathcing
    dense_scores = query_embeddings["dense_vecs"] @ corpus_embeddings["dense_vecs"].T
    lexical_scores = model.compute_lexical_matching_score(
        query_embeddings["lexical_weights"], corpus_embeddings["lexical_weights"]
    )
    hybrid_scores = dense_scores * 0.6 + lexical_scores * 0.4
    sorted_doc_ids = np.argsort(-hybrid_scores, axis=1)[:,:top_k]
    sorted_scores = hybrid_scores[np.arange(sorted_doc_ids.shape[0])[:, None], sorted_doc_ids]
    raw_results = [
    [{"corpus_id": doc_id, "score": score} for doc_id, score in zip(doc_ids, scores)]
    for doc_ids, scores in zip(sorted_doc_ids, sorted_scores)
    ]   
    retrieval_results = collections.defaultdict(dict)
    doc_id_map = corpus['_id']
    query_id_map = queries['_id']
    for qid, doc_score_list in enumerate(raw_results):
        qid_key = query_id_map[qid]
        result_dict = {}
        doc_ids = [doc_score['corpus_id'] for doc_score in doc_score_list]
        scores = [doc_score['score'] for doc_score in doc_score_list]
        
        doc_id_keys = list(map(lambda x: doc_id_map[x], doc_ids))
        result_dict = dict(zip(doc_id_keys, scores))
        retrieval_results[qid_key] = result_dict
    return retrieval_results

In [52]:
hybrid_retrieval_results = get_hybrid_retrieval_results(
    query_embeddings, corpus_embeddings, top_k=10
)
eval_model(hybrid_retrieval_results, qrels_for_eval)

# {'Recall@1': 0.00235, 'Recall@3': 0.00717, 'Recall@5': 0.01144, 'Recall@10': 0.02229, 'Recall@20': 0.02229, 'Recall@30': 0.02229, 'Recall@100': 0.02229, 'Recall@500': 0.02229, 'Recall@2000': 0.02229}
# {'Precision@1': 0.92, 'Precision@3': 0.90667, 'Precision@5': 0.872, 'Precision@10': 0.852, 'Precision@20': 0.426, 'Precision@30': 0.284, 'Precision@100': 0.0852, 'Precision@500': 0.01704, 'Precision@2000': 0.00426}
# {'NDCG@1': 0.86, 'NDCG@3': 0.8552, 'NDCG@5': 0.81812, 'NDCG@10': 0.79824}

{'Recall@1': 0.00235, 'Recall@3': 0.00717, 'Recall@5': 0.01144, 'Recall@10': 0.02229, 'Recall@20': 0.02229, 'Recall@30': 0.02229, 'Recall@100': 0.02229, 'Recall@500': 0.02229, 'Recall@2000': 0.02229}
{'Precision@1': 0.92, 'Precision@3': 0.90667, 'Precision@5': 0.872, 'Precision@10': 0.852, 'Precision@20': 0.426, 'Precision@30': 0.284, 'Precision@100': 0.0852, 'Precision@500': 0.01704, 'Precision@2000': 0.00426}
{'NDCG@1': 0.86, 'NDCG@3': 0.8552, 'NDCG@5': 0.81812, 'NDCG@10': 0.79824}


## Enrich documents with queries

In [53]:
qid_2_query = dict(zip(queries['_id'], queries['text']))
docid_2_title = dict(zip(corpus['_id'], corpus['title']))
docid_2_text = dict(zip(corpus['_id'], corpus['text']))
docid_2_combined_text = dict(zip(corpus['_id'], [title + " " + text for title, text in zip(corpus['title'], corpus['text'])]))

In [54]:
dense_retrieval_results_top100 = get_retrieval_results(
    query_embeddings["dense_vecs"], corpus_embeddings["dense_vecs"], top_k=100
)

In [55]:
enrich_doc_ids = set()
for qid, doc_score_dict in dense_retrieval_results_top100.items():
    for docid, _ in doc_score_dict.items():
        enrich_doc_ids.add(docid)

In [56]:
all_input_dicts = []

for docid in enrich_doc_ids:
    title = docid_2_title[docid]
    text = docid_2_text[docid]        
    record = {
                'docid': docid,
                'title': title,
                'text': text
                }
    all_input_dicts.append(record)

print(len(all_input_dicts))

4485


In [57]:
prompt_template_path = "/mnt/d/Dropbox/llm_book/repos/ir_eval/prompts/doc_to_query.jinja"

In [67]:
def read_json(file_path):
    data = []
    with open(file_path, 'r', encoding="utf-8") as f:
        data = json.load(f)
    return data

In [68]:

def llm_inference(prompt_template_path, all_input_dicts, model, output_path):

    prompt_template_text = load_prompt_text(prompt_template_path)
    count = 0
    for input_dict in all_input_dicts:
        print(count)
        count += 1
        prompt_info_dict = preprocess_prompt(prompt_template_text, input_dict)
        eval_results = eval_prompt(prompt_info_dict, model=model)
        input_dict.update({f"eval_result_{model}": eval_results})
        #break
    
    with open(output_path,"w") as json_file:
        json.dump(all_input_dicts, json_file, indent=4)
    return all_input_dicts
    


In [71]:
output_path="./data/trec_covoid_doc_2_query.jsonl"
if os.path.exists(output_path):
    print(f"output_path {output_path} exists, return the existing jsonl")
    enriched = read_json(output_path)
else:
    enriched = llm_inference(prompt_template_path, all_input_dicts, "gpt-4o-mini", output_path)

output_path ./data/trec_covoid_doc_2_query.jsonl exists, return the existing jsonl


In [72]:
docid_2_enriched = {}
for doc_dict in enriched:
    docid = doc_dict['docid']
    enrichment_raw = doc_dict['eval_result_gpt-4o-mini']
    enrichment = ' '.join(tag_parser(enrichment_raw, "Q"))
    docid_2_enriched[docid] = enrichment

In [None]:
docid_2_enriched

{'k36e2sob': 'What are the key infection prevention guidelines for reopening primary schools during the COVID-19 pandemic in Norway? What specific considerations are outlined for paediatric risk groups in the context of school reopening during COVID-19? In what month and year were the infection prevention guidelines for reopening primary schools published in Norway? What is the primary focus of the document regarding the impact of COVID-19 on children in schools? How does the document address the challenges faced by paediatric risk groups during the pandemic in relation to school environments?',
 '6msznh9u': 'What associations were explored between mental health conditions and lifestyle changes during COVID-19 in Australian adults? How did physical activity levels change among Australian adults during the COVID-19 pandemic? What impact did COVID-19 have on sleep patterns in the Australian adult population? In what ways did tobacco and alcohol use fluctuate among Australian adults durin

In [73]:
def combine_enriched_text(example):
    example['full_text'] = '[Title] '  + example['title'] + ' [QUERY] '+ docid_2_enriched.get(example['_id'],"") + ' [TEXT] ' + example['text']  
    return example
enriched_corpus = corpus.map(combine_enriched_text)

In [None]:
enriched_corpus[0]

{'_id': 'ug7v899j',
 'title': 'Clinical features of culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia',
 'text': 'OBJECTIVE: This retrospective chart review describes the epidemiology and clinical features of 40 patients with culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia. METHODS: Patients with positive M. pneumoniae cultures from respiratory specimens from January 1997 through December 1998 were identified through the Microbiology records. Charts of patients were reviewed. RESULTS: 40 patients were identified, 33 (82.5%) of whom required admission. Most infections (92.5%) were community-acquired. The infection affected all age groups but was most common in infants (32.5%) and pre-school children (22.5%). It occurred year-round but was most common in the fall (35%) and spring (30%). More than three-quarters of patients (77.5%) had comorbidities. Twenty-four isolates (

In [None]:
enriched_corpus_embeddings = model.encode(enriched_corpus['full_text'], return_dense=True, return_sparse=True, return_colbert_vecs=False)


pre tokenize: 100%|██████████| 670/670 [00:14<00:00, 46.56it/s]
Inference Embeddings: 100%|██████████| 670/670 [09:24<00:00,  1.19it/s]


In [75]:
enriched_only_corpus_embeddings = model.encode(list(docid_2_enriched.values()), return_dense=True, return_sparse=True, return_colbert_vecs=False)

pre tokenize: 100%|██████████| 18/18 [00:00<00:00, 104.79it/s]
Inference Embeddings: 100%|██████████| 18/18 [00:06<00:00,  2.61it/s]


In [None]:
sparse_retrieval_results = get_sparse_retrieval_results(
    query_embeddings["lexical_weights"], corpus_embeddings["lexical_weights"], top_k=100
)
eval_model(sparse_retrieval_results, qrels_for_eval)

{'Recall@1': 0.00177, 'Recall@3': 0.0056, 'Recall@5': 0.00864, 'Recall@10': 0.01524, 'Recall@20': 0.02755, 'Recall@30': 0.03794, 'Recall@100': 0.0905, 'Recall@500': 0.0905, 'Recall@2000': 0.0905}
{'Precision@1': 0.68, 'Precision@3': 0.68667, 'Precision@5': 0.64, 'Precision@10': 0.58, 'Precision@20': 0.55, 'Precision@30': 0.52333, 'Precision@100': 0.3984, 'Precision@500': 0.07968, 'Precision@2000': 0.01992}
{'NDCG@1': 0.59, 'NDCG@3': 0.5986, 'NDCG@5': 0.5708, 'NDCG@10': 0.53365}


In [None]:
sparse_retrieval_results = get_sparse_retrieval_results(
    query_embeddings["lexical_weights"], enriched_corpus_embeddings["lexical_weights"], top_k=100
)
eval_model(sparse_retrieval_results, qrels_for_eval)

{'Recall@1': 0.00149, 'Recall@3': 0.00443, 'Recall@5': 0.00752, 'Recall@10': 0.01422, 'Recall@20': 0.02654, 'Recall@30': 0.03701, 'Recall@100': 0.08941, 'Recall@500': 0.08941, 'Recall@2000': 0.08941}
{'Precision@1': 0.6, 'Precision@3': 0.56667, 'Precision@5': 0.576, 'Precision@10': 0.546, 'Precision@20': 0.523, 'Precision@30': 0.48733, 'Precision@100': 0.3874, 'Precision@500': 0.07748, 'Precision@2000': 0.01937}
{'NDCG@1': 0.54, 'NDCG@3': 0.49, 'NDCG@5': 0.49846, 'NDCG@10': 0.4852}


In [76]:
def get_sparse_retrieval_results_with_id_map(query_lexical_weights, corpus_lexcial_weights, query_id_map, doc_id_map, top_k=10):
    # note that here all corpus documents are used as candidates to compute score with the query
    # ideally, we should use an inverted index and a matching phase to reduce the number of candidates.
    
    lexical_scores = model.compute_lexical_matching_score(
        query_lexical_weights, corpus_lexcial_weights
    )
    sorted_doc_ids = np.argsort(-lexical_scores, axis=1)[:,:top_k]
    sorted_scores = lexical_scores[np.arange(sorted_doc_ids.shape[0])[:, None], sorted_doc_ids]
    raw_results = [
    [{"corpus_id": doc_id, "score": score} for doc_id, score in zip(doc_ids, scores)]
    for doc_ids, scores in zip(sorted_doc_ids, sorted_scores)
    ]   
    retrieval_results = collections.defaultdict(dict)
    for qid, doc_score_list in enumerate(raw_results):
        qid_key = query_id_map[qid]
        result_dict = {}
        doc_ids = [doc_score['corpus_id'] for doc_score in doc_score_list]
        scores = [doc_score['score'] for doc_score in doc_score_list]
        
        doc_id_keys = list(map(lambda x: doc_id_map[x], doc_ids))
        result_dict = dict(zip(doc_id_keys, scores))
        retrieval_results[qid_key] = result_dict
    return retrieval_results

In [77]:
query_id_map = queries["_id"]
doc_id_map = list(docid_2_enriched.keys())
enriched_only_sparse_retrieval_results = get_sparse_retrieval_results_with_id_map(
    query_embeddings["lexical_weights"],
    enriched_only_corpus_embeddings["lexical_weights"],
    query_id_map,
    doc_id_map,
    top_k=100
)

In [78]:
eval_model(enriched_only_sparse_retrieval_results, qrels_for_eval)

{'Recall@1': 0.00206, 'Recall@3': 0.00589, 'Recall@5': 0.00909, 'Recall@10': 0.01634, 'Recall@20': 0.02723, 'Recall@30': 0.036, 'Recall@100': 0.07542, 'Recall@500': 0.07542, 'Recall@2000': 0.07542}
{'Precision@1': 0.78, 'Precision@3': 0.73333, 'Precision@5': 0.684, 'Precision@10': 0.614, 'Precision@20': 0.55, 'Precision@30': 0.50067, 'Precision@100': 0.3328, 'Precision@500': 0.06656, 'Precision@2000': 0.01664}
{'NDCG@1': 0.71, 'NDCG@3': 0.67163, 'NDCG@5': 0.64683, 'NDCG@10': 0.59709}


In [None]:
dense_retrieval_results = get_retrieval_results(
    query_embeddings["dense_vecs"], corpus_embeddings["dense_vecs"], top_k=100
)
eval_model(dense_retrieval_results, qrels_for_eval)

{'Recall@1': 0.00216, 'Recall@3': 0.00667, 'Recall@5': 0.0106, 'Recall@10': 0.0195, 'Recall@20': 0.03764, 'Recall@30': 0.05296, 'Recall@100': 0.12817, 'Recall@500': 0.12817, 'Recall@2000': 0.12817}
{'Precision@1': 0.84, 'Precision@3': 0.84, 'Precision@5': 0.82, 'Precision@10': 0.776, 'Precision@20': 0.758, 'Precision@30': 0.71867, 'Precision@100': 0.5574, 'Precision@500': 0.11148, 'Precision@2000': 0.02787}
{'NDCG@1': 0.81, 'NDCG@3': 0.80181, 'NDCG@5': 0.77762, 'NDCG@10': 0.74684}


In [None]:
dense_retrieval_results = get_retrieval_results(
    query_embeddings["dense_vecs"], enriched_corpus_embeddings["dense_vecs"], top_k=100
)
eval_model(dense_retrieval_results, qrels_for_eval)

{'Recall@1': 0.00231, 'Recall@3': 0.00674, 'Recall@5': 0.01122, 'Recall@10': 0.02099, 'Recall@20': 0.03808, 'Recall@30': 0.05218, 'Recall@100': 0.12253, 'Recall@500': 0.12253, 'Recall@2000': 0.12253}
{'Precision@1': 0.88, 'Precision@3': 0.86, 'Precision@5': 0.864, 'Precision@10': 0.82, 'Precision@20': 0.749, 'Precision@30': 0.702, 'Precision@100': 0.5328, 'Precision@500': 0.10656, 'Precision@2000': 0.02664}
{'NDCG@1': 0.82, 'NDCG@3': 0.819, 'NDCG@5': 0.80944, 'NDCG@10': 0.78199}


## Using reranker

In [None]:
from FlagEmbedding import FlagReranker
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True)

In [None]:
def get_rerank_results(retrieved_results, qid_2_query, docid_2_combined_text, reranker, top_k=10):
    
    rerank_results = collections.defaultdict(dict)
    for qid, doc_score_dict in retrieved_results.items():
        print(qid)
        inputs = [[qid_2_query[qid], docid_2_combined_text[docid]] for docid in doc_score_dict.keys()]
        
        scores = reranker.compute_score(inputs)
        sorted_doc_ids = np.argsort(-np.array(scores))[:top_k]
        sorted_corpus_id = [list(doc_score_dict.keys())[i] for i in sorted_doc_ids]
        sorted_scores = [scores[i] for i in sorted_doc_ids]
        rerank_result_dict = dict(zip(sorted_corpus_id, sorted_scores))
        
        rerank_results[qid] = rerank_result_dict
    return rerank_results

In [None]:
rerank_results = get_rerank_results(dense_retrieval_results, qid_2_query, docid_2_combined_text, reranker, top_k=100)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50


In [None]:
eval_model(rerank_results, qrels_for_eval)

{'Recall@1': 0.00208, 'Recall@3': 0.00672, 'Recall@5': 0.01134, 'Recall@10': 0.02332, 'Recall@20': 0.04309, 'Recall@30': 0.06031, 'Recall@100': 0.12817, 'Recall@500': 0.12817, 'Recall@2000': 0.12817}
{'Precision@1': 0.82, 'Precision@3': 0.86667, 'Precision@5': 0.872, 'Precision@10': 0.886, 'Precision@20': 0.851, 'Precision@30': 0.80867, 'Precision@100': 0.5574, 'Precision@500': 0.11148, 'Precision@2000': 0.02787}
{'NDCG@1': 0.81, 'NDCG@3': 0.82959, 'NDCG@5': 0.82467, 'NDCG@10': 0.82245}
