In [19]:
import pandas as pd
import matplotlib.pyplot as plt

In [1]:
model_names = [
        "thang1943/bkcare-embed-v2",
        "thang1943/vietnamese-bi-encoder-v2",
        "thang1943/vietnamese-sbert-v2",
]


In [2]:
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim
from datasets import load_dataset, concatenate_datasets
from sentence_transformers import SentenceTransformerModelCardData, SentenceTransformer
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers

from sentence_transformers import SentenceTransformerTrainer
import json

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from datasets import load_dataset

dataset = load_dataset("meandyou200175/dataset_full_fixed", split="train")
dataset = dataset.rename_column("pos", "positive")
dataset = dataset.rename_column("neg", "negative")

def flatten_columns(example):
    example['pos'] = example['pos'][0] if isinstance(example['pos'], list) else example['pos']
    example['neg'] = example['neg'][0] if isinstance(example['neg'], list) else example['neg']
    return example

dataset = dataset.train_test_split(test_size=0.1)

train_dataset = dataset['train']

eval_dataset = dataset['test']

train_dataset = train_dataset.train_test_split(test_size= 1 / 9)['train']
test_dataset = train_dataset.train_test_split(test_size= 1 / 9)['test']

In [4]:
corpus = {}
queries = {}
relevant_docs = {}

for idx, data in enumerate(eval_dataset):
    query_id = f"{2*idx}"
    positive_id = f"{2*idx+1}"

    
    # Add to corpus
    corpus[positive_id] = data['positive']

    
    # Add to queries
    queries[query_id] = data['query']
    
    # Map relevant docs
    relevant_docs[query_id] = [positive_id]

In [5]:
ir_evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name=f"dim_{768}",
    truncate_dim=768,  # Truncate the embeddings to a certain dimension
    score_functions={"cosine": cos_sim},
)

In [6]:
scores = []

In [7]:
for model_id in model_names:
    model = SentenceTransformer(
        model_id,
        device="cuda" if torch.cuda.is_available() else "cpu",
        model_kwargs={"attn_implementation": "sdpa"},
    )

    result = ir_evaluator(model)
    scores.append(result)

In [8]:
print(scores)

[{'dim_768_cosine_accuracy@1': 0.595872899926954, 'dim_768_cosine_accuracy@3': 0.7342951059167275, 'dim_768_cosine_accuracy@5': 0.7870708546384222, 'dim_768_cosine_accuracy@10': 0.8489773557341125, 'dim_768_cosine_precision@1': 0.595872899926954, 'dim_768_cosine_precision@3': 0.24476503530557583, 'dim_768_cosine_precision@5': 0.15741417092768445, 'dim_768_cosine_precision@10': 0.08489773557341124, 'dim_768_cosine_recall@1': 0.595872899926954, 'dim_768_cosine_recall@3': 0.7342951059167275, 'dim_768_cosine_recall@5': 0.7870708546384222, 'dim_768_cosine_recall@10': 0.8489773557341125, 'dim_768_cosine_ndcg@10': 0.7196881790090148, 'dim_768_cosine_mrr@10': 0.6786520835507334, 'dim_768_cosine_map@100': 0.684264196034238}, {'dim_768_cosine_accuracy@1': 0.633308984660336, 'dim_768_cosine_accuracy@3': 0.7529218407596786, 'dim_768_cosine_accuracy@5': 0.7985756026296567, 'dim_768_cosine_accuracy@10': 0.8601168736303871, 'dim_768_cosine_precision@1': 0.633308984660336, 'dim_768_cosine_precision@3'

In [14]:
for model,score in zip(model_names, scores):
    score['model_name'] = model

In [15]:
scores

[{'dim_768_cosine_accuracy@1': 0.595872899926954,
  'dim_768_cosine_accuracy@3': 0.7342951059167275,
  'dim_768_cosine_accuracy@5': 0.7870708546384222,
  'dim_768_cosine_accuracy@10': 0.8489773557341125,
  'dim_768_cosine_precision@1': 0.595872899926954,
  'dim_768_cosine_precision@3': 0.24476503530557583,
  'dim_768_cosine_precision@5': 0.15741417092768445,
  'dim_768_cosine_precision@10': 0.08489773557341124,
  'dim_768_cosine_recall@1': 0.595872899926954,
  'dim_768_cosine_recall@3': 0.7342951059167275,
  'dim_768_cosine_recall@5': 0.7870708546384222,
  'dim_768_cosine_recall@10': 0.8489773557341125,
  'dim_768_cosine_ndcg@10': 0.7196881790090148,
  'dim_768_cosine_mrr@10': 0.6786520835507334,
  'dim_768_cosine_map@100': 0.684264196034238,
  'model_name': 'thang1943/bkcare-embed-v2'},
 {'dim_768_cosine_accuracy@1': 0.633308984660336,
  'dim_768_cosine_accuracy@3': 0.7529218407596786,
  'dim_768_cosine_accuracy@5': 0.7985756026296567,
  'dim_768_cosine_accuracy@10': 0.860116873630387

In [None]:

df = pd.DataFrame(scores)

df.set_index('model_name', inplace=True)

def highlight_max(s):
    is_max = s == s.max()
    return ['font-weight: bold' if v else '' for v in is_max]

df_styled = df.style.apply(highlight_max, axis=0)

df_styled

Unnamed: 0_level_0,dim_768_cosine_accuracy@1,dim_768_cosine_accuracy@3,dim_768_cosine_accuracy@5,dim_768_cosine_accuracy@10,dim_768_cosine_precision@1,dim_768_cosine_precision@3,dim_768_cosine_precision@5,dim_768_cosine_precision@10,dim_768_cosine_recall@1,dim_768_cosine_recall@3,dim_768_cosine_recall@5,dim_768_cosine_recall@10,dim_768_cosine_ndcg@10,dim_768_cosine_mrr@10,dim_768_cosine_map@100
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
thang1943/bkcare-embed-v2,0.595873,0.734295,0.787071,0.848977,0.595873,0.244765,0.157414,0.084898,0.595873,0.734295,0.787071,0.848977,0.719688,0.678652,0.684264
thang1943/vietnamese-bi-encoder-v2,0.633309,0.752922,0.798576,0.860117,0.633309,0.250974,0.159715,0.086012,0.633309,0.752922,0.798576,0.860117,0.742172,0.705004,0.710242
thang1943/vietnamese-sbert-v2,0.613039,0.74653,0.794375,0.850621,0.613039,0.248843,0.158875,0.085062,0.613039,0.74653,0.794375,0.850621,0.729462,0.690974,0.69666


: 