In [None]:
!pip install datasets


Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.wh

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
from datasets import load_dataset
from transformers import (
    AutoConfig,
    AutoModel,
    get_linear_schedule_with_warmup,
    TrainerCallback
)
from torch.optim import SGD, RMSprop, AdamW
from sentence_transformers import (
    models,
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import TripletLoss, MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator, SimilarityFunction


In [None]:
big_patent_dataset = load_dataset("json", data_files="/content/dataset_big_patent_v2.json")


Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
big_patent_dataset


DatasetDict({
    train: Dataset({
        features: ['query', 'pos', 'neg'],
        num_rows: 499
    })
})

In [None]:
split_dataset = big_patent_dataset["train"].train_test_split(test_size=0.3)

# Second split: Temp into Test (10%) and Validation (10%)
temp_split = split_dataset["test"].train_test_split(test_size=0.5)

# Merge into final dataset
final_dataset = {
    "train": split_dataset["train"],
    "test": temp_split["test"],
    "validation": temp_split["train"],
}


In [None]:
train_dataset = final_dataset["train"]
test_dataset = final_dataset["test"]
eval_dataset = final_dataset["validation"]


In [None]:
checkpoint = "OrdalieTech/Solon-embeddings-large-0.1"
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
config.hidden_dropout_prob = 0.5
config.attention_probs_dropout_prob = 0.5

hf_model = AutoModel.from_pretrained(checkpoint, config=config, trust_remote_code=True)

transformer_model = models.Transformer(checkpoint)  # Initialize with the checkpoint
transformer_model.auto_model = hf_model

pooling_model = models.Pooling(transformer_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[transformer_model, pooling_model], trust_remote_code=True)

# # for param in model[0].auto_model.encoder.parameters():
# #     param.requires_grad = False

# model = SentenceTransformer(checkpoint, trust_remote_code=True)

loss = MultipleNegativesRankingLoss(model)


config.json:   0%|          | 0.00/704 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/964 [00:00<?, ?B/s]

In [None]:
num_epochs = 8
num_train_steps = len(train_dataset) * num_epochs
num_warmup_steps = int(0.1 * num_train_steps)
#trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = AdamW(params=model.parameters(), lr=2e-5, weight_decay=0.03)

lr_scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_train_steps
)

optimizers = (optimizer, lr_scheduler)


In [10]:
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/OrdalieTech/Solon-embeddings-large-0.1",
    # Optional training parameters:
    num_train_epochs=num_epochs,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    #fp16=True,  # Set to False if your GPU can't handle FP16
    #bf16=False,  # Set to True if your GPU supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # Losses using "in-batch negatives" benefit from no duplicates
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    save_total_limit=2,
    logging_steps=10,
    #run_name="mpnet-base-all-nli-triplet",  # Used in W&B if `wandb` is installed
)



Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [None]:
dev_evaluator = TripletEvaluator(
    anchors=eval_dataset["query"],
    positives=eval_dataset["pos"],
    negatives=eval_dataset["neg"],
    #main_similarity_function=SimilarityFunction.COSINE,
    name='eval'
)

dev_evaluator(model)


{'eval_cosine_accuracy': 0.7066666483879089}

In [None]:
train_evaluator = TripletEvaluator(
        anchors=train_dataset["query"],
        positives=train_dataset["pos"],
        negatives=train_dataset["neg"],
        #main_similarity_function=SimilarityFunction.COSINE,
        name='train'
    )


In [None]:
# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=[train_evaluator, dev_evaluator],
    optimizers=optimizers,
)
trainer.train()


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss,Validation Loss,Cosine Accuracy,Train Cosine Accuracy,Sequential Score
10,3.547,0.871379,0.706667,0.727794,0.706667
20,3.513,0.871643,0.706667,0.727794,0.706667
30,3.4189,0.872516,0.706667,0.724928,0.706667
40,3.5186,0.872424,0.706667,0.724928,0.706667
50,3.4639,0.872955,0.706667,0.724928,0.706667
60,3.4945,0.873643,0.706667,0.727794,0.706667
70,3.4554,0.876967,0.706667,0.727794,0.706667
80,3.4273,0.878633,0.72,0.724928,0.72
90,3.429,0.877739,0.72,0.722063,0.72
100,3.4658,0.881177,0.72,0.724928,0.72


TrainOutput(global_step=176, training_loss=3.4540568698536265, metrics={'train_runtime': 673.3995, 'train_samples_per_second': 4.146, 'train_steps_per_second': 0.261, 'total_flos': 0.0, 'train_loss': 3.4540568698536265, 'epoch': 8.0})

In [None]:
test_evaluator = TripletEvaluator(
    anchors=test_dataset["query"],
    positives=test_dataset["pos"],
    negatives=test_dataset["neg"],
    #main_similarity_function=SimilarityFunction.COSINE,
    name="test",
)
test_evaluator(model)


{'test_cosine_accuracy': 0.7333333492279053}

In [None]:
model.eval()


SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: XLMRobertaModel 
  (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

## References
- https://huggingface.co/blog/train-sentence-transformers
- https://www.marqo.ai/course/training-fine-tuning-sentence-transformers
- https://dagshub.com/blog/how-to-train-a-custom-llm-embedding-model/
- https://medium.com/@aisagescribe/multiple-negative-ranking-loss-mnrl-explained-5b4741e38d8f