# Train

In [6]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]= "6"

In [7]:
dataset = "msmarco_tiny"

dataset_path = "../beir/datasets/msmarco_tiny/"
corpus_file = "tiny_collection.json"
queries_file = "topics.dl20.txt"
qrels_test_file = "qrels.dl20-passage.txt"
training_set = "msmarco_triples.train.tiny.tsv"

In [8]:
max_seq_length = 512
model_name = "distilbert-base-uncased" 

In [9]:

from sentence_transformers import losses, models, SentenceTransformer
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.search.lexical import BM25Search as BM25
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.train import TrainRetriever
import pathlib, os, tqdm
import logging

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

In [10]:
def load_triplets(path):
    triplets = []
    with open(path) as f:
        for line in f:
            query, positive_passage, negative_passage = line.strip().split('\t')
            triplets.append([query, positive_passage, negative_passage])
    return triplets

triplets = load_triplets(f"{dataset_path}{training_set}")

In [11]:
#### Provide any sentence-transformers or HF model
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])



2024-06-02 20:23:43 - Use pytorch device_name: cuda


In [7]:
retriever = TrainRetriever(model=model, batch_size=12)

#### Prepare triplets samples
train_samples = retriever.load_train_triplets(triplets=triplets)
train_dataloader = retriever.prepare_train_triplets(train_samples)

#### Training SBERT with cosine-product
train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)

ir_evaluator = retriever.load_dummy_evaluator()

#### Provide model save path
model_save_path = os.path.join("../", "output", "{}-v1-{}".format(model_name, dataset))
os.makedirs(model_save_path, exist_ok=True)

#### Configure Train params
num_epochs = 10
evaluation_steps = 5000
warmup_steps = int(len(train_samples) * num_epochs / retriever.batch_size * 0.1)


Adding Input Examples:   0%|          | 0/917 [00:00<?, ?it/s]

Adding Input Examples: 100%|██████████| 917/917 [00:00<00:00, 41252.06it/s]

2024-05-29 20:25:34 - Loaded 11000 training pairs.





In [8]:
retriever.fit(train_objectives=[(train_dataloader, train_loss)], 
                evaluator=ir_evaluator, 
                epochs=num_epochs,
                output_path=model_save_path,
                warmup_steps=warmup_steps,
                evaluation_steps=evaluation_steps,
                use_amp=True)

2024-05-29 20:25:41 - Starting to Train...




Step,Training Loss,Validation Loss,Sequential Score
916,0.3614,No log,1717039698.707815
1832,0.0456,No log,1717039844.021595
2748,0.0047,No log,1717039987.830306
3664,0.0013,No log,1717040130.348752
4580,0.0008,No log,1717040276.147916
5000,0.0007,No log,1717040343.546629
5496,0.0007,No log,1717040424.160553
6412,0.0004,No log,1717040563.502927
7328,0.0003,No log,1717040701.996164
8244,0.0002,No log,1717040838.653505


2024-05-29 20:28:18 - Save model to ../output/distilbert-base-uncased-v1-msmarco_tiny


                                                                             

2024-05-29 20:30:44 - Save model to ../output/distilbert-base-uncased-v1-msmarco_tiny


                                                                             

2024-05-29 20:33:07 - Save model to ../output/distilbert-base-uncased-v1-msmarco_tiny


                                                                             

2024-05-29 20:35:30 - Save model to ../output/distilbert-base-uncased-v1-msmarco_tiny


                                                                             

2024-05-29 20:37:56 - Save model to ../output/distilbert-base-uncased-v1-msmarco_tiny


                                                                             

2024-05-29 20:39:03 - Save model to ../output/distilbert-base-uncased-v1-msmarco_tiny


                                                                             

2024-05-29 20:40:24 - Save model to ../output/distilbert-base-uncased-v1-msmarco_tiny


                                                                             

2024-05-29 20:42:43 - Save model to ../output/distilbert-base-uncased-v1-msmarco_tiny


                                                                             

2024-05-29 20:45:02 - Save model to ../output/distilbert-base-uncased-v1-msmarco_tiny


                                                                             

2024-05-29 20:47:18 - Save model to ../output/distilbert-base-uncased-v1-msmarco_tiny


                                                                             

2024-05-29 20:49:36 - Save model to ../output/distilbert-base-uncased-v1-msmarco_tiny


                                                                             

# Evaluate


In [12]:
# Loading test set
# corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")

import collections
import pytrec_eval
import json

def load_queries(path):
    """Returns a dictionary whose keys are query ids and values are query texts."""
    queries = {}
    with open(path) as f:
        for line in f:
            query_id, query_text = line.strip().split('\t')
            queries[query_id] = query_text
    return queries


def load_qrels(path):
    with open(path, 'r') as f_qrel:
        qrels = pytrec_eval.parse_qrel(f_qrel)

    return qrels


def load_corpus_json(path):
    with open(path, 'r') as corpus_f:
        corpus_json = json.load(corpus_f)
    return corpus_json


qrels = load_qrels(f"{dataset_path}{qrels_test_file}")
queries = load_queries(f"{dataset_path}{queries_file}")
corpus = load_corpus_json(f"{dataset_path}{corpus_file}")

In [14]:
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval import models
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

## Load retriever from saved model
model_save_path = os.path.join("../", "output", "{}-v1-{}".format(model_name, dataset))
model = DRES(models.SentenceBERT(model_save_path), batch_size=128)
retriever = EvaluateRetrieval(model, score_function="cos_sim")

#### Retrieve dense results (format of results is identical to qrels)
results = retriever.retrieve(corpus, queries)

2024-06-02 20:30:11 - Use pytorch device_name: cuda
2024-06-02 20:30:11 - Load pretrained SentenceTransformer: ../output/distilbert-base-uncased-v1-msmarco_tiny
2024-06-02 20:30:12 - Encoding Queries...


Batches: 100%|██████████| 1/1 [00:00<00:00,  2.23it/s]


2024-06-02 20:30:12 - Sorting Corpus by document length (Longest first)...
2024-06-02 20:30:13 - Scoring Function: Cosine Similarity (cos_sim)
2024-06-02 20:30:13 - Encoding Batch 1/11...


Batches: 100%|██████████| 391/391 [01:03<00:00,  6.17it/s]


2024-06-02 20:31:17 - Encoding Batch 2/11...


Batches: 100%|██████████| 391/391 [00:51<00:00,  7.53it/s]


2024-06-02 20:32:10 - Encoding Batch 3/11...


Batches: 100%|██████████| 391/391 [00:45<00:00,  8.52it/s]


2024-06-02 20:32:57 - Encoding Batch 4/11...


Batches: 100%|██████████| 391/391 [00:37<00:00, 10.46it/s]


2024-06-02 20:33:35 - Encoding Batch 5/11...


Batches: 100%|██████████| 391/391 [00:33<00:00, 11.80it/s]


2024-06-02 20:34:09 - Encoding Batch 6/11...


Batches: 100%|██████████| 391/391 [00:31<00:00, 12.28it/s]


2024-06-02 20:34:41 - Encoding Batch 7/11...


Batches: 100%|██████████| 391/391 [00:30<00:00, 12.97it/s]


2024-06-02 20:35:12 - Encoding Batch 8/11...


Batches: 100%|██████████| 391/391 [00:28<00:00, 13.68it/s]


2024-06-02 20:35:41 - Encoding Batch 9/11...


Batches: 100%|██████████| 391/391 [00:26<00:00, 14.87it/s]


2024-06-02 20:36:08 - Encoding Batch 10/11...


Batches: 100%|██████████| 391/391 [00:23<00:00, 16.71it/s]


2024-06-02 20:36:32 - Encoding Batch 11/11...


Batches: 100%|██████████| 83/83 [00:03<00:00, 21.27it/s]


In [15]:
import pickle

with open(f"{dataset_path}{dataset}_distilBertR_scores.pickle", 'wb') as f:
    pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)

In [16]:
#### Evaluate your retrieval using NDCG@k, MAP@K ...
logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
ndcg, _map, recall, precision

2024-06-02 20:55:55 - Retriever evaluation for k in: [1, 3, 5, 10, 100, 1000]
2024-06-02 20:55:55 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2024-06-02 20:55:55 - 

2024-06-02 20:55:55 - NDCG@1: 0.6265
2024-06-02 20:55:55 - NDCG@3: 0.5957
2024-06-02 20:55:55 - NDCG@5: 0.5976
2024-06-02 20:55:55 - NDCG@10: 0.5535
2024-06-02 20:55:55 - NDCG@100: 0.5523
2024-06-02 20:55:55 - NDCG@1000: 0.6391
2024-06-02 20:55:55 - 

2024-06-02 20:55:55 - MAP@1: 0.0367
2024-06-02 20:55:55 - MAP@3: 0.0747
2024-06-02 20:55:55 - MAP@5: 0.1051
2024-06-02 20:55:55 - MAP@10: 0.1539
2024-06-02 20:55:55 - MAP@100: 0.3390
2024-06-02 20:55:55 - MAP@1000: 0.3811
2024-06-02 20:55:55 - 

2024-06-02 20:55:55 - Recall@1: 0.0367
2024-06-02 20:55:55 - Recall@3: 0.0780
2024-06-02 20:55:55 - Recall@5: 0.1130
2024-06-02 20:55:55 - Recall@10: 0.1819
2024-06-02 20:55:55 - Recall@100: 0.5381
2024-06-02 20:55:55 - Recall@1000: 0.7627

({'NDCG@1': 0.62654,
  'NDCG@3': 0.59571,
  'NDCG@5': 0.59759,
  'NDCG@10': 0.55349,
  'NDCG@100': 0.55234,
  'NDCG@1000': 0.63912},
 {'MAP@1': 0.03666,
  'MAP@3': 0.07473,
  'MAP@5': 0.10514,
  'MAP@10': 0.15388,
  'MAP@100': 0.33904,
  'MAP@1000': 0.38106},
 {'Recall@1': 0.03666,
  'Recall@3': 0.078,
  'Recall@5': 0.11304,
  'Recall@10': 0.18187,
  'Recall@100': 0.53806,
  'Recall@1000': 0.76265},
 {'P@1': 0.7963,
  'P@3': 0.75309,
  'P@5': 0.72222,
  'P@10': 0.60556,
  'P@100': 0.27833,
  'P@1000': 0.04596})