## Setup

In [None]:
%%capture
!pip install -q accelerate peft bitsandbytes transformers trl sentencepiece
!pip install -q sentence-transformers mteb datasets

In [1]:
from datasets import load_dataset, Dataset
from sentence_transformers import SentenceTransformer, models
from sentence_transformers import losses
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.datasets import DenoisingAutoEncoderDataset
from mteb import MTEB

import gc
import torch
import random
from tqdm import tqdm
import pandas as pd
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


## Dataset

In [2]:
# Load MNLI dataset from GLUE
# 0 = entailment, 1 = neutral, 2 = contradiction
train_dataset = load_dataset("glue", "mnli", split="train").select(range(50_000))
train_dataset = train_dataset.remove_columns("idx")

In [3]:
train_dataset[2]

{'premise': 'One of our number will carry out your instructions minutely.',
 'hypothesis': 'A member of my team will execute your orders with immense precision.',
 'label': 0}

## Model

In [4]:
device = torch.device("mps")
# Use a base model
embedding_model = SentenceTransformer("bert-base-uncased", device=device)
# embedding_model=embedding_model.to(device)

No sentence-transformers model found with name bert-base-uncased. Creating a new one with mean pooling.


In [5]:
embedding_model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

## Loss Function

In [6]:
# Define the loss function. In soft-max loss, we will also need to explicitly set the number of labels.
train_loss = losses.SoftmaxLoss(
    model=embedding_model, 
    sentence_embedding_dimension=embedding_model.get_sentence_embedding_dimension(), 
    num_labels=3
)

## Evaluation

In [7]:
# Create an embedding similarity evaluator for stsb
val_sts = load_dataset("glue", "stsb", split="validation")
evaluator = EmbeddingSimilarityEvaluator(
    sentences1=val_sts["sentence1"],
    sentences2=val_sts["sentence2"],
    scores=[score/5 for score in val_sts["label"]],
    main_similarity="cosine",
)

## Training

In [None]:
# Define the training arguments
training_args = SentenceTransformerTrainingArguments(
    output_dir="base_embedding_model",
    num_train_epochs=1,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    warmup_steps=100,
    fp16=False,
    eval_steps=100,
    logging_steps=100, 
)

In [9]:
# Train embedding model
trainer = SentenceTransformerTrainer(
    model=embedding_model, 
    args=training_args, 
    train_dataset=train_dataset, 
    loss=train_loss,
    evaluator=evaluator,
)

In [None]:
trainer.train()

  0%|          | 0/6250 [00:00<?, ?it/s]Column 'hypothesis' is at index 1, whereas a column with this name is usually expected at index 0. Note that the column order can be important for some losses, e.g. MultipleNegativesRankingLoss will always consider the first column as the anchor and the second as the positive, regardless of the dataset column names. Consider renaming the columns to match the expected order, e.g.:
dataset = dataset.select_columns(['hypothesis', 'entailment', 'contradiction'])
  2%|▏         | 100/6250 [00:34<30:32,  3.36it/s]

{'loss': 1.0957, 'grad_norm': 4.804290771484375, 'learning_rate': 5e-05, 'epoch': 0.02}


  3%|▎         | 200/6250 [01:03<28:43,  3.51it/s]

{'loss': 1.013, 'grad_norm': 3.2441935539245605, 'learning_rate': 4.9186991869918704e-05, 'epoch': 0.03}


  5%|▍         | 300/6250 [01:32<27:26,  3.61it/s]

{'loss': 0.9379, 'grad_norm': 4.932455539703369, 'learning_rate': 4.8373983739837406e-05, 'epoch': 0.05}


  6%|▋         | 400/6250 [02:01<30:15,  3.22it/s]

{'loss': 0.8927, 'grad_norm': 9.76833438873291, 'learning_rate': 4.75609756097561e-05, 'epoch': 0.06}


  8%|▊         | 500/6250 [02:36<37:36,  2.55it/s]

{'loss': 0.9041, 'grad_norm': 11.443574905395508, 'learning_rate': 4.6747967479674795e-05, 'epoch': 0.08}


 10%|▉         | 600/6250 [03:11<35:38,  2.64it/s]  

{'loss': 0.819, 'grad_norm': 7.497768402099609, 'learning_rate': 4.59349593495935e-05, 'epoch': 0.1}


 11%|█         | 700/6250 [03:45<31:07,  2.97it/s]

{'loss': 0.8565, 'grad_norm': 9.638748168945312, 'learning_rate': 4.51219512195122e-05, 'epoch': 0.11}


 13%|█▎        | 800/6250 [04:17<28:34,  3.18it/s]

{'loss': 0.8708, 'grad_norm': 8.749395370483398, 'learning_rate': 4.43089430894309e-05, 'epoch': 0.13}


 14%|█▍        | 900/6250 [04:51<28:20,  3.15it/s]

{'loss': 0.8422, 'grad_norm': 4.671433448791504, 'learning_rate': 4.3495934959349595e-05, 'epoch': 0.14}


 16%|█▌        | 1000/6250 [05:24<32:32,  2.69it/s]

{'loss': 0.8232, 'grad_norm': 6.3639044761657715, 'learning_rate': 4.26829268292683e-05, 'epoch': 0.16}


 16%|█▌        | 1002/6250 [05:26<54:46,  1.60it/s]  