# Training a Matryoshka Embedding Model 🪆

It uses `MultipleNegativesRankingLoss` with `MatryoshkaLoss` to train a strong embedding model at output dimensions `[768, 512, 256, 128, 64]` using Natural Language Inference datasets (`AllNLI` in this case).



> Colab by: [mrm8488](https://twitter.com/mrm8488) adapted from [Sentence-Transformers](https://www.sbert.net/examples) script

In [1]:
! nvidia-smi

Sun Jun  2 14:08:51 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   47C    P8               9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

### Install required dependencies 📦

In [2]:
! pip install -q sentence-transformers datasets "accelerate>=0.21.0" wandb

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.7/224.7 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.6/302.6 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.7/6.7 MB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━

### Imports

In [3]:
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    losses,
)
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SequentialEvaluator, SimilarityFunction
from sentence_transformers.training_args import BatchSamplers

### Set main variables ⚙️

In [4]:
model_name = "distilroberta-base" # Choose the model you want
batch_size = 128  # The larger you select this, the better the results (usually). But it requires more GPU memory
num_train_epochs = 1
matryoshka_dims = [768, 512, 256, 128, 64]

In [5]:
# Save path of the model
output_dir = f"output/matryoshka_nli_{model_name.replace('/', '-')}_{batch_size}_bs_{num_train_epochs}_e"

In [6]:
# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically
# create one with "mean" pooling.
model = SentenceTransformer(model_name)
# If we want, we can limit the maximum sequence length for the model
# model.max_seq_length = 75



config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/331M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

### Load the Dataset 📚

In [7]:
# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli
train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")

Downloading readme:   0%|          | 0.00/5.15k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/38.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/782k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/810k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/557850 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/6584 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6609 [00:00<?, ? examples/s]

In [8]:
train_dataset, train_dataset[0]

(Dataset({
     features: ['anchor', 'positive', 'negative'],
     num_rows: 557850
 }),
 {'anchor': 'A person on a horse jumps over a broken down airplane.',
  'positive': 'A person is outdoors, on a horse.',
  'negative': 'A person is at a diner, ordering an omelette.'})

#### (Optional) Training on the entire dataset can take a long time, so for demonstration purposes, let's use only a small portion.



In [9]:
MAX_EXAMPLES = 10000
train_dataset = train_dataset.shuffle(seed=21).select(range(MAX_EXAMPLES))

### Define our training loss functions 📉

In [10]:
inner_train_loss = losses.MultipleNegativesRankingLoss(model)
train_loss = losses.MatryoshkaLoss(model, inner_train_loss, matryoshka_dims=matryoshka_dims)

### Set an evaluator to keep track of alongside the evaluation loss.

In [11]:
stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
evaluators = []
for dim in matryoshka_dims:
    evaluators.append(
        EmbeddingSimilarityEvaluator(
            sentences1=stsb_eval_dataset["sentence1"],
            sentences2=stsb_eval_dataset["sentence2"],
            scores=stsb_eval_dataset["score"],
            main_similarity=SimilarityFunction.COSINE,
            name=f"sts-dev-{dim}",
            truncate_dim=dim,
        )
    )

Downloading readme:   0%|          | 0.00/1.50k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/471k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/142k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/108k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5749 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1500 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1379 [00:00<?, ? examples/s]

In [12]:
dev_evaluator = SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[0])

### Define the training args ⚙️

In [13]:
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=output_dir,
    # Optional training parameters:
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=30,
    save_strategy="steps",
    save_steps=30,
    save_total_limit=2,
    logging_steps=30,
    run_name="matryoshka-nli_128_bs_1e",  # Will be used in W&B if `wandb` is installed
)

### Create the Trainer and run it 🏋️‍♀️

In [14]:
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
    evaluator=dev_evaluator,
)

In [15]:
%%time

trainer.train()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Step,Training Loss,Validation Loss,Sts-dev-768 Pearson Cosine,Sts-dev-768 Spearman Cosine,Sts-dev-768 Pearson Manhattan,Sts-dev-768 Spearman Manhattan,Sts-dev-768 Pearson Euclidean,Sts-dev-768 Spearman Euclidean,Sts-dev-768 Pearson Dot,Sts-dev-768 Spearman Dot,Sts-dev-768 Pearson Max,Sts-dev-768 Spearman Max,Sts-dev-512 Pearson Cosine,Sts-dev-512 Spearman Cosine,Sts-dev-512 Pearson Manhattan,Sts-dev-512 Spearman Manhattan,Sts-dev-512 Pearson Euclidean,Sts-dev-512 Spearman Euclidean,Sts-dev-512 Pearson Dot,Sts-dev-512 Spearman Dot,Sts-dev-512 Pearson Max,Sts-dev-512 Spearman Max,Sts-dev-256 Pearson Cosine,Sts-dev-256 Spearman Cosine,Sts-dev-256 Pearson Manhattan,Sts-dev-256 Spearman Manhattan,Sts-dev-256 Pearson Euclidean,Sts-dev-256 Spearman Euclidean,Sts-dev-256 Pearson Dot,Sts-dev-256 Spearman Dot,Sts-dev-256 Pearson Max,Sts-dev-256 Spearman Max,Sts-dev-128 Pearson Cosine,Sts-dev-128 Spearman Cosine,Sts-dev-128 Pearson Manhattan,Sts-dev-128 Spearman Manhattan,Sts-dev-128 Pearson Euclidean,Sts-dev-128 Spearman Euclidean,Sts-dev-128 Pearson Dot,Sts-dev-128 Spearman Dot,Sts-dev-128 Pearson Max,Sts-dev-128 Spearman Max,Sts-dev-64 Pearson Cosine,Sts-dev-64 Spearman Cosine,Sts-dev-64 Pearson Manhattan,Sts-dev-64 Spearman Manhattan,Sts-dev-64 Pearson Euclidean,Sts-dev-64 Spearman Euclidean,Sts-dev-64 Pearson Dot,Sts-dev-64 Spearman Dot,Sts-dev-64 Pearson Max,Sts-dev-64 Spearman Max,Sequential Score
30,15.8875,6.108927,0.799077,0.807647,0.79807,0.797459,0.799269,0.798511,0.56837,0.585154,0.799269,0.807647,0.809985,0.814285,0.798351,0.797482,0.800138,0.798986,0.653775,0.667887,0.809985,0.814285,0.806134,0.812255,0.795961,0.79594,0.797647,0.797025,0.644396,0.663916,0.806134,0.812255,0.791713,0.803571,0.790397,0.791489,0.789728,0.790727,0.614007,0.629209,0.791713,0.803571,0.785611,0.801032,0.779371,0.783714,0.778577,0.783511,0.586901,0.607513,0.785611,0.801032,0.799077
60,7.4874,5.018856,0.817016,0.825598,0.808516,0.809327,0.809008,0.809883,0.578785,0.605082,0.817016,0.825598,0.821714,0.827731,0.808578,0.809039,0.809639,0.809734,0.636018,0.659442,0.821714,0.827731,0.817892,0.825704,0.807214,0.808086,0.807556,0.808237,0.631591,0.653092,0.817892,0.825704,0.808365,0.819527,0.801712,0.80441,0.800451,0.803252,0.605492,0.632383,0.808365,0.819527,0.796902,0.813825,0.790742,0.795316,0.78866,0.793532,0.536988,0.552889,0.796902,0.813825,0.817016


Computing widget examples:   0%|          | 0/5 [00:00<?, ?example/s]

Computing widget examples:   0%|          | 0/5 [00:00<?, ?example/s]

CPU times: user 1min 17s, sys: 8.27 s, total: 1min 26s
Wall time: 11min 32s


TrainOutput(global_step=79, training_loss=10.388897183575208, metrics={'train_runtime': 691.7537, 'train_samples_per_second': 14.456, 'train_steps_per_second': 0.114, 'total_flos': 0.0, 'train_loss': 10.388897183575208, 'epoch': 1.0})

### Evaluate on the STS Benchmark test dataset 🧪

In [16]:
test_dataset = load_dataset("sentence-transformers/stsb", split="test")
evaluators = []
for dim in matryoshka_dims:
    evaluators.append(
        EmbeddingSimilarityEvaluator(
            sentences1=test_dataset["sentence1"],
            sentences2=test_dataset["sentence2"],
            scores=test_dataset["score"],
            main_similarity=SimilarityFunction.COSINE,
            name=f"sts-test-{dim}",
            truncate_dim=dim,
        )
    )

In [17]:
test_evaluator = SequentialEvaluator(evaluators)

In [18]:
test_evaluator(model)

{'sts-test-768_pearson_cosine': 0.7830303669378891,
 'sts-test-768_spearman_cosine': 0.7773625997426432,
 'sts-test-768_pearson_manhattan': 0.7760379804905847,
 'sts-test-768_spearman_manhattan': 0.7571500188418279,
 'sts-test-768_pearson_euclidean': 0.776793987384272,
 'sts-test-768_spearman_euclidean': 0.7576769993000992,
 'sts-test-768_pearson_dot': 0.5696917713192656,
 'sts-test-768_spearman_dot': 0.5537799075128554,
 'sts-test-768_pearson_max': 0.7830303669378891,
 'sts-test-768_spearman_max': 0.7773625997426432,
 'sts-test-512_pearson_cosine': 0.7907973459486692,
 'sts-test-512_spearman_cosine': 0.7782644020369065,
 'sts-test-512_pearson_manhattan': 0.7764511872615909,
 'sts-test-512_spearman_manhattan': 0.7566408579053339,
 'sts-test-512_pearson_euclidean': 0.7782953766500842,
 'sts-test-512_spearman_euclidean': 0.7586949913092054,
 'sts-test-512_pearson_dot': 0.6258186284156914,
 'sts-test-512_spearman_dot': 0.6181773438089058,
 'sts-test-512_pearson_max': 0.7907973459486692,
 

### Save the model locally

In [19]:
final_output_dir = f"{output_dir}/final"
model.save(final_output_dir)

Computing widget examples:   0%|          | 0/5 [00:00<?, ?example/s]

### Push to the Hugging Face Hub 🤗
You may need an token. Get it here: https://huggingface.co/settings/tokens

In [21]:
from google.colab import userdata
HF_TOKEN = userdata.get('HF_TOKEN')

In [22]:
model.push_to_hub(f"{model_name}-nli-matryoshka", token=HF_TOKEN)

model.safetensors:   0%|          | 0.00/328M [00:00<?, ?B/s]

'https://huggingface.co/eagle0504/distilroberta-base-nli-matryoshka/commit/00b880aab93090634aab808ab4f16c5ae859d563'