# Fine tuning

### Load pretrained model

In [6]:
from sentence_transformers import SentenceTransformer

In [7]:
model_id = "BAAI/bge-small-en"
model = SentenceTransformer(model_id)



Downloading pytorch_model.bin:   0%|                                                        | 0.00/134M [00:00<?, ?B/s][A[A

Downloading pytorch_model.bin:   8%|███▊                                            | 10.5M/134M [00:25<05:03, 406kB/s][A[A

Downloading pytorch_model.bin:   8%|███▊                                            | 10.5M/134M [00:41<05:03, 406kB/s][A[A

Downloading pytorch_model.bin:  16%|███████▌                                        | 21.0M/134M [00:59<05:24, 347kB/s][A[A

Downloading pytorch_model.bin:  16%|███████▌                                        | 21.0M/134M [01:11<05:24, 347kB/s][A[A

Downloading pytorch_model.bin:  24%|███████████▎                                    | 31.5M/134M [01:32<05:09, 330kB/s][A[A

Downloading pytorch_model.bin:  24%|███████████▎                                    | 31.5M/134M [01:51<05:09, 330kB/s][A[A

Downloading pytorch_model.bin:  31%|███████████████                                 | 41.9M/134M [01:59<04:19

In [8]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

### Define dataloader

In [9]:
import json
import os 

from torch.utils.data import DataLoader
from sentence_transformers import InputExample

In [36]:
TRAIN_DATASET_FPATH = os.getcwd() + '/finetune_data/train_dataset.json'
VAL_DATASET_FPATH = os.getcwd() + '/finetune_data/val_dataset.json'

# We use a very small batchsize to run this toy example on a local machine. 
# This should typically be much larger. 
BATCH_SIZE = 10

In [39]:
with open(TRAIN_DATASET_FPATH, 'r+') as f:
    train_dataset = json.load(f)

with open(VAL_DATASET_FPATH, 'r+') as f:
    val_dataset = json.load(f)

In [40]:
dataset = train_dataset

corpus = dataset['corpus']
queries = dataset['queries']
relevant_docs = dataset['relevant_docs']

examples = []
for query_id, query in queries.items():
    node_id = relevant_docs[query_id][0]
    text = corpus[node_id]
    example = InputExample(texts=[query, text])
    examples.append(example)

In [41]:
loader = DataLoader(
    examples, batch_size=BATCH_SIZE
)

### Define loss

In [42]:
# https://www.sbert.net/docs/package_reference/losses.html#multiplenegativesrankingloss
from sentence_transformers import losses

In [43]:
loss = losses.MultipleNegativesRankingLoss(model)

### Define evaluator 

In [44]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator

In [45]:
dataset = val_dataset

corpus = dataset['corpus']
queries = dataset['queries']
relevant_docs = dataset['relevant_docs']

evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)

### Run training 

In [46]:
EPOCHS = 2

In [51]:
import mlflow

# Definition of of callbak should be after model init
class MLFlowCallback:
    def __init__(self, model):
        self.model = model
    
    def __call__(self, score, epoch, steps) -> None:
        print(self.model)
        print(score, epoch, steps)
        mlflow.log_metric('score', score)
        # https://mlflow.org/docs/latest/tracking/artifacts-stores.html

mlflow_callback = MLFlowCallback(model)
           
warmup_steps = int(len(loader) * EPOCHS * 0.1)

with mlflow.start_run():
    model.fit(
        train_objectives=[(loader, loss)],
        epochs=EPOCHS,
        warmup_steps=warmup_steps,
        output_path='exp_finetune',
        show_progress_bar=True,
        evaluator=evaluator, 
        evaluation_steps=50,
        callback=mlflow_callback 
    )

Epoch:   0%|                                                                                     | 0/2 [00:00<?, ?it/s]
Iteration:   0%|                                                                                | 0/67 [00:00<?, ?it/s][A
Iteration:   1%|█                                                                       | 1/67 [00:18<20:17, 18.45s/it][A
Iteration:   3%|██▏                                                                     | 2/67 [00:38<20:45, 19.16s/it][A
Iteration:   4%|███▏                                                                    | 3/67 [00:58<20:51, 19.56s/it][A
Iteration:   6%|████▎                                                                   | 4/67 [01:17<20:20, 19.38s/it][A
Iteration:   7%|█████▎                                                                  | 5/67 [01:35<19:44, 19.11s/it][A
Iteration:   9%|██████▍                                                                 | 6/67 [01:54<19:18, 18.99s/it][A
Iteration:  10%|███

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)
0.7049405523677822 0 50



Iteration:  75%|████████████████████████████████████████████████████▉                  | 50/67 [19:31<08:57, 31.64s/it][A
Iteration:  76%|██████████████████████████████████████████████████████                 | 51/67 [19:51<08:16, 31.01s/it][A
Iteration:  78%|███████████████████████████████████████████████████████                | 52/67 [20:10<07:35, 30.35s/it][A
Iteration:  79%|████████████████████████████████████████████████████████▏              | 53/67 [20:25<06:53, 29.53s/it][A
Iteration:  81%|█████████████████████████████████████████████████████████▏             | 54/67 [20:41<06:14, 28.83s/it][A
Iteration:  82%|██████████████████████████████████████████████████████████▎            | 55/67 [20:59<05:39, 28.27s/it][A
Iteration:  84%|███████████████████████████████████████████████████████████▎           | 56/67 [21:17<05:05, 27.73s/it][A
Iteration:  85%|████████████████████████████████████████████████████████████▍          | 57/67 [21:33<04:31, 27.10s/it][A
Iteration:  87%

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)
0.7104460850199539 0 -1


Epoch:  50%|█████████████████████████████████████▌                                     | 1/2 [28:28<28:28, 1708.18s/it]
Iteration:   0%|                                                                                | 0/67 [00:00<?, ?it/s][A
Iteration:   1%|█                                                                       | 1/67 [00:18<20:33, 18.70s/it][A
Iteration:   3%|██▏                                                                     | 2/67 [00:38<20:48, 19.20s/it][A
Iteration:   4%|███▏                                                                    | 3/67 [00:58<20:46, 19.48s/it][A
Iteration:   6%|████▎                                                                   | 4/67 [01:17<20:16, 19.30s/it][A
Iteration:   7%|█████▎                                                                  | 5/67 [01:35<19:42, 19.07s/it][A
Iteration:   9%|██████▍                                                                 | 6/67 [01:54<19:17, 18.97s/it][A
Iteration:  10%|███

KeyboardInterrupt: 

In [60]:
### llamaindex

In [78]:
from llama_index import ServiceContext, VectorStoreIndex
from llama_index.schema import TextNode
from llama_index.embeddings import HuggingFaceEmbedding# OpenAIEmbedding

MODEL_PATH = os.getcwd() + r'\exp_finetune'
embed_model = HuggingFaceEmbedding(MODEL_PATH)

print(embed_model)

model_name='C:\\Users\\tempdelta\\Desktop\\temp_l\\exp_finetune' embed_batch_size=10 callback_manager=<llama_index.callbacks.base.CallbackManager object at 0x000001F2820750D0> tokenizer_name='C:\\Users\\tempdelta\\Desktop\\temp_l\\exp_finetune' max_length=512 pooling=<Pooling.CLS: 'cls'> normalize=True query_instruction=None text_instruction=None cache_folder=None


In [92]:
# https://github.com/run-llama/llama_index/issues/10051
top_k = 5
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)

nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()] 
index = VectorStoreIndex(
    nodes, 
    service_context=service_context, 
    show_progress=True
)
retriever = index.as_retriever(similarity_top_k=top_k)

LLM is explicitly disabled. Using MockLLM.






Generating embeddings:   0%|                                                                   | 0/395 [00:00<?, ?it/s][A[A[A[A



Generating embeddings:   3%|█▍                                                        | 10/395 [00:07<04:53,  1.31it/s][A[A[A[A



Generating embeddings:   5%|██▉                                                       | 20/395 [00:15<04:51,  1.29it/s][A[A[A[A

KeyboardInterrupt: 

# 
##index.storage_context.persist(persist_dir="./storage")
##
loaded_index = load_index_from_disk(StorageContext.from_defaults(persist_dir="./storage"))