# Install required modules

In [1]:
!pip install transformers[torch] sentence-transformers datasets

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
from sentence_transformers import SentenceTransformer
import torch

# Load Dataset

In [3]:
from datasets import load_dataset

dataset = load_dataset("bebeyondo/medical-triplet")
train_dataset = dataset["train"]
test_dataset = dataset["test"]
eval_dataset = dataset["val"]

# Remove the metadata columns from all datasets
columns_to_remove = ['pairs_unique_id', 'idx']

# Remove from train dataset
train_dataset = train_dataset.remove_columns(columns_to_remove)

# Remove from test dataset  
test_dataset = test_dataset.remove_columns(columns_to_remove)

# Remove from eval dataset
eval_dataset = eval_dataset.remove_columns(columns_to_remove)

print(train_dataset)
print(test_dataset)
print(eval_dataset)

Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 52838
})
Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 17613
})
Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 17613
})


In [4]:
train_dataset.to_pandas()

Unnamed: 0,anchor,positive,negative
0,sign (Additional file 5 (Short axial section o...,enlargement and ventricular septum deviation o...,"deviation of left ventricle, showing “D” sign ..."
1,endocrinologist’s opinion who advised for eval...,were raised beyond the normal range which was ...,of gestation with singleton pregnancy was refe...
2,check valve.\nWe performed bursectomy and ATFL...,check valve.\nWe performed bursectomy and ATFL...,"a fluctuant mass 5 × 8 cm in size, over the an..."
3,male was presented with his persistent back pa...,used home oxygen therapy of 2 l/min on occasio...,He had suffered from chronic dyspnea on effort...
4,"Ltd., South Africa) which were subsequently in...",slightly oversized left fascio-cutaneous radia...,bridge was removed for repair. All implants we...
...,...,...,...
52833,59.9 IU/mL (normal <40). Blood and urine cultu...,The patient was discharged with good diabetic ...,days and again improved with reduction in her ...
52834,"measures proved ineffective, the Acute Pain Me...",touch from the toes to the ankle. There was dr...,on the soles of her feet in order to promote v...
52835,who presented to our Emergency Department 6 ho...,to our Emergency Department 6 hours after he h...,a small suburban home. He was an active tobacc...
52836,CLL/SLL via lymph node biopsy that was diagnos...,arterial blood gas (ABG) was performed demonst...,with very dark urine. In light of his elevated...


# Load the pretrained model

In [5]:
model = SentenceTransformer("all-mpnet-base-v2")

In [6]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 384, 'do_lower_case': False, 'architecture': 'MPNetModel'})
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)

# Setting the training arguments

In [7]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import TripletLoss
from sentence_transformers.training_args import BatchSamplers

In [9]:
# there 3 parameters model, scale (opt), similarity function(opt)
loss = TripletLoss(model=model)

# def get_dot_product_similarity(a, b):
#     return torch.matmul(a, b.T)  # Matrix multiplication to handle 2D tensors

# For PAIRS training we use MultipleNegativesRankingLoss to compute the probability of all positive entries being the highest
# loss = MultipleNegativesRankingLoss(
#     model,
    # scale=1, # only for dot product, check how the base model was trained.
    # similarity_fct=get_dot_product_similarity #comment out to default to cosine
# )

# # For TRIPLETS training. Initialize the TripletLoss if training on anchor, positive, negative triplets
# loss = TripletLoss(model)

#Other option
# Initialize CoSENTLoss (a variant of cosine loss that enforces the relative similarity ranking of pairs in a batch):

In [14]:
args = SentenceTransformerTrainingArguments(
    output_dir="models/ft-all-mpnet-base-v2",
    num_train_epochs=3,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    warmup_ratio=0.1,
    fp16=False, #set it according to your GPU hardware
    bf16=True, # #set it according to your GPU hardware
    batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    learning_rate=1e-5,
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    logging_steps=500,
    run_name="ft-all-mpnet-base-v2" # will be useful in W&B if installed and tracked
)

In [15]:
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss
)

In [None]:
trainer.train()

Step,Training Loss,Validation Loss


In [13]:
from sentence_transformers.evaluation import TripletEvaluator

test_evaluator = TripletEvaluator(
    anchors=test_dataset["anchor"],
    positives=test_dataset["positive"],
    negatives=test_dataset["negative"],
    name="ft-all-mpnet-base-v2-test",
)

test_evaluator(model)

{'ft-all-mpnet-base-v2-test_cosine_accuracy': 0.1668086051940918}

In [None]:
# Save the trained model
model.save_pretrained("models/ft-all-mpnet-base-v2/final")

# (Optional) Push it to the Hugging Face Hub
model.push_to_hub("ft-all-mpnet-base-v2-triplet")