# Fine tuning

### Load pretrained model

In [52]:
from sentence_transformers import SentenceTransformer

model_id = "BAAI/bge-small-en"
pretrained_model = SentenceTransformer(model_id)

In [53]:
pretrained_model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
)

### Load data

In [28]:
import pickle

In [29]:
with open('./dataset.pkl', 'rb') as f:
    dataset = pickle.load(f) 

In [48]:
examples = []
for data in dataset:
    question = data['question']
    text = data['text']
    example = InputExample(texts=[question, text])
    examples.append(example)

In [58]:
# from sentence_transformers import datasets
from torch.utils.data import DataLoader

batch_size = 10

loader = DataLoader(
    examples, batch_size=batch_size
)

### Define loss

In [59]:
from sentence_transformers import losses

loss = losses.MultipleNegativesRankingLoss(model)

### Train loop

In [60]:
epochs = 1
warmup_steps = int(len(loader) * epochs * 0.1)

model.fit(
    train_objectives=[(loader, loss)],
    epochs=epochs,
    warmup_steps=warmup_steps,
    output_path='mpnet-mnr-squad2',
    show_progress_bar=True
)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/20 [00:00<?, ?it/s]

### Load local trained model

In [54]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('./test_ckpt')

### Evaluate

In [41]:
import pickle

In [42]:
with open('./nodes_pojo_test.pkl', 'rb') as f:
    nodes = pickle.load(f) 

In [44]:
with open('./dataset_test.pkl', 'rb') as f:
    dataset = pickle.load(f) 

In [45]:
queries = {}
for id_, data in enumerate(dataset):
    queries[id_] = data['question']

In [46]:
corpus = {}
for node in nodes:
    if isinstance(node, dict):
        corpus[node['node_id']] = node['text']

In [47]:
relevant_docs = {}
for id_, data in enumerate(dataset):
    relevant_docs[id_] = [data['node_id']]

In [48]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator

In [49]:
eval = InformationRetrievalEvaluator(queries, corpus, relevant_docs)

In [55]:
eval(model)

0.7280796315117085

In [56]:
eval(pretrained_model)

0.6919232861022638