In [None]:
%pip install -q transformers

Choosing Model

In [None]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("abhinand/MedEmbed-base-v0.1")

In [None]:
from sentence_transformers import util
sentences = [
    "What are the symptoms of diabetes?",
    "Diabetes symptoms include increased thirst and frequent urination.",
    "Hypertension is often called the silent killer.",
    "Increased blood sugar is a common indicator of diabetes.",
    "Smoking causes cancer.",
    "Smoking kills,"
]
embeddings = model.encode(sentences)
similarities = util.cos_sim(embeddings, embeddings)
print("Cosine Similarity Matrix:")
print(similarities)

Cosine Similarity Matrix:
tensor([[1.0000, 0.8373, 0.5017, 0.7146, 0.4323, 0.4423],
        [0.8373, 1.0000, 0.5188, 0.7441, 0.4777, 0.4419],
        [0.5017, 0.5188, 1.0000, 0.6368, 0.5931, 0.6668],
        [0.7146, 0.7441, 0.6368, 1.0000, 0.6027, 0.5534],
        [0.4323, 0.4777, 0.5931, 0.6027, 1.0000, 0.7779],
        [0.4423, 0.4419, 0.6668, 0.5534, 0.7779, 1.0000]])


Evaluating This Embedding Model

In [None]:
import numpy as np
queries = [
    "What are the symptoms of diabetes?",
    "How is hypertension diagnosed?"
    "What are the effects of smoking?",
]
candidates = [
    "Diabetes symptoms include increased thirst and frequent urination.",
    "Hypertension is often called the silent killer.",
    "Blood pressure measurements help diagnose hypertension.",
    "Increased blood sugar is a common indicator of diabetes.",
    "Smoking causes lung cancer.",
    "Smoking is often called the silent killer.",
]
ground_truth = [
    [0, 3],
    [1, 2],
    [4, 5],
]
query_embeddings = model.encode(queries)
candidate_embeddings = model.encode(candidates)
similarity_matrix = util.cos_sim(query_embeddings, candidate_embeddings).numpy()
mrr_scores = []
for i, query_similarities in enumerate(similarity_matrix):

    ranked_indices = np.argsort(query_similarities)[::-1]
    reciprocal_rank = 0
    for rank, idx in enumerate(ranked_indices, start=1):
        if idx in ground_truth[i]:
            reciprocal_rank = 1 / rank
            break
    mrr_scores.append(reciprocal_rank)
mrr = np.mean(mrr_scores)
print(f"MRR: {mrr:.4f}")

MRR: 1.0000


Ground Truth Cross Evaluation

In [None]:
queries = [
    "What are the symptoms of diabetes?",
]
candidates = [
    "Diabetes symptoms include increased thirst and frequent urination.",
    "Hypertension is often called the silent killer.",
    "Blood pressure measurements help diagnose hypertension.",
    "Increased blood sugar is a common indicator of diabetes.",
    "Smoking causes lung cancer.",
    "Smoking is often called the silent killer.",
]
query_embeddings = model.encode(queries)
candidate_embeddings = model.encode(candidates)
similarity_matrix = util.cos_sim(query_embeddings, candidate_embeddings).numpy()
similarity_threshold = 0.6
ground_truth = []
for i, query_similarities in enumerate(similarity_matrix):
    relevant_indices = [
        idx for idx, similarity in enumerate(query_similarities) if similarity >= similarity_threshold
    ]
    ground_truth.append(relevant_indices)
print("Automatically Derived Ground Truth:", ground_truth)
mrr_scores = []
for i, query_similarities in enumerate(similarity_matrix):
    ranked_indices = np.argsort(query_similarities)[::-1]
    reciprocal_rank = 0
    for rank, idx in enumerate(ranked_indices, start=1):
        if idx in ground_truth[i]:
            reciprocal_rank = 1 / rank
            break
    mrr_scores.append(reciprocal_rank)

mrr = np.mean(mrr_scores)
print(f"MRR: {mrr:.4f}")

Automatically Derived Ground Truth: [[0, 3], [1, 2], [4, 5]]
MRR: 1.0000


In [None]:
queries = [
    "What are the symptoms of diabetes?",
]
candidates = [
    "Diabetes symptoms include increased thirst and frequent urination.",
    "Hypertension is often called the silent killer.",
    "Blood pressure measurements help diagnose hypertension.",
    "Increased blood sugar is a common indicator of diabetes.",
    "Smoking causes lung cancer.",
    "Smoking is often called the silent killer.",
]
query_embeddings = model.encode(queries)
candidate_embeddings = model.encode(candidates)
similarity_matrix = util.cos_sim(query_embeddings, candidate_embeddings).numpy()
similarity_threshold = 0.6
ground_truth = []
for i, query_similarities in enumerate(similarity_matrix):
    relevant_indices = [
        idx for idx, similarity in enumerate(query_similarities) if similarity >= similarity_threshold
    ]
    ground_truth.append(relevant_indices)
print("Automatically Derived Ground Truth:", ground_truth)
mrr_scores = []
for i, query_similarities in enumerate(similarity_matrix):
    ranked_indices = np.argsort(query_similarities)[::-1]
    reciprocal_rank = 0
    for rank, idx in enumerate(ranked_indices, start=1):
        if idx in ground_truth[i]:
            reciprocal_rank = 1 / rank
            break
    mrr_scores.append(reciprocal_rank)

mrr = np.mean(mrr_scores)
print(f"MRR: {mrr:.4f}")

Automatically Derived Ground Truth: [[0, 3]]
MRR: 1.0000


In [None]:
queries = [
    "Define Hypertension?",
]
candidates = [
    "Diabetes symptoms include increased thirst and frequent urination.",
    "Hypertension is often called the silent killer.",
    "Blood pressure measurements help diagnose hypertension.",
    "Increased blood sugar is a common indicator of diabetes.",
    "Smoking causes lung cancer.",
    "Smoking is often called the silent killer.",
]
reference = [
    "Diabetes symptoms include increased thirst and frequent urination",
]
query_embeddings = model.encode(queries)
candidate_embeddings = model.encode(candidates)
ground_truth_embeddings = model.encode(reference)
similarity_matrix = util.cos_sim(query_embeddings, candidate_embeddings).numpy()
mrr_scores = []
for i, query_similarities in enumerate(similarity_matrix):
    ranked_indices = np.argsort(query_similarities)[::-1]
    reciprocal_rank = 0

    for rank, idx in enumerate(ranked_indices, start=1):
        candidate_text = candidates[idx]
        candidate_embedding = candidate_embeddings[idx]

        ground_truth_similarity = util.cos_sim(candidate_embedding, ground_truth_embeddings[i]).item()
        if ground_truth_similarity > 0.9:
            reciprocal_rank = 1 / rank
            break

    mrr_scores.append(reciprocal_rank)

mrr = np.mean(mrr_scores)
print(f"MRR: {mrr:.4f}")

similarity_matrix = util.cos_sim(query_embeddings, candidate_embeddings).numpy()
similarity_threshold = 0.6
ground_truth = []
for i, query_similarities in enumerate(similarity_matrix):
    relevant_indices = [
        idx for idx, similarity in enumerate(query_similarities) if similarity >= similarity_threshold
    ]
    ground_truth.append(relevant_indices)
print("Automatically Derived Ground Truth:", ground_truth)
mrr_scores = []
for i, query_similarities in enumerate(similarity_matrix):
    ranked_indices = np.argsort(query_similarities)[::-1]
    reciprocal_rank = 0
    for rank, idx in enumerate(ranked_indices, start=1):
        if idx in ground_truth[i]:
            reciprocal_rank = 1 / rank
            break
    mrr_scores.append(reciprocal_rank)

mrr = np.mean(mrr_scores)
print(f"MRR: {mrr:.4f}")

MRR: 0.2000
Automatically Derived Ground Truth: [[1, 2]]
MRR: 1.0000


In [None]:
queries = [
    "What are the procedure for autopsy?",
]
candidates = [
    "3-month-old baby died suddenly at night while asleep.",
    "His mother noticed that he had died only after she awoke in the morning. No cause of death was determined based on the autopsy.",
]
reference = [
    "Placing the infant in a supine position on a firm mattress while sleeping",
]
query_embeddings = model.encode(queries)
candidate_embeddings = model.encode(candidates)
ground_truth_embeddings = model.encode(reference)
similarity_matrix = util.cos_sim(query_embeddings, candidate_embeddings).numpy()
mrr_scores = []
for i, query_similarities in enumerate(similarity_matrix):
    ranked_indices = np.argsort(query_similarities)[::-1]
    reciprocal_rank = 0

    for rank, idx in enumerate(ranked_indices, start=1):
        candidate_text = candidates[idx]
        candidate_embedding = candidate_embeddings[idx]

        ground_truth_similarity = util.cos_sim(candidate_embedding, ground_truth_embeddings[i]).item()
        if ground_truth_similarity > 0.9:
            reciprocal_rank = 1 / rank
            break

    mrr_scores.append(reciprocal_rank)

mrr = np.mean(mrr_scores)

similarity_matrix = util.cos_sim(query_embeddings, candidate_embeddings).numpy()
similarity_threshold = 0.6
ground_truth = []
for i, query_similarities in enumerate(similarity_matrix):
    relevant_indices = [
        idx for idx, similarity in enumerate(query_similarities) if similarity >= similarity_threshold
    ]
    ground_truth.append(relevant_indices)
print("Automatically Derived Ground Truth:", ground_truth)
mrr_scores = []
for i, query_similarities in enumerate(similarity_matrix):
    ranked_indices = np.argsort(query_similarities)[::-1]
    reciprocal_rank = 0
    for rank, idx in enumerate(ranked_indices, start=1):
        if idx in ground_truth[i]:
            reciprocal_rank = 1 / rank
            break
    mrr_scores.append(reciprocal_rank)

mrr = np.mean(mrr_scores)
print(f"MRR: {mrr:.4f}")

Automatically Derived Ground Truth: [[1]]
MRR: 1.0000


Evaluation with NDCG and Recall

In [None]:

queries = [
    "What are the symptoms of diabetes?",
]
candidates = [
    "Diabetes symptoms include increased thirst and frequent urination.",
    "Hypertension is often called the silent killer.",
    "Blood pressure measurements help diagnose hypertension.",
    "Increased blood sugar is a common indicator of diabetes.",
    "Smoking causes lung cancer.",
    "Smoking is often called the silent killer.",
]
query_embeddings = model.encode(queries)
candidate_embeddings = model.encode(candidates)
similarity_matrix = util.cos_sim(query_embeddings, candidate_embeddings).numpy()

similarity_threshold = 0.6
ground_truth = []
for i, query_similarities in enumerate(similarity_matrix):
    relevant_indices = [
        idx for idx, similarity in enumerate(query_similarities) if similarity >= similarity_threshold
    ]
    ground_truth.append(relevant_indices)

print("Automatically Derived Ground Truth:", ground_truth)

mrr_scores = []
ndcg_scores = []
recall_scores = []

for i, query_similarities in enumerate(similarity_matrix):
    ranked_indices = np.argsort(query_similarities)[::-1]
    reciprocal_rank = 0
    dcg = 0
    ideal_dcg = 0
    relevant_count = 0

    for rank, idx in enumerate(ranked_indices, start=1):
        relevance = 1 if idx in ground_truth[i] else 0
        dcg += relevance / np.log2(rank + 1)
        if relevance == 1:
            relevant_count += 1

    ideal_dcg = 0
    for rank, idx in enumerate(sorted(ground_truth[i], key=lambda idx: query_similarities[idx], reverse=True), start=1):
        relevance = 1
        ideal_dcg += relevance / np.log2(rank + 1)

    ndcg = dcg / ideal_dcg if ideal_dcg > 0 else 0
    ndcg_scores.append(ndcg)

    for rank, idx in enumerate(ranked_indices, start=1):
        if idx in ground_truth[i]:
            reciprocal_rank = 1 / rank
            break
    mrr_scores.append(reciprocal_rank)

    recall = relevant_count / len(ground_truth[i]) if len(ground_truth[i]) > 0 else 0
    recall_scores.append(recall)

mrr = np.mean(mrr_scores)
ndcg = np.mean(ndcg_scores)
recall = np.mean(recall_scores)

print(f"MRR: {mrr:.4f}")
print(f"NDCG: {ndcg:.4f}")
print(f"Recall: {recall:.4f}")


Automatically Derived Ground Truth: [[0, 3]]
MRR: 1.0000
NDCG: 1.0000
Recall: 1.0000


Mytryoshka Training

In [None]:
!pip install dataset

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, InputExample
from tqdm import tqdm

class MytryoshkaLoss(nn.Module):
    def __init__(self, margin: float = 0.2):
        super(MytryoshkaLoss, self).__init__()
        self.margin = margin

    def forward(self, query_embeddings, positive_embeddings, negative_embeddings):
        positive_similarity = F.cosine_similarity(query_embeddings, positive_embeddings)
        negative_similarity = F.cosine_similarity(query_embeddings, negative_embeddings)
        loss = torch.relu(self.margin + negative_similarity - positive_similarity)
        return loss.mean()

def collate_fn(batch):
    queries = [item.texts[0] for item in batch]
    positives = [item.texts[1] for item in batch]
    negatives = [item.texts[2] for item in batch]

    query_emb = model.encode(queries, convert_to_tensor=True, convert_to_numpy=False)
    pos_emb = model.encode(positives, convert_to_tensor=True, convert_to_numpy=False)
    neg_emb = model.encode(negatives, convert_to_tensor=True, convert_to_numpy=False)

    query_emb.requires_grad_()
    pos_emb.requires_grad_()
    neg_emb.requires_grad_()

    return query_emb, pos_emb, neg_emb

model = SentenceTransformer("abhinand/MedEmbed-base-v0.1")

queries = ["What are the symptoms of diabetes?"]
candidates = [
    "Diabetes symptoms include increased thirst and frequent urination.",
    "Hypertension is often called the silent killer.",
    "Blood pressure measurements help diagnose hypertension.",
    "Increased blood sugar is a common indicator of diabetes.",
    "Smoking causes lung cancer.",
    "Smoking is often called the silent killer.",
]

query_embeddings = model.encode(queries, convert_to_tensor=True, convert_to_numpy=False)
candidate_embeddings = model.encode(candidates, convert_to_tensor=True, convert_to_numpy=False)

similarity_matrix = torch.mm(query_embeddings, candidate_embeddings.T).cpu().numpy()
similarity_threshold = 0.6
ground_truth = [
    [idx for idx, sim in enumerate(similarity_matrix[0]) if sim >= similarity_threshold]
]
print("Automatically Derived Ground Truth:", ground_truth)

train_examples = []
for i, query in enumerate(queries):
    positives = [candidates[idx] for idx in ground_truth[i]]
    negatives = [candidates[idx] for idx in range(len(candidates)) if idx not in ground_truth[i]]
    for pos in positives:
        for neg in negatives:
            train_examples.append(InputExample(texts=[query, pos, neg]))

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8, collate_fn=collate_fn)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
loss_function = MytryoshkaLoss(margin=0.2)

model.train()
epochs = 1

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}")
    epoch_loss = 0
    for batch in tqdm(train_dataloader, desc="Training"):
        query_emb, pos_emb, neg_emb = batch
        loss = loss_function(query_emb, pos_emb, neg_emb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch + 1} Loss: {epoch_loss:.4f}")
model.save("my_model_output")


Automatically Derived Ground Truth: [[0, 3]]
Epoch 1


Training: 100%|██████████| 1/1 [00:01<00:00,  1.32s/it]


Epoch 1 Loss: 0.0000


Evaluating with MRR, NDCG, and *Recall*

In [None]:
import numpy as np
import torch
model = SentenceTransformer("my_model_output")

def evaluate_model(queries, candidates, ground_truth, model):
    query_embeddings = model.encode(queries, convert_to_tensor=True, convert_to_numpy=False)
    candidate_embeddings = model.encode(candidates, convert_to_tensor=True, convert_to_numpy=False)

    similarity_matrix = torch.mm(query_embeddings, candidate_embeddings.T).cpu().numpy()

    mrr_scores = []
    ndcg_scores = []
    recall_scores = []

    for i, query_similarities in enumerate(similarity_matrix):
        ranked_indices = np.argsort(query_similarities)[::-1]
        reciprocal_rank = 0
        dcg = 0
        ideal_dcg = 0
        relevant_count = 0

        for rank, idx in enumerate(ranked_indices, start=1):
            relevance = 1 if idx in ground_truth[i] else 0
            dcg += relevance / np.log2(rank + 1)
            if relevance == 1:
                relevant_count += 1

        ideal_dcg = 0
        for rank, idx in enumerate(sorted(ground_truth[i], key=lambda idx: query_similarities[idx], reverse=True), start=1):
            relevance = 1
            ideal_dcg += relevance / np.log2(rank + 1)

        ndcg = dcg / ideal_dcg if ideal_dcg > 0 else 0
        ndcg_scores.append(ndcg)

        for rank, idx in enumerate(ranked_indices, start=1):
            if idx in ground_truth[i]:
                reciprocal_rank = 1 / rank
                break
        mrr_scores.append(reciprocal_rank)

        recall = relevant_count / len(ground_truth[i]) if len(ground_truth[i]) > 0 else 0
        recall_scores.append(recall)

    mrr = np.mean(mrr_scores)
    ndcg = np.mean(ndcg_scores)
    recall = np.mean(recall_scores)

    return mrr, ndcg, recall

mrr, ndcg, recall = evaluate_model(queries, candidates, ground_truth, model)

print(f"MRR: {mrr:.4f}")
print(f"NDCG: {ndcg:.4f}")
print(f"Recall: {recall:.4f}")


MRR: 1.0000
NDCG: 1.0000
Recall: 1.0000
