In [1]:
from datasets import load_dataset
import pandas as pd
from ir_eval.metrics import recall, precision, hole, ndcg
from sentence_transformers import SentenceTransformer, util, CrossEncoder
from ir_eval.utils_prompt import load_prompt_text, eval_prompt, preprocess_prompt
import collections
import os
import json
%load_ext autoreload
%autoreload 2


## Load trec covid dataset
https://paperswithcode.com/dataset/trec-covid

In [2]:
corpus = load_dataset("BeIR/trec-covid", 'corpus')['corpus']
queries = load_dataset("BeIR/trec-covid", 'queries')['queries']
qrels = load_dataset("BeIR/trec-covid-qrels")


In [3]:
def combine_text(example):
    example['full_text'] = '[Title] ' + example['title'] + ' [TEXT] ' + example['text']
    return example
corpus = corpus.map(combine_text)

In [4]:
qrels_for_eval = collections.defaultdict(dict)
for example in qrels['test']:
    qrels_for_eval[str(example['query-id'])][str(example['corpus-id'])] = example['score']

## Load model

In [5]:
model_name = 'sentence-transformers/all-mpnet-base-v2'
model = SentenceTransformer(model_name)


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

1_Pooling%2Fconfig.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [11]:
def get_retrieval_results(query_embeddings, corpus_embeddings, top_k=10):
    hits = util.semantic_search(query_embeddings, corpus_embeddings, top_k=top_k)
    retrieval_results = collections.defaultdict(dict)
    doc_id_map = corpus['_id']
    query_id_map = queries['_id']
    for qid, doc_score_list in enumerate(hits):
        qid_key = query_id_map[qid]
        result_dict = {}
        doc_ids = [doc_score['corpus_id'] for doc_score in doc_score_list]
        scores = [doc_score['score'] for doc_score in doc_score_list]
        
        doc_id_keys = list(map(lambda x: doc_id_map[x], doc_ids))
        result_dict = dict(zip(doc_id_keys, scores))
        retrieval_results[qid_key] = result_dict
    return retrieval_results

In [7]:
def eval_model(retrieval_results, qrels_for_eval):
    
    print(recall(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10, 20, 30, 100, 500, 2000]))
    print(precision(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10, 20, 30, 100, 500, 2000]))
    print(ndcg(qrels=qrels_for_eval, results=retrieval_results, k_values=[1, 3, 5, 10]))

In [None]:
corpus_embeddings = model.encode(corpus['full_text'], convert_to_tensor=True, show_progress_bar=True)
query_embeddings = model.encode(queries['text'], convert_to_tensor=True, show_progress_bar=True)



Batches:   0%|          | 0/5355 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

AttributeError: 'NoneType' object has no attribute 'items'

In [12]:
retrieval_results = get_retrieval_results(query_embeddings, corpus_embeddings, top_k=10)
eval_model(retrieval_results, qrels_for_eval)

{'Recall@1': 0.00179, 'Recall@3': 0.00488, 'Recall@5': 0.0079, 'Recall@10': 0.01548, 'Recall@20': 0.01548, 'Recall@30': 0.01548, 'Recall@100': 0.01548, 'Recall@500': 0.01548, 'Recall@2000': 0.01548}
{'Precision@1': 0.6, 'Precision@3': 0.6, 'Precision@5': 0.592, 'Precision@10': 0.588, 'Precision@20': 0.294, 'Precision@30': 0.196, 'Precision@100': 0.0588, 'Precision@500': 0.01176, 'Precision@2000': 0.00294}
{'NDCG@1': 0.54, 'NDCG@3': 0.54842, 'NDCG@5': 0.54485, 'NDCG@10': 0.53753}


## Model performance record

### sentence-transformers/all-mpnet-base-v2
{'Recall@1': 0.00179, 'Recall@3': 0.00488, 'Recall@5': 0.0079, 'Recall@10': 0.01548, 'Recall@20': 0.01548, 'Recall@30': 0.01548, 'Recall@100': 0.01548, 'Recall@500': 0.01548, 'Recall@2000': 0.01548}
{'Precision@1': 0.6, 'Precision@3': 0.6, 'Precision@5': 0.592, 'Precision@10': 0.588, 'Precision@20': 0.294, 'Precision@30': 0.196, 'Precision@100': 0.0588, 'Precision@500': 0.01176, 'Precision@2000': 0.00294}
{'NDCG@1': 0.54, 'NDCG@3': 0.54842, 'NDCG@5': 0.54485, 'NDCG@10': 0.53753}