# Fine-tune Text Embeddings for AI Job Search

Code authored by: Shaw Talebi

[Video link](https://youtu.be/hOLBrIjRAj4) | [Blog link](https://shawhin.medium.com/fine-tuning-text-embeddings-f913b882b11c)<br>
[Dataset](https://huggingface.co/datasets/shawhin/ai-job-embedding-finetuning) | [Fine-tuned Model](https://huggingface.co/shawhin/distilroberta-ai-job-embeddings) <br>
Based on example from [here](https://sbert.net/docs/sentence_transformer/training_overview.html#trainer)

### imports

In [1]:
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator

### import pre-trained model

Various base models are ranked [here](https://sbert.net/docs/sentence_transformer/training_overview.html#best-base-embedding-models)

In [2]:
model_name = "sentence-transformers/all-distilroberta-v1" # acc = 0.88
model = SentenceTransformer(model_name)

# # other models
# model_name = "microsoft/mpnet-base" # acc = 0.57
# model_name = "sentence-transformers/msmarco-bert-base-dot-v5" # acc = 0.09
# model_name = "sentence-transformers/msmarco-distilbert-dot-v5" # acc = 0.13

### load dataset

In [3]:
dataset = load_dataset("shawhin/ai-job-embedding-finetuning")

### evaluate pre-trained model on eval data

In [4]:
evaluator_valid = TripletEvaluator(
    anchors=dataset["validation"]["query"],
    positives=dataset["validation"]["job_description_pos"],
    negatives=dataset["validation"]["job_description_neg"],
    name="ai-job-validation",
)
evaluator_valid(model)

{'ai-job-validation_cosine_accuracy': np.float64(0.8811881188118812)}

### define loss function

In [5]:
loss = MultipleNegativesRankingLoss(model)

### define training args

In [6]:
num_epochs = 1
batch_size = 16
lr = 2e-5
finetuned_model_name = "distilroberta-ai-job-embeddings"

train_args = SentenceTransformerTrainingArguments(
    output_dir=f"models/{finetuned_model_name}",
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=lr,
    warmup_ratio=0.1,
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="steps",
    eval_steps=100,
    logging_steps=100,
)

### fine-tune model

In [7]:
%%time
trainer = SentenceTransformerTrainer(
    model=model,
    args=train_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    loss=loss,
    evaluator=evaluator_valid,
)
trainer.train()

Step,Training Loss,Validation Loss


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

CPU times: user 18.9 s, sys: 4.23 s, total: 23.2 s
Wall time: 37.3 s


TrainOutput(global_step=51, training_loss=0.8226340050790825, metrics={'train_runtime': 36.3893, 'train_samples_per_second': 22.232, 'train_steps_per_second': 1.402, 'total_flos': 0.0, 'train_loss': 0.8226340050790825, 'epoch': 1.0})

### evaluate fine-tuned model

In [8]:
evaluator_test = TripletEvaluator(
    anchors=dataset["test"]["query"],
    positives=dataset["test"]["job_description_pos"],
    negatives=dataset["test"]["job_description_neg"],
    name="ai-job-test",
)
print("Validation:", evaluator_valid(model))
print("Test:", evaluator_test(model))

Validation: {'ai-job-validation_cosine_accuracy': np.float64(0.9900990099009901)}
Test: {'ai-job-test_cosine_accuracy': np.float64(1.0)}


### push fine-tuned model to HF hub

In [10]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [11]:
model.push_to_hub(f"shawhin/{finetuned_model_name}")

model.safetensors:   0%|          | 0.00/328M [00:00<?, ?B/s]

'https://huggingface.co/shawhin/distilroberta-ai-job-embeddings/commit/5158057316815c5f6415f7afe5060a5a5083d367'