In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]= "6"

# Train

In [2]:
max_seq_length = 512
model_name = "distilbert-base-uncased" 
dataset = "msmarco_tiny"

dataset_path = "../beir/datasets/msmarco_tiny/"
corpus_file = "tiny_collection.json"
queries_file = "topics.dl20.txt"
qrels_test_file = "qrels.dl20-passage.txt"
training_set = "msmarco_triples.train.tiny.tsv"

In [3]:

from sentence_transformers import losses, models, SentenceTransformer
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.search.lexical import BM25Search as BM25
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.train import TrainRetriever
import pathlib, os, tqdm
import logging

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

  from tqdm.autonotebook import tqdm, trange


In [15]:
data_path = f"../beir/datasets/{dataset}"
# corpus, queries, qrels = GenericDataLoader(data_path).load(split="train")

In [5]:
import collections
import pytrec_eval
import json

def load_triplets(path):
    triplets = []
    with open(path) as f:
        for line in f:
            query, positive_passage, negative_passage = line.strip().split('\t')
            triplets.append([query, positive_passage, negative_passage])
    return triplets

def load_corpus_json(path):
    with open(path, 'r') as corpus_f:
        corpus_json = json.load(corpus_f)
    return corpus_json


triplets_temp = load_triplets(f"{dataset_path}{training_set}")
corpus = load_corpus_json(f"{dataset_path}{corpus_file}")

In [6]:
#### Lexical Retrieval using Bm25 (Elasticsearch) ####

## elasticsearch settings
hostname = "localhost" #localhost
index_name = dataset # scifact
initialize = True # True - Delete existing index and re-index all documents from scratch 

number_of_shards = 1
model = BM25(index_name=index_name, hostname=hostname, initialize=initialize, number_of_shards=number_of_shards)
bm25 = EvaluateRetrieval(model)

#### Index passages into the index (seperately)
bm25.retriever.index(corpus)


2024-06-02 16:06:55 - Activating Elasticsearch....
2024-06-02 16:06:55 - Elastic Search Credentials: {'hostname': 'localhost', 'index_name': 'msmarco_tiny', 'keys': {'title': 'title', 'body': 'txt'}, 'timeout': 100, 'retry_on_timeout': True, 'maxsize': 24, 'number_of_shards': 1, 'language': 'english'}
2024-06-02 16:06:55 - Deleting previous Elasticsearch-Index named - msmarco_tiny
2024-06-02 16:06:58 - Creating fresh Elasticsearch-Index named - msmarco_tiny


  0%|          | 0/510585 [00:00<?, ?docs/s]                


In [7]:

triplets = []
hard_negatives_max = 10

#### Retrieve BM25 hard negatives => Given a positive document, find most similar lexical documents
for query_text, pos_doc_text, neg_doc_text in tqdm.tqdm(triplets_temp, desc="Retrieve Hard Negatives using BM25"):
    hits = bm25.retriever.es.lexical_multisearch(texts=[pos_doc_text], top_hits=hard_negatives_max+1)
    for (neg_id, _) in hits[0].get("hits"):
        if corpus[neg_id]["text"] != neg_doc_text:
            neg_text = corpus[neg_id]["text"]
            triplets.append([query_text, pos_doc_text, neg_text])


Retrieve Hard Negatives using BM25: 100%|██████████| 11000/11000 [03:54<00:00, 46.86it/s]


In [8]:
import pickle

with open(f"{dataset_path}{dataset}bm25_triplets.pickle", 'wb') as f:
            pickle.dump(triplets, f, protocol=pickle.HIGHEST_PROTOCOL)

In [7]:
import pickle

with open(f"{dataset_path}{dataset}bm25_triplets.pickle", 'rb') as f:
    triplets = pickle.load(f)

In [8]:
triplets_temp[0]

['is a little caffeine ok during pregnancy',
 'We donâ\x80\x99t know a lot about the effects of caffeine during pregnancy on you and your baby. So itâ\x80\x99s best to limit the amount you get each day. If youâ\x80\x99re pregnant, limit caffeine to 200 milligrams each day. This is about the amount in 1Â½ 8-ounce cups of coffee or one 12-ounce cup of coffee.',
 'It is generally safe for pregnant women to eat chocolate because studies have shown to prove certain benefits of eating chocolate during pregnancy. However, pregnant women should ensure their caffeine intake is below 200 mg per day.']

In [9]:
triplets[0]

['is a little caffeine ok during pregnancy',
 'We donâ\x80\x99t know a lot about the effects of caffeine during pregnancy on you and your baby. So itâ\x80\x99s best to limit the amount you get each day. If youâ\x80\x99re pregnant, limit caffeine to 200 milligrams each day. This is about the amount in 1Â½ 8-ounce cups of coffee or one 12-ounce cup of coffee.',
 'Should I limit caffeine during pregnancy? If youâ\x80\x99re pregnant, you should limit the amount of caffeine you have to 200 milligrams (mg) a day â\x80\x93 the equivalent of two mugs of instant coffee. Caffeine is found naturally in lots of foods, such as coffee, tea and chocolate.']

In [10]:
triplets[1]

['is a little caffeine ok during pregnancy',
 'We donâ\x80\x99t know a lot about the effects of caffeine during pregnancy on you and your baby. So itâ\x80\x99s best to limit the amount you get each day. If youâ\x80\x99re pregnant, limit caffeine to 200 milligrams each day. This is about the amount in 1Â½ 8-ounce cups of coffee or one 12-ounce cup of coffee.',
 'Limit the amount of caffeine you get each day to 200 mg during pregnancy. Drinks and foods with caffeine incldue coffee, tea, energy drinks, soft drinks and chocolate. Limit the amount of caffeine you get each day to 200 mg during pregnancy.']

In [11]:
#### Provide any sentence-transformers or HF model
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

#### Provide a high batch-size to train better with triplets!
retriever = TrainRetriever(model=model, batch_size=12)



2024-06-02 16:20:38 - Use pytorch device_name: cuda


In [12]:

#### Prepare triplets samples
train_samples = retriever.load_train_triplets(triplets=triplets)
train_dataloader = retriever.prepare_train_triplets(train_samples)

#### Training SBERT with cosine-product
train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)

#### Prepare dev evaluator
# ir_evaluator = retriever.load_ir_evaluator(dev_corpus, dev_queries, dev_qrels)

#### If no dev set is present from above use dummy evaluator
ir_evaluator = retriever.load_dummy_evaluator()

#### Provide model save path
model_save_path = os.path.join(os.getcwd(), "../output", "{}-v2-{}-bm25-hard-negs".format(model_name, dataset))
os.makedirs(model_save_path, exist_ok=True)


Adding Input Examples: 100%|██████████| 10081/10081 [00:00<00:00, 170816.04it/s]

2024-06-02 16:20:41 - Loaded 120965 training pairs.





In [13]:
#### Configure Train params
num_epochs = 10
evaluation_steps = 5000
warmup_steps = int(len(train_samples) * num_epochs / retriever.batch_size * 0.1)

retriever.fit(train_objectives=[(train_dataloader, train_loss)], 
                evaluator=ir_evaluator, 
                epochs=num_epochs,
                output_path=model_save_path,
                warmup_steps=warmup_steps,
                evaluation_steps=evaluation_steps,
                use_amp=True)


2024-06-02 16:20:42 - Starting to Train...




Step,Training Loss,Validation Loss,Sequential Score
5000,0.1168,No log,1717370986.346945
10000,0.0537,No log,1717371535.171653
10080,0.0537,No log,1717371545.809072
15000,0.0236,No log,1717372080.916198
20000,0.0118,No log,1717372631.116839
20160,0.0118,No log,1717372651.581016
25000,0.0083,No log,1717373189.319853
30000,0.0072,No log,1717373748.967721
30240,0.0072,No log,1717373778.751595
35000,0.0062,No log,1717374304.727505


2024-06-02 16:29:46 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 16:38:55 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 16:39:05 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 16:48:00 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 16:57:11 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 16:57:31 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 17:06:29 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 17:15:48 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 17:16:18 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 17:25:04 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 17:34:08 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 17:34:45 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 17:43:15 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 17:52:32 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 17:53:19 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 18:01:46 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 18:10:58 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 18:11:54 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 18:20:21 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 18:29:26 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 18:30:23 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 18:37:46 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 18:45:44 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 18:46:49 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 18:53:41 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 19:01:48 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 19:03:00 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 19:09:56 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 19:17:53 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 19:19:14 - Save model to /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

# Evaluate

In [16]:
# Loading test set
# corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")

import collections
import pytrec_eval
import json

def load_queries(path):
    """Returns a dictionary whose keys are query ids and values are query texts."""
    queries = {}
    with open(path) as f:
        for line in f:
            query_id, query_text = line.strip().split('\t')
            queries[query_id] = query_text
    return queries


def load_qrels(path):
    with open(path, 'r') as f_qrel:
        qrels = pytrec_eval.parse_qrel(f_qrel)

    return qrels


def load_corpus_json(path):
    with open(path, 'r') as corpus_f:
        corpus_json = json.load(corpus_f)
    return corpus_json


qrels = load_qrels(f"{dataset_path}{qrels_test_file}")
queries = load_queries(f"{dataset_path}{queries_file}")
corpus = load_corpus_json(f"{dataset_path}{corpus_file}")

In [17]:
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval import models
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

## Load retriever from saved model

model = DRES(models.SentenceBERT(model_save_path), batch_size=128)
retriever = EvaluateRetrieval(model, score_function="cos_sim")

#### Retrieve dense results (format of results is identical to qrels)
results = retriever.retrieve(corpus, queries)

2024-06-02 19:38:42 - Loading faiss with AVX2 support.
2024-06-02 19:38:42 - Successfully loaded faiss with AVX2 support.
2024-06-02 19:38:42 - Use pytorch device_name: cuda
2024-06-02 19:38:42 - Load pretrained SentenceTransformer: /data/addullah/253_proj/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs
2024-06-02 19:38:43 - Encoding Queries...


Batches: 100%|██████████| 1/1 [00:00<00:00, 72.24it/s]

2024-06-02 19:38:43 - Sorting Corpus by document length (Longest first)...





2024-06-02 19:38:43 - Scoring Function: Cosine Similarity (cos_sim)
2024-06-02 19:38:43 - Encoding Batch 1/11...


Batches: 100%|██████████| 391/391 [01:03<00:00,  6.18it/s]


2024-06-02 19:39:48 - Encoding Batch 2/11...


Batches: 100%|██████████| 391/391 [00:51<00:00,  7.57it/s]


2024-06-02 19:40:41 - Encoding Batch 3/11...


Batches: 100%|██████████| 391/391 [00:46<00:00,  8.46it/s]


2024-06-02 19:41:28 - Encoding Batch 4/11...


Batches: 100%|██████████| 391/391 [00:37<00:00, 10.52it/s]


2024-06-02 19:42:06 - Encoding Batch 5/11...


Batches: 100%|██████████| 391/391 [00:33<00:00, 11.69it/s]


2024-06-02 19:42:40 - Encoding Batch 6/11...


Batches: 100%|██████████| 391/391 [00:31<00:00, 12.42it/s]


2024-06-02 19:43:12 - Encoding Batch 7/11...


Batches: 100%|██████████| 391/391 [00:30<00:00, 12.99it/s]


2024-06-02 19:43:43 - Encoding Batch 8/11...


Batches: 100%|██████████| 391/391 [00:28<00:00, 13.71it/s]


2024-06-02 19:44:12 - Encoding Batch 9/11...


Batches: 100%|██████████| 391/391 [00:26<00:00, 14.87it/s]


2024-06-02 19:44:39 - Encoding Batch 10/11...


Batches: 100%|██████████| 391/391 [00:23<00:00, 16.93it/s]


2024-06-02 19:45:03 - Encoding Batch 11/11...


Batches: 100%|██████████| 83/83 [00:03<00:00, 22.29it/s]


In [20]:
with open(f"{dataset_path}{dataset}_distilBertBM_scores.pickle", 'wb') as f:
    pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)

In [19]:
#### Evaluate your retrieval using NDCG@k, MAP@K ...
logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
ndcg, _map, recall, precision

2024-06-02 20:07:23 - Retriever evaluation for k in: [1, 3, 5, 10, 100, 1000]
2024-06-02 20:07:23 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2024-06-02 20:07:23 - 

2024-06-02 20:07:23 - NDCG@1: 0.5803
2024-06-02 20:07:23 - NDCG@3: 0.5167
2024-06-02 20:07:23 - NDCG@5: 0.5176
2024-06-02 20:07:23 - NDCG@10: 0.4875
2024-06-02 20:07:23 - NDCG@100: 0.4226
2024-06-02 20:07:23 - NDCG@1000: 0.4880
2024-06-02 20:07:23 - 

2024-06-02 20:07:23 - MAP@1: 0.0302
2024-06-02 20:07:23 - MAP@3: 0.0688
2024-06-02 20:07:23 - MAP@5: 0.0930
2024-06-02 20:07:23 - MAP@10: 0.1331
2024-06-02 20:07:23 - MAP@100: 0.2299
2024-06-02 20:07:23 - MAP@1000: 0.2506
2024-06-02 20:07:23 - 

2024-06-02 20:07:23 - Recall@1: 0.0302
2024-06-02 20:07:23 - Recall@3: 0.0721
2024-06-02 20:07:23 - Recall@5: 0.1036
2024-06-02 20:07:23 - Recall@10: 0.1567
2024-06-02 20:07:23 - Recall@100: 0.3703
2024-06-02 20:07:23 - Recall@1000: 0.5426

({'NDCG@1': 0.58025,
  'NDCG@3': 0.51669,
  'NDCG@5': 0.51756,
  'NDCG@10': 0.4875,
  'NDCG@100': 0.4226,
  'NDCG@1000': 0.48799},
 {'MAP@1': 0.03016,
  'MAP@3': 0.0688,
  'MAP@5': 0.09296,
  'MAP@10': 0.13311,
  'MAP@100': 0.22992,
  'MAP@1000': 0.25063},
 {'Recall@1': 0.03016,
  'Recall@3': 0.0721,
  'Recall@5': 0.10357,
  'Recall@10': 0.15667,
  'Recall@100': 0.37034,
  'Recall@1000': 0.54264},
 {'P@1': 0.7037,
  'P@3': 0.64815,
  'P@5': 0.62963,
  'P@10': 0.53889,
  'P@100': 0.1813,
  'P@1000': 0.02998})