In [None]:
!pip install transformers[torch] datasets

In [2]:
!pip install -U sentence-transformers

Collecting sentence-transformers
  Downloading sentence_transformers-3.0.1-py3-none-any.whl (227 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/227.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━[0m [32m163.8/227.1 kB[0m [31m4.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m227.1/227.1 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: sentence-transformers
Successfully installed sentence-transformers-3.0.1


In [3]:
from sentence_transformers import SentenceTransformer
import torch

  from tqdm.autonotebook import tqdm, trange


In [4]:
from datasets import load_dataset
dataset = load_dataset("sentence-transformers/all-nli", "triplet")
train_dataset = dataset["train"].select(range(5000))
eval_dataset = dataset["dev"]
test_dataset = dataset["test"]

Downloading readme:   0%|          | 0.00/5.15k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/38.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/782k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/810k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/557850 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/6584 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6609 [00:00<?, ? examples/s]

In [5]:
train_dataset

Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 5000
})

In [6]:
train_dataset.to_pandas()
# Anchor: The original sentence or query.
# Positive answer: A correct or relevant response to the anchor.
# Negative answer: An incorrect or irrelevant response to the anchor

Unnamed: 0,anchor,positive,negative
0,A person on a horse jumps over a broken down a...,"A person is outdoors, on a horse.","A person is at a diner, ordering an omelette."
1,Children smiling and waving at camera,There are children present,The kids are frowning
2,A boy is jumping on skateboard in the middle o...,The boy does a skateboarding trick.,The boy skates down the sidewalk.
3,Two blond women are hugging one another.,There are women showing affection.,The women are sleeping.
4,"A few people in a restaurant setting, one of t...",The diners are at a restaurant.,The people are sitting at desks in school.
...,...,...,...
4995,The people are outside.,People on ATVs and dirt bikes are traveling al...,A woman in a pink shirt is handing a bag to th...
4996,The people are outside.,People on ATVs and dirt bikes are traveling al...,A small group of adult males enjoy a conversat...
4997,The people are outside.,People on ATVs and dirt bikes are traveling al...,Two guys and one girl are sitting at a table i...
4998,The people are outside.,People on ATVs and dirt bikes are traveling al...,People sitting on black chairs on a bus.


In [7]:
model = SentenceTransformer("BAAI/bge-large-en")

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/90.3k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/720 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/191 [00:00<?, ?B/s]

In [8]:
# Setting up Training Arguments

from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

In [9]:

# 3. Define a loss function
loss = MultipleNegativesRankingLoss(model)

In [10]:
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/mpnet-base-all-nli-triplet",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="mpnet-base-all-nli-triplet",  # Will be used in W&B if `wandb` is installed
)

In [11]:
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss
)

In [12]:
trainer.train()

Step,Training Loss,Validation Loss
100,0.6642,0.63994
200,0.1421,0.817364
300,0.3632,1.056603


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

TrainOutput(global_step=313, training_loss=0.42413954079722443, metrics={'train_runtime': 520.1962, 'train_samples_per_second': 9.612, 'train_steps_per_second': 0.602, 'total_flos': 0.0, 'train_loss': 0.42413954079722443, 'epoch': 1.0})

In [13]:
from sentence_transformers.evaluation import TripletEvaluator

test_evaluator = TripletEvaluator(
    anchors=test_dataset["anchor"],
    positives=test_dataset["positive"],
    negatives=test_dataset["negative"],
    name="all-nli-test",
)

test_evaluator(model)

{'all-nli-test_cosine_accuracy': 0.8853079134513542,
 'all-nli-test_dot_accuracy': 0.11469208654864578,
 'all-nli-test_manhattan_accuracy': 0.885761839915267,
 'all-nli-test_euclidean_accuracy': 0.8853079134513542,
 'all-nli-test_max_accuracy': 0.885761839915267}

In [None]:

# 8. Save the trained model
model.save_pretrained("models/wz-mpnet-base-all-nli-triplet/final")

# 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub("mpnet-base-all-nli-triplet")