The following code is the evaluation script for the bi-encoder using MeanMSME Loss based on the paper and found at the following GitHub repository: UKPLab/sentence-transformers

In [1]:
!pip install sentence-transformers

Collecting sentence-transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/86.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting sentencepiece (from sentence-transformers)
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: sentence-transformers
  Building wheel for sentence-transformers (setup.py) ... [?25l[?25hdone
  Created wheel for sentence-transformers: filename=sentence_transformers-2.2.2-py3-none-any.whl size=125923 sha256=9bb586101f323971d953a24a60963ab79b4736da3ad4592273f48afed1a75adc
  Stored in directory: 

In [2]:
import sys

sys.argv[1] = '/content/drive/My Drive/MSMARCO-ClincalBERT'
sys.argv[2] = '100'

In [None]:
"""
This script runs the evaluation of an SBERT msmarco model on the
MS MARCO dev dataset and reports different performances metrices for cossine similarity & dot-product.

Usage:
python eval_msmarco.py model_name [max_corpus_size_in_thousands]

Code taken from https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/ms_marco/eval_msmarco.py
"""

from sentence_transformers import  LoggingHandler, SentenceTransformer, evaluation, util, models
import logging
import sys
import os
import tarfile


#Name of the SBERT model
model_name = 'sravn/msmarco-clincalbert'

# You can limit the approx. max size of the corpus. Pass 100 as second parameter and the corpus has a size of approx 100k docs
corpus_max_size = int(sys.argv[2])*1000 if len(sys.argv) >= 3 else 0


####  Load model

model = SentenceTransformer(model_name)

### Data files
data_folder = 'msmarco-data'
os.makedirs(data_folder, exist_ok=True)

collection_filepath = os.path.join(data_folder, 'collection.tsv')
dev_queries_file = os.path.join(data_folder, 'queries.dev.small.tsv')
qrels_filepath = os.path.join(data_folder, 'qrels.dev.tsv')

### Download files if needed
if not os.path.exists(collection_filepath) or not os.path.exists(dev_queries_file):
    tar_filepath = os.path.join(data_folder, 'collectionandqueries.tar.gz')
    if not os.path.exists(tar_filepath):
        logging.info("Download: "+tar_filepath)
        util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz', tar_filepath)

    with tarfile.open(tar_filepath, "r:gz") as tar:
        tar.extractall(path=data_folder)


if not os.path.exists(qrels_filepath):
    util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/qrels.dev.tsv', qrels_filepath)

### Load data

corpus = {}             #Our corpus pid => passage
dev_queries = {}        #Our dev queries. qid => query
dev_rel_docs = {}       #Mapping qid => set with relevant pids
needed_pids = set()     #Passage IDs we need
needed_qids = set()     #Query IDs we need

# Load the 6980 dev queries
with open(dev_queries_file, encoding='utf8') as fIn:
    for line in fIn:
        qid, query = line.strip().split("\t")
        dev_queries[qid] = query.strip()


# Load which passages are relevant for which queries
with open(qrels_filepath) as fIn:
    for line in fIn:
        qid, _, pid, _ = line.strip().split('\t')

        if qid not in dev_queries:
            continue

        if qid not in dev_rel_docs:
            dev_rel_docs[qid] = set()
        dev_rel_docs[qid].add(pid)

        needed_pids.add(pid)
        needed_qids.add(qid)


# Read passages
with open(collection_filepath, encoding='utf8') as fIn:
    for line in fIn:
        pid, passage = line.strip().split("\t")
        passage = passage

        if pid in needed_pids or corpus_max_size <= 0 or len(corpus) <= corpus_max_size:
            corpus[pid] = passage.strip()



## Run evaluator
logging.info("Queries: {}".format(len(dev_queries)))
logging.info("Corpus: {}".format(len(corpus)))

ir_evaluator = evaluation.InformationRetrievalEvaluator(dev_queries, corpus, dev_rel_docs,
                                                        show_progress_bar=True,
                                                        corpus_chunk_size=100000,
                                                        precision_recall_at_k=[10, 100],
                                                        name="msmarco dev")

ir_evaluator(model)

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/3.80k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/610 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/123 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/539M [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.92M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.31k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

  0%|          | 0.00/1.06G [00:00<?, ?B/s]

  0%|          | 0.00/1.20M [00:00<?, ?B/s]

Batches:   0%|          | 0/219 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 2/2 [5:28:05<00:00, 9842.94s/it]


0.8108492163527864

In [4]:
ir_evaluator = evaluation.InformationRetrievalEvaluator(dev_queries, corpus, dev_rel_docs,
                                                        show_progress_bar=True,
                                                        corpus_chunk_size=100000,
                                                        precision_recall_at_k=[10, 100],
                                                        name="msmarco dev")

In [9]:
metrics = ir_evaluator.compute_metrices(model)

Batches:   0%|          | 0/219 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 2/2 [05:45<00:00, 172.59s/it]


In [57]:
print('The mertics for the Clincal BERT model are as follows:')
print("\nAccuracy: ")
print(metrics['cos_sim']['accuracy@k'][10])
print("\nPrecision: ")
print(metrics['cos_sim']['precision@k'][10])
print("\nRecall: ")
print(metrics['cos_sim']['recall@k'][10])
print("\nMean Reciprocal Rank: ")
print(metrics['cos_sim']['mrr@k'][10])
print("\nMean Average Precision: ")
print(metrics['cos_sim']['mrr@k'][10])
print("\nNDCG: ")
print(metrics['cos_sim']['ndcg@k'][10])

The mertics for the Clincal BERT model are as follows:

Accuracy: 
0.9290830945558739

Precision: 
0.09802292263610315

Recall: 
0.9240210124164278

Mean Reciprocal Rank: 
0.815079763041796

Mean Average Precision: 
0.815079763041796

NDCG: 
0.8381790579726461


## Comparison

We can now compare this to the model that was originally trained using the MultipleNegativesRankingLoss.

When trained using distilbert-base-uncased model, it should achieve a performance of about 33.79 MRR@10 on the MSMARCO Passages Dev-Corpus