In [None]:
from sentence_transformers import SentenceTransformer
import json
from sentence_transformers import losses
from torch.utils.data import DataLoader
from sentence_transformers import InputExample
from sentence_transformers.evaluation import InformationRetrievalEvaluator
import json
from tqdm.notebook import tqdm
import pandas as pd
from llama_index.core import ServiceContext, VectorStoreIndex
from llama_index.core.schema import TextNode
# from llama_index.embeddings import OpenAIEmbedding

## queries: Dict[str, str],  # qid => query
## corpus: Dict[str, str],  # cid => doc
## relevant_docs: Dict[str, Set[str]],  # qid => Set[cid]

In [None]:
model_id = "BAAI/bge-small-en"
model = SentenceTransformer(model_id)

In [None]:
TRAIN_DATASET_FPATH = './data/train_dataset.json'
VAL_DATASET_FPATH = './data/val_dataset.json'

# We use a very small batchsize to run this toy example on a local machine. 
# This should typically be much larger. 
BATCH_SIZE = 10

In [None]:
with open(TRAIN_DATASET_FPATH, 'r+') as f:
    train_dataset = json.load(f)

with open(VAL_DATASET_FPATH, 'r+') as f:
    val_dataset = json.load(f)

In [None]:
dataset = train_dataset

corpus = dataset['corpus']
queries = dataset['queries']
relevant_docs = dataset['relevant_docs']

examples = []
for query_id, query in queries.items():
    node_id = relevant_docs[query_id][0]
    text = corpus[node_id]
    example = InputExample(texts=[query, text])
    examples.append(example)

In [None]:
loader = DataLoader(
    examples, batch_size=BATCH_SIZE
)

In [None]:

loss = losses.MultipleNegativesRankingLoss(model)

In [None]:

dataset = val_dataset

corpus = dataset['corpus']
queries = dataset['queries']
relevant_docs = dataset['relevant_docs']

evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)

In [None]:
EPOCHS = 2
warmup_steps = int(len(loader) * EPOCHS * 0.1)

model.fit(
    train_objectives=[(loader, loss)],
    epochs=EPOCHS,
    warmup_steps=warmup_steps,
    output_path='exp_finetune',
    show_progress_bar=True,
    evaluator=evaluator, 
    evaluation_steps=50,
)

In [None]:
model.save("models")