In [1]:
from tevatron.datasets.dataset import load_dataset

2022-11-24 03:01:30.242847: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


In [2]:
dataset = load_dataset("Tevatron/msmarco-passage",
                      "default",
                      data_files=None, cache_dir=None)

Found cached dataset msmarco-passage (/home/y247xie/.cache/huggingface/datasets/Tevatron___msmarco-passage/default/0.0.1/1874f5d9ae5257b9dbc7d8f89c76f8d4c321be6b660bb5df208e5e64decfa978)


  0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
corpus_fiqa = load_dataset("Tevatron/beir-corpus",
                      "fiqa",
                      data_files=None, cache_dir=None)

Found cached dataset beir-corpus (/home/y247xie/.cache/huggingface/datasets/Tevatron___beir-corpus/fiqa/1.1.0/02e1318cd9412cdf85d3f039bf36bec0af49ddeeab2279d4cf19fe556af6f29a)


  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-mpnet-base-v2')




## Calculate MSMARCO document embedding

In [5]:
%%time

import tqdm
docid2doc_embedding = {}
doc_batch = []
docid_batch = []
docs_msmarco = []
docids_msmarco = []
doc_embeddings_msmarco = []
docids = set()
for example in dataset['train']:
    for doc in example['positive_passages']:
        if doc['docid'] in docids:
            continue
        docids.add(doc['docid'])
        doc_batch.append(doc['text'])
        docid_batch.append(doc['docid'])
        docs_msmarco.append(doc['text'])
        docids_msmarco.append(doc['docid'])
        if len(docid_batch) == 512:
            doc_embeddings = model.encode(doc_batch)
            for i, docid in enumerate(docid_batch):
                docid2doc_embedding[docid] = doc_embeddings[i]
                doc_embeddings_msmarco.append(doc_embeddings[i])
            doc_batch = []
            docid_batch = []

if len(docid_batch) > 0 :
    doc_embeddings = model.encode(doc_batch)
    for i, docid in enumerate(docid_batch):
        docid2doc_embedding[docid] = doc_embeddings[i]
        doc_embeddings_msmarco.append(doc_embeddings[i])
        


CPU times: user 1h 16min 15s, sys: 3min 12s, total: 1h 19min 27s
Wall time: 26min 29s


## Calculate fiqa document embedding

In [10]:
# get fiqa domain embedding
import tqdm
docid2doc_embedding_fiqa = {}
doc_batch = []
docid_batch = []
for example in tqdm.tqdm(corpus_fiqa['train']):
    docid = example['docid']
    text = example['text']
    doc_batch.append(text)
    docid_batch.append(docid)
    if len(docid_batch) == 512:
        doc_embeddings = model.encode(doc_batch)
        for i, docid in enumerate(docid_batch):
            docid2doc_embedding_fiqa[docid] = doc_embeddings[i]
        doc_batch = []
        docid_batch = []

if len(docid_batch) > 0:
    doc_embeddings = model.encode(doc_batch)
    for i, docid in enumerate(docid_batch):
        docid2doc_embedding_fiqa[docid] = doc_embeddings[i]
        
import numpy as np
doc_embedding_fiqa = sum(
    docid2doc_embedding_fiqa[docid] for docid in docid2doc_embedding_fiqa) / len(docid2doc_embedding_fiqa)

100%|███████| 57638/57638 [03:57<00:00, 242.39it/s]


In [11]:
import numpy as np
scores = np.dot(doc_embedding_fiqa, np.array(doc_embeddings_msmarco).T)
sim_docids = [(docs_msmarco[i], score, docids_msmarco[i]) for i, score in enumerate(scores)]
sim_docids = sorted(sim_docids, key=lambda x: x[1], reverse=True)

## get qid2score based on docid2score

In [12]:
docid2score = {docid: score for _, score, docid in sim_docids}

qid2query = {}
qid2score = {}
for example in dataset['train']:
    qid = example['query_id']
    qid2query[qid] = example['query']
    for doc in example['positive_passages']:
        score = docid2score[doc['docid']]
        if qid not in qid2score:
            qid2score[qid] = score
        else:
            qid2score[qid] = max(qid2score[qid], score)

In [13]:
scores = sorted([(qid, score) for qid, score in qid2score.items()], key=lambda x:x[1], reverse=True)
for qid, score in scores[:100]:
    print(score, qid2query[qid])

0.1280104 what is warren buffett investing in
0.12713526 who is gregory mannarino
0.12344028 why health care is a market failure
0.12269017 advantages of having money
0.12235844 most expensive stock firms
0.12130409 do energy management systems make sense from a business point of view
0.119668044 it finance definition
0.11939504 define the term opportunity cost
0.11831266 what type of industry is a holding company
0.11812482 how much is ge pension underfunded
0.1180791 what is growth low
0.116239846 preferred stock 1x liquidation preference
0.115887 what challenges faces a financial manager
0.11457144 Good Financial Status definition
0.11410122 what is overcapitalization
0.11396207 define nonfungible role
0.11349765 how obamacare failed
0.11339473 does tesla negotiate price
0.112693764 stock price maximization requires _____.
0.11240898 what is uma investment account
0.1119784 is ge a good buy
0.111773185 amazon wage study
0.11173489 why do people lose money investing?
0.111103155 what

In [14]:
import pickle
pickle.dump(qid2score, open("qid2score_by_doc_fiqa2.pkl", "wb"))

## Calculate segment stats

In [15]:
def add_score(example):
    example["score"] = qid2score[example['query_id']]
    return example

dataset_train = dataset['train'].map(add_score)

  0%|          | 0/400782 [00:00<?, ?ex/s]

In [16]:
import pandas as pd
pd.DataFrame([qid2score[qid] for qid in qid2score]).describe()

Unnamed: 0,0
count,400782.0
mean,0.007696
std,0.018191
min,-0.049535
25%,-0.004508
50%,0.004566
75%,0.016453
max,0.12801


In [17]:
dataset_train = dataset_train.sort('score')

In [18]:
dataset_segments_train = []

score_segments = [0.03, 0.08, 1]
score_segments_i = 0
pre_i = 0
for i in range(len(dataset_train)):
    if dataset_train[i]['score'] > score_segments[score_segments_i]:
        dataset_segments_train.append(dataset_train[pre_i:i])
        score_segments_i += 1
        pre_i = i
        
dataset_segments_train.append(dataset_train[pre_i:])

In [19]:
[len(d['query']) for d in dataset_segments_train]

[356268, 43523, 991]