In [1]:
import sentence_transformers as st
import pandas as pd

In [2]:
TRAIN_QUESTION_PATH = "gs://scraped-news-article-data-null/fiqa-augmented-trainq.parquet"
TRAIN_DOC_PATH = "gs://scraped-news-article-data-null/fiqa-augmented-traindoc.parquet"
TRAIN_REL_PATH = "gs://scraped-news-article-data-null/fiqa-augmented-trainrel.parquet"
EVAL_QUESTION_PATH = "gs://scraped-news-article-data-null/fiqa-augmented-devq.parquet"
EVAL_DOC_PATH = "gs://scraped-news-article-data-null/fiqa-augmented-devdoc.parquet"
EVAL_REL_PATH = "gs://scraped-news-article-data-null/fiqa-augmented-devrel.parquet"
TEST_QUESTION_PATH = "gs://scraped-news-article-data-null/fiqa-augmented-testq.parquet"
TEST_DOC_PATH = "gs://scraped-news-article-data-null/fiqa-augmented-testdoc.parquet"
TEST_REL_PATH = "gs://scraped-news-article-data-null/fiqa-augmented-testrel.parquet"

BASE_MODEL = "thenlper/gte-base"

In [3]:
train_q = pd.read_parquet(TRAIN_QUESTION_PATH)
train_doc = pd.read_parquet(TRAIN_DOC_PATH)
train_rel = pd.read_parquet(TRAIN_REL_PATH)
train_rel.head()

Unnamed: 0,query_id,doc_id,relevance
0,0,18850,1.0
1,4,196463,1.0
2,5,69306,1.0
3,6,560251,1.0
4,6,188530,1.0


In [4]:
test_q = pd.read_parquet(TEST_QUESTION_PATH)
test_doc = pd.read_parquet(TEST_DOC_PATH)
test_rel = pd.read_parquet(TEST_REL_PATH)
test_rel.head()

Unnamed: 0,query_id,doc_id,relevance
0,8,566392,1.0
1,8,65404,1.0
2,15,325273,1.0
3,18,88124,1.0
4,26,285255,1.0


In [5]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator


def to_input_examples(questions, documents, relations):
    positives = pd.merge(left=questions, right=relations, on="query_id")
    positives = pd.merge(left=positives, right=documents, on="doc_id")
    for _, row in positives.iterrows():
        yield st.InputExample(texts=[row["query_text"], row["doc_text"]], label=1)


def to_retrieval_evaluator(questions, documents, relations, **kwargs):
    q_dict = {}
    doc_dict = {}
    rel_dict = {}
    for _, row in questions.iterrows():
        q_dict[str(row["query_id"])] = row["query_text"]
    for _, row in documents.iterrows():
        doc_dict[str(row["doc_id"])] = row["doc_text"]
    relations = relations.copy()
    relations["doc_id_list"] = relations.doc_id.apply(lambda x: [str(x)])
    relations_grouped = relations[["query_id", "doc_id_list"]].groupby("query_id").sum().reset_index()
    for _, row in relations_grouped.iterrows():
        rel_dict[str(row["query_id"])] = set(row["doc_id_list"])
    return InformationRetrievalEvaluator(queries=q_dict, corpus=doc_dict, relevant_docs=rel_dict, **kwargs)

for i in to_input_examples(train_q, train_doc, train_rel):
    print(i)
    break

<InputExample> label: 1, texts: What is considered a business expense on a business trip?; The IRS Guidance pertaining to the subject.  In general the best I can say is your business expense may be deductible.  But it depends on the circumstances and what it is you want to deduct. Travel Taxpayers who travel away from home on business may deduct related   expenses, including the cost of reaching their destination, the cost   of lodging and meals and other ordinary and necessary expenses.   Taxpayers are considered “traveling away from home” if their duties   require them to be away from home substantially longer than an   ordinary day’s work and they need to sleep or rest to meet the demands   of their work. The actual cost of meals and incidental expenses may be   deducted or the taxpayer may use a standard meal allowance and reduced   record keeping requirements. Regardless of the method used, meal   deductions are generally limited to 50 percent as stated earlier.    Only actual cos

In [6]:
test_set_evaluator = to_retrieval_evaluator(test_q, test_doc, test_rel,
                                            show_progress_bar=True, ndcg_at_k=[1, 3, 5, 10], mrr_at_k=[1, 3, 5, 10])
base_model = st.SentenceTransformer(model_name_or_path=BASE_MODEL)
pd.DataFrame(test_set_evaluator.compute_metrices(model=base_model)['cos_sim'])

Batches:   0%|          | 0/21 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:15<00:00, 15.13s/it]


Unnamed: 0,accuracy@k,precision@k,recall@k,ndcg@k,mrr@k,map@k
1,0.695988,0.695988,0.187505,0.695988,0.695988,
3,0.807099,0.584362,0.440857,0.663874,0.744084,
5,0.851852,0.473765,0.552121,0.652049,0.754347,
10,0.890432,0.311728,0.660335,0.658487,0.759256,
100,,,,,,0.600996
