# Train

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]= "6"

In [2]:
max_seq_length = 512
model_name = "distilbert-base-uncased" 
dataset = "msmarco_tiny"

dataset_path = "../beir/datasets/msmarco_tiny/"
corpus_file = "tiny_collection.json"
queries_file = "topics.dl20.txt"
qrels_test_file = "qrels.dl20-passage.txt"
training_set = "msmarco_triples.train.tiny.tsv"

In [3]:

from sentence_transformers import losses, models, SentenceTransformer
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.search.lexical import BM25Search as BM25
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.train import TrainRetriever
import pathlib, os, tqdm
import logging

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

  from tqdm.autonotebook import tqdm, trange


In [4]:
# Load repLlama scores

import pickle
import torch


with open(f"{dataset_path}tiny_collection_llamaEmbed.pickle", 'rb') as f:
    doc_embeddings = pickle.load(f)

with open(f"{dataset_path}queries_llamaEmbed.pickle", 'rb') as f:
    query_embeddings = pickle.load(f)


from tqdm import tqdm

results_dense = {}
for q_id, q_embed in tqdm(query_embeddings.items()):
    results_dense[q_id] = {}
    for d_id, d_embed in doc_embeddings.items():
        # compute similarity score
        score = torch.dot(q_embed, d_embed)
        results_dense[q_id][d_id] = score.item() #.item() to get value out of tensor


100%|██████████| 54/54 [11:06<00:00, 12.34s/it]


In [7]:
# with open(f"{dataset_path}{dataset}_repLlama_scores.pickle", 'wb') as f:
#     pickle.dump(results_dense, f, protocol=pickle.HIGHEST_PROTOCOL)

In [44]:

with open(f"{dataset_path}{dataset}_repLlama_scores.pickle", 'rb') as f:
    results_dense = pickle.load(f)


# Evaluate-bm25

In [45]:
# Load BM25 scores
import pickle 

with open(f"{dataset_path}{dataset}_bm25_scores.pickle", 'rb') as f:
    results_bm25 = pickle.load(f)

# Ensemble

In [46]:
def get_maxmin(results):
    max_score = -1
    min_score = 999999
    for q_id, q in results.items():
        for doc_id, score in q.items():
            max_score = max(score, max_score)
            min_score = min(score, min_score)

    return min_score, max_score

# Get range to normalize both
min_distilbert_score, max_distilbert_score = get_maxmin(results_dense)
min_bm25_score, max_bm25_score = get_maxmin(results_bm25)

min_distilbert_score, max_distilbert_score, min_bm25_score, max_bm25_score

(0.5468878746032715, 0.9585447907447815, 4.215731, 49.542587)

In [47]:
# Normalize
def normalize_results(results, min_score, max_score):
    for q_id, q in results.items():
        for doc_id, score in q.items():
            results[q_id][doc_id] = (score-min_score)/(max_score-min_score)

    return results

results = normalize_results(results_dense, min_distilbert_score, max_distilbert_score)
results_bm25 = normalize_results(results_bm25, min_bm25_score, max_bm25_score)
# results

In [48]:
# results_bm25

In [49]:
def ensemble_score(x,y):
    mu = 0.8
    return mu*x + (1-mu)*y

combined_result = {}

for q_id_1, q_1 in results.items():
        combined_result[q_id_1] = {}
        for doc_id_1, score_1 in q_1.items():
            
            score_2 = 0
            if results_bm25[q_id_1].get(doc_id_1,None)!=None:
                score_2 = results_bm25[q_id_1][doc_id_1]
                del results_bm25[q_id_1][doc_id_1] # So that same query-doc pair is not added to combined result twice
            
            combined_score = ensemble_score(score_1, score_2)
            combined_result[q_id_1][doc_id_1] = combined_score


# Now add remaining bm25 results in combined dict
for q_id_2, q_2 in results_bm25.items():
    for doc_id_2, score_2 in q_2.items():
         score_1 = 0
         combined_score = ensemble_score(score_1, score_2)
         combined_result[q_id_1][doc_id_1] = combined_score

In [25]:
from beir.retrieval.search.lexical import BM25Search as BM25
from beir.retrieval.evaluation import EvaluateRetrieval

## elasticsearch settings
hostname = "localhost" #localhost
index_name = dataset+'_1' # scifact
initialize = True # True - Delete existing index and re-index all documents from scratch 

model_bm25 = BM25(index_name=index_name, hostname=hostname, initialize=initialize)
retriever_bm25 = EvaluateRetrieval(model_bm25)

2024-06-02 21:50:17 - Activating Elasticsearch....
2024-06-02 21:50:17 - Elastic Search Credentials: {'hostname': 'localhost', 'index_name': 'msmarco_tiny_1', 'keys': {'title': 'title', 'body': 'txt'}, 'timeout': 100, 'retry_on_timeout': True, 'maxsize': 24, 'number_of_shards': 'default', 'language': 'english'}
2024-06-02 21:50:17 - Deleting previous Elasticsearch-Index named - msmarco_tiny_1
2024-06-02 21:50:17 - Unable to create Index in Elastic Search. Reason: ConnectionError(('Connection aborted.', BadStatusLine('ÿ\x00\x00\x00\x00\x00\x00\x00\x01\x7fg\x00 identity\r\n'))) caused by: ProtocolError(('Connection aborted.', BadStatusLine('ÿ\x00\x00\x00\x00\x00\x00\x00\x01\x7fg\x00 identity\r\n')))
2024-06-02 21:50:19 - Creating fresh Elasticsearch-Index named - msmarco_tiny_1
2024-06-02 21:50:19 - Unable to create Index in Elastic Search. Reason: ConnectionError(('Connection aborted.', BadStatusLine('ÿ\x00\x00\x00\x00\x00\x00\x00\x01\x7fn\x00ent-Length: 117\r\n'))) caused by: ProtocolE

In [50]:
import collections
import pytrec_eval
import json

def load_qrels(path):
    with open(path, 'r') as f_qrel:
        qrels = pytrec_eval.parse_qrel(f_qrel)

    return qrels

qrels = load_qrels(f"{dataset_path}{qrels_test_file}")

In [51]:
ndcg, _map, recall, precision = retriever_bm25.evaluate(qrels, combined_result, retriever_bm25.k_values)

2024-06-02 21:54:09 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2024-06-02 21:54:31 - 

2024-06-02 21:54:31 - NDCG@1: 0.7623
2024-06-02 21:54:31 - NDCG@3: 0.7373
2024-06-02 21:54:31 - NDCG@5: 0.7066
2024-06-02 21:54:31 - NDCG@10: 0.6922
2024-06-02 21:54:31 - NDCG@100: 0.7147
2024-06-02 21:54:31 - NDCG@1000: 0.8100
2024-06-02 21:54:31 - 

2024-06-02 21:54:31 - MAP@1: 0.0396
2024-06-02 21:54:31 - MAP@3: 0.0864
2024-06-02 21:54:31 - MAP@5: 0.1211
2024-06-02 21:54:31 - MAP@10: 0.2026
2024-06-02 21:54:31 - MAP@100: 0.5245
2024-06-02 21:54:31 - MAP@1000: 0.6027
2024-06-02 21:54:31 - 

2024-06-02 21:54:31 - Recall@1: 0.0396
2024-06-02 21:54:31 - Recall@3: 0.0872
2024-06-02 21:54:31 - Recall@5: 0.1255
2024-06-02 21:54:31 - Recall@10: 0.2351
2024-06-02 21:54:31 - Recall@100: 0.7199
2024-06-02 21:54:31 - Recall@1000: 0.9492
2024-06-02 21:54:31 - 

2024-06-02 21:54:31 - P@1: 0.8889
2024-06-02 21:54:31