# 对embedding 模型进行评估

In [1]:
! pip install modelscope

Collecting modelscope
  Downloading modelscope-1.31.0-py3-none-any.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.4/40.4 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Downloading modelscope-1.31.0-py3-none-any.whl (5.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.9/5.9 MB[0m [31m30.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: modelscope
Successfully installed modelscope-1.31.0


In [2]:
! modelscope download --model 'BAAI/bge-base-zh-v1.5' --local_dir './BAAI/bge-base-zh-v1.5'


 _   .-')                _ .-') _     ('-.             .-')                              _ (`-.    ('-.
( '.( OO )_             ( (  OO) )  _(  OO)           ( OO ).                           ( (OO  ) _(  OO)
 ,--.   ,--.).-'),-----. \     .'_ (,------.,--.     (_)---\_)   .-----.  .-'),-----.  _.`     \(,------.
 |   `.'   |( OO'  .-.  ',`'--..._) |  .---'|  |.-') /    _ |   '  .--./ ( OO'  .-.  '(__...--'' |  .---'
 |         |/   |  | |  ||  |  \  ' |  |    |  | OO )\  :` `.   |  |('-. /   |  | |  | |  /  | | |  |
 |  |'.'|  |\_) |  |\|  ||  |   ' |(|  '--. |  |`-' | '..`''.) /_) |OO  )\_) |  |\|  | |  |_.' |(|  '--.
 |  |   |  |  \ |  | |  ||  |   / : |  .--'(|  '---.'.-._)   \ ||  |`-'|   \ |  | |  | |  .___.' |  .--'
 |  |   |  |   `'  '-'  '|  '--'  / |  `---.|      | \       /(_'  '--'\    `'  '-'  ' |  |      |  `---.
 `--'   `--'     `-----' `-------'  `------'`------'  `-----'    `-----'      `-----'  `--'      `------'

Downloading Model from https://www.modelscope.cn to d

In [4]:
import os
import json
import time
import torch
from pprint import pprint
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.util import cos_sim

# 加载数据集，数据应包含query,corpus,relevant_docs
# 其中query是问题，corpus 是字典，relevant_docs 也是字典，将query Id 和相关文档的id关联起来
# 数据集格式
'''
datasets={
    "corpus":[{uuid1:doc1},{uuid2:doc2},{uuid3:doc3}       #对应的文本id、文本
    ],
    "queries":[{uuid1:问题}，{uuid2:问题}，...
    ],
    "relevant_docs":[{uuid1:[uuid答案]},{uuid2:[uuid答案]},{uuid3:[uuid答案]}
    ]
}

'''


#加载用于微调的数据集，划分为train,test
from datasets import load_dataset, Dataset

def load_jsonl_to_dataset(file_path):
    """Loads a JSONL file into a Hugging Face Dataset."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return Dataset.from_list(data)


# Load the dataset
jsonl_path = "/content/sample_data/final_datasets.jsonl"
dataset = load_jsonl_to_dataset(jsonl_path)
dataset



Dataset({
    features: ['context', 'question', 'answer', 'source_doc', 'groundedness_score', 'groundedness_eval', 'relevance_score', 'relevance_eval', 'standalone_score', 'standalone_eval'],
    num_rows: 572
})

In [5]:
split_dataset = dataset.train_test_split(
    test_size=0.2,  # 测试集比例（可根据需求调整，如0.1/0.3）
    seed=42         # 随机种子，固定后每次拆分结果一致
)


# 3. 查看拆分结果
print("\n拆分后数据集信息：")
print(split_dataset)


拆分后数据集信息：
DatasetDict({
    train: Dataset({
        features: ['context', 'question', 'answer', 'source_doc', 'groundedness_score', 'groundedness_eval', 'relevance_score', 'relevance_eval', 'standalone_score', 'standalone_eval'],
        num_rows: 457
    })
    test: Dataset({
        features: ['context', 'question', 'answer', 'source_doc', 'groundedness_score', 'groundedness_eval', 'relevance_score', 'relevance_eval', 'standalone_score', 'standalone_eval'],
        num_rows: 115
    })
})


In [6]:
dataset[0]

{'context': '# （二）其他直属单位兼职培训员数量要求\n\n1.各单位市公司、下属分公司的每个部门各不少于1人（1-2人）；\n\n2.各单位班组、变电站、项目部等若因工作原因，确实无法抽出人员参加紧急救护兼职培训员集中培训的，后续由所属市公司或分公司兼职培训员结合各项培训、会议等尽快为相关班组、变电站、项目部培训兼职培训员。\n\n# 三、其他要求',
 'question': '各单位市公司、下属分公司的每个部门至少需要多少名兼职培训员？',
 'answer': '各单位市公司、下属分公司的每个部门至少需要1名兼职培训员。',
 'source_doc': '培训相关要求.pdf',
 'groundedness_score': 5,
 'groundedness_eval': '上文明确指出“各单位市公司、下属分公司的每个部门各不少于1人（1-2人）”，这直接回答了问题中关于各单位市公司、下属分公司的每个部门至少需要多少名兼职培训员的要求。因此，根据给定的上下文，可以清晰、无歧义地回答该问题。',
 'relevance_score': 3,
 'relevance_eval': '该问题涉及国网公司内部的人员配置情况，特别是关于兼职培训员的数量要求，这与工会相关规章制度有一定的关联，因为培训员的职责可能包括员工培训、技能提升等方面，这些活动通常由工会组织或参与。然而，该问题更偏向于人力资源管理的具体操作层面，而非直接涉及工会规章制度的理解和应用。因此，对于了解工会相关规章制度来说，这个问题的直接有用程度有限，但仍然具有一定的间接价值。',
 'standalone_score': 4,
 'standalone_eval': '该问题虽然指定了具体的组织结构（各单位市公司、下属分公司、每个部门），但并未提及特定的上下文或文档，因此可以理解为在一般情况下询问这些单位需要配置多少名兼职培训员。尽管如此，问题的具体答案可能依赖于具体的组织政策或行业标准，但问题本身是明确的，不依赖于特定的上下文信息来理解其含义。'}

In [7]:
#将data['test']整理成对应格式进行评估
test_data = split_dataset['test']
test_data

Dataset({
    features: ['context', 'question', 'answer', 'source_doc', 'groundedness_score', 'groundedness_eval', 'relevance_score', 'relevance_eval', 'standalone_score', 'standalone_eval'],
    num_rows: 115
})

In [8]:
import uuid

corpus = {str(i): sample['context'] for i, sample in enumerate(test_data)}
queries = {}
relevant_docs = {}
for i, sample in enumerate(test_data):
  question_id = str(uuid.uuid4())
  queries[question_id] = sample["question"]
  relevant_docs[question_id] = [str(i)]


test_dataset = {
    'queries': queries,
    'corpus': corpus,
    'relevant_docs': relevant_docs,
}


test_dataset

{'queries': {'d1554ff4-66f8-4d4b-baa7-07cfbf91caef': '组织安排异地流动到国网陕西电力系统内，考核鉴定结果为优秀的考评系数是多少？',
  'f23e77db-0476-4b0c-b010-51f281689058': '各级单位在防灾减灾工作中会成立哪些组织机构？',
  'f645ae37-eab4-4ab1-9afa-6ee78436863d': '组织安排异地流动到国网陕西电力系统内，考核鉴定结果为优秀的考评系数是多少？',
  'be00a63f-aafc-443d-9cda-87cc64cde754': '薪酬分配与员工的哪些要素挂钩？',
  'bd3a32c6-cbb5-44d2-912a-55848eddf1a4': '社会保险行政部门工作人员在工伤认定过程中收受当事人财物的，将面临什么处罚？',
  '0837ca86-d18e-4634-82c2-f7a099ad6e39': '各单位每年至少需要进行几次全员轮训和紧急救护模拟演练？',
  'b1ad1f8f-ff92-4a50-9ef1-a35fa3c754de': '勤后勤服务的工作理念是什么？',
  'e2b47f69-e8da-4c5f-aacb-88c93949129f': '申报人员所在单位评价积分标准的分值范围是多少？',
  'c67fc9ed-cbbc-4d9d-bedf-fddb4d5ddb91': '正高级专业技术资格可以获得多少分？',
  'a4a31a2a-c700-4611-aeb7-d28751aca2b8': '申报正高级实习指导教师人员应具备什么级别的技能操作水平？',
  'aa62ead7-3263-47cc-b492-1e3048544f5b': '人才强企战略的主要目标是什么？',
  '5593e18f-f8fc-4a80-a00b-bac77eaa2bb5': '应急预案要求每三年至少组织几次大型综合应急演练？',
  'a8996738-07bd-44e5-90e1-35437b8cb9f4': '国网陕西省电力有限公司计划开展的消防设施和设备改进工作的目的是什么？',
  'ae5809d6-55db-4de2-8ecd-9e1d4dd0b2ae': '国家电网有限公司新型规章制度体系由几

In [9]:
# 导出评估数据集
import json

output_path="bge_eval_dataset.json"
with open(output_path, 'w', encoding='utf-8') as f:
        # ensure_ascii=False：保证中文正常显示；indent=2：格式化输出，便于阅读
        json.dump(test_dataset, f, ensure_ascii=False, indent=2)

print(f"数据集已保存至 {output_path}")


数据集已保存至 bge_eval_dataset.json


In [10]:
corpus = test_dataset['corpus']
queries = test_dataset['queries']
relevant_docs = test_dataset['relevant_docs']

# # Load a model
# 替换成自己的模型完整路径或使用huggingface modl id
model_name = "/content/BAAI/bge-base-zh-v1.5"
model = SentenceTransformer(model_name, device="cuda" if torch.cuda.is_available() else "cpu")
print("Model loaded")

s_time = time.time()

# # Evaluate the model
evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name=f"cunstom",
    score_functions={"cosine": cos_sim}
)

# Evaluate the model
result = evaluator(model)
pprint(result)
print(f"Time cost: {time.time() - s_time:.2f}s")

#评估指标 recall, mrr,recall是召回率，计算的是每个问题的相关文档被召回概率的平均。 mrr计算的是第一个相关文档的倒数排名，不关心其他文档（当相关文档有多个的时候）

Model loaded
{'cunstom_cosine_accuracy@1': 0.8869565217391304,
 'cunstom_cosine_accuracy@10': 1.0,
 'cunstom_cosine_accuracy@3': 1.0,
 'cunstom_cosine_accuracy@5': 1.0,
 'cunstom_cosine_map@100': 0.9420289855072465,
 'cunstom_cosine_mrr@10': 0.9420289855072463,
 'cunstom_cosine_ndcg@10': 0.9571404960248477,
 'cunstom_cosine_precision@1': 0.8869565217391304,
 'cunstom_cosine_precision@10': 0.1,
 'cunstom_cosine_precision@3': 0.33333333333333337,
 'cunstom_cosine_precision@5': 0.2,
 'cunstom_cosine_recall@1': 0.8869565217391304,
 'cunstom_cosine_recall@10': 1.0,
 'cunstom_cosine_recall@3': 1.0,
 'cunstom_cosine_recall@5': 1.0}
Time cost: 2.97s


In [11]:
with open("/content/sample_data/doc_qa_dataset.json","r", encoding="utf-8") as f:
  content = f.read()

In [12]:
import json

data = json.loads(content)

corpus = data['corpus']
queries = data['queries']
relevant_docs = data['relevant_docs']

print("Data loaded successfully!")

Data loaded successfully!


In [13]:
model_name = "/content/BAAI/bge-base-zh-v1.5"
model = SentenceTransformer(model_name, device="cuda" if torch.cuda.is_available() else "cpu")
print("Model loaded")

s_time = time.time()

# # Evaluate the model
evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name=f"cunstom",
    score_functions={"cosine": cos_sim}
)

# Evaluate the model
result = evaluator(model)
pprint(result)
print(f"Time cost: {time.time() - s_time:.2f}s")

Model loaded
{'cunstom_cosine_accuracy@1': 0.6137071651090342,
 'cunstom_cosine_accuracy@10': 0.881619937694704,
 'cunstom_cosine_accuracy@3': 0.7601246105919003,
 'cunstom_cosine_accuracy@5': 0.8130841121495327,
 'cunstom_cosine_map@100': 0.7070720844057398,
 'cunstom_cosine_mrr@10': 0.7016540572615335,
 'cunstom_cosine_ndcg@10': 0.7450975234585043,
 'cunstom_cosine_precision@1': 0.6137071651090342,
 'cunstom_cosine_precision@10': 0.0881619937694704,
 'cunstom_cosine_precision@3': 0.2533748701973001,
 'cunstom_cosine_precision@5': 0.16261682242990652,
 'cunstom_cosine_recall@1': 0.6137071651090342,
 'cunstom_cosine_recall@10': 0.881619937694704,
 'cunstom_cosine_recall@3': 0.7601246105919003,
 'cunstom_cosine_recall@5': 0.8130841121495327}
Time cost: 3.42s


In [14]:
train_data = split_dataset['train']
train_data

Dataset({
    features: ['context', 'question', 'answer', 'source_doc', 'groundedness_score', 'groundedness_eval', 'relevance_score', 'relevance_eval', 'standalone_score', 'standalone_eval'],
    num_rows: 457
})

In [15]:
import uuid

corpus = {str(i): sample['context'] for i, sample in enumerate(train_data)}
queries = {}
relevant_docs = {}
for i, sample in enumerate(train_data):
  question_id = str(uuid.uuid4())
  queries[question_id] = sample["question"]
  relevant_docs[question_id] = [str(i)]


train_dataset = {
    'queries': queries,
    'corpus': corpus,
    'relevant_docs': relevant_docs,
}

train_dataset

{'queries': {'6bf9858f-036c-4787-9b02-324572f4719a': '一次性工亡补助金的标准是多少？',
  '90afab9b-3a9a-4c7c-a88a-5980558abab0': '人员借用的基础积分为多少分？',
  '872b1e3d-1104-42cf-8234-1745eec2de63': '单一来源采购是指采购人与哪个对象进行谈判？',
  '93e50a6d-9ed6-4b8e-a8e5-18e6162f11ab': '变电运维检修管理遵循哪些原则？',
  'db4a735e-40d2-47d8-be58-d5304bb71b50': '状态检修工作的核心原则是什么？',
  'b22a9fe9-44b5-47d9-9e07-a8be4cc02a01': '年功工资每5年工龄段的工资增长额是多少？',
  '53b58221-0d85-4222-87e8-3b3fa65dfd23': '取得高级讲师（高级实习指导教师）职称后，需要具备多少项业绩成果？',
  '4453ddab-07f3-466f-8440-1d985f252103': '职工共享服务点可以提供哪些基本急救设备？',
  'e264d4cf-0cc1-4106-8caa-97446685f20f': '出差期间，公司三级正（副）职、一级协理、二级协理人员可选择哪些交通工具？',
  '35f11fa2-50c8-4b94-91b8-ebe24c0fdb99': '业绩成果的主要贡献者需要提供哪些文件来证明其主要完成者的身份？',
  '1ed46437-3b18-48b9-bc96-9e781e418a30': '国家电网公司的企业愿景是什么？',
  '1cd6ef98-f945-4a16-9ef6-aac6946ded4d': '国网陕西省电力有限公司员工制度学习手册的主要目的是什么？',
  '8dc896c7-f0ae-42c4-b4c5-70abd02cc78c': '职称申报中计算现有职称取得年限的截止时间是什么时候？',
  '51f0f239-bb7e-4e94-9eca-c18fc0e4a8b0': '员工专项考核奖是如何计算的？',
  '0cd8a6f3-f3f5-4913-9831-3b45755de2b3': '

In [16]:
# 导出训练数据集
import json

output_path="bge_train_dataset.json"
with open(output_path, 'w', encoding='utf-8') as f:
        # ensure_ascii=False：保证中文正常显示；indent=2：格式化输出，便于阅读
        json.dump(train_dataset, f, ensure_ascii=False, indent=2)

print(f"数据集已保存至 {output_path}")

数据集已保存至 bge_train_dataset.json


In [17]:
train_dataset

{'queries': {'6bf9858f-036c-4787-9b02-324572f4719a': '一次性工亡补助金的标准是多少？',
  '90afab9b-3a9a-4c7c-a88a-5980558abab0': '人员借用的基础积分为多少分？',
  '872b1e3d-1104-42cf-8234-1745eec2de63': '单一来源采购是指采购人与哪个对象进行谈判？',
  '93e50a6d-9ed6-4b8e-a8e5-18e6162f11ab': '变电运维检修管理遵循哪些原则？',
  'db4a735e-40d2-47d8-be58-d5304bb71b50': '状态检修工作的核心原则是什么？',
  'b22a9fe9-44b5-47d9-9e07-a8be4cc02a01': '年功工资每5年工龄段的工资增长额是多少？',
  '53b58221-0d85-4222-87e8-3b3fa65dfd23': '取得高级讲师（高级实习指导教师）职称后，需要具备多少项业绩成果？',
  '4453ddab-07f3-466f-8440-1d985f252103': '职工共享服务点可以提供哪些基本急救设备？',
  'e264d4cf-0cc1-4106-8caa-97446685f20f': '出差期间，公司三级正（副）职、一级协理、二级协理人员可选择哪些交通工具？',
  '35f11fa2-50c8-4b94-91b8-ebe24c0fdb99': '业绩成果的主要贡献者需要提供哪些文件来证明其主要完成者的身份？',
  '1ed46437-3b18-48b9-bc96-9e781e418a30': '国家电网公司的企业愿景是什么？',
  '1cd6ef98-f945-4a16-9ef6-aac6946ded4d': '国网陕西省电力有限公司员工制度学习手册的主要目的是什么？',
  '8dc896c7-f0ae-42c4-b4c5-70abd02cc78c': '职称申报中计算现有职称取得年限的截止时间是什么时候？',
  '51f0f239-bb7e-4e94-9eca-c18fc0e4a8b0': '员工专项考核奖是如何计算的？',
  '0cd8a6f3-f3f5-4913-9831-3b45755de2b3': '

In [18]:
#使用 sentence transformer 微调embedding模型

import os
import json
import time
import torch
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.util import cos_sim
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers import SentenceTransformerTrainer
start_time = time.time()

# Load eval dataset
corpus, queries, relevant_docs = test_dataset['corpus'], test_dataset['queries'], test_dataset['relevant_docs']
# Load train dataset

train_anchor, train_positive = [], []
for query_id, context_id in train_dataset['relevant_docs'].items():
   train_anchor.append(train_dataset['queries'][query_id])
   train_positive.append(train_dataset['corpus'][context_id[0]])
train_dataset_for_finetune = Dataset.from_dict({"positive": train_positive, "anchor": train_anchor})

# Load a model
model_path = '/content/BAAI/bge-base-zh-v1.5'
model_name = 'bge-base-zh-v1.5'
model = SentenceTransformer(model_path, device="cuda:0" if torch.cuda.is_available() else "cpu")
# Evaluate the model
evaluator = InformationRetrievalEvaluator(
   queries=queries,
   corpus=corpus,
   relevant_docs=relevant_docs,
   name=f"{model_name}",
   score_functions={"cosine": cos_sim}
)
train_loss = MultipleNegativesRankingLoss(model)
# Define training arguments
args = SentenceTransformerTrainingArguments(
   output_dir=f"ft_{model_name}",
   num_train_epochs=5,
   per_device_train_batch_size=2,
   gradient_accumulation_steps=2,
   per_device_eval_batch_size=4,
   warmup_ratio=0.1,
   learning_rate=2e-5,
   lr_scheduler_type="cosine",
   optim="adamw_torch_fused",
   tf32=False,
   bf16=False,
   batch_sampler=BatchSamplers.NO_DUPLICATES,
   eval_strategy="epoch",
   save_strategy="epoch",
   logging_steps=10,
   save_total_limit=3,
   load_best_model_at_end=True,
   metric_for_best_model=f"eval_{model_name}_cosine_ndcg@10"
)
# Train the model
trainer = SentenceTransformerTrainer(
   model=model,
   args=args,
   train_dataset=train_dataset_for_finetune.select_columns(["positive", "anchor"]),
   loss=train_loss,
   evaluator=evaluator
)
trainer.train()
trainer.save_model()
print(f"cost time: {time.time() - start_time:.2f}s")

'''
多负样本排序损失函数（MultipleNegativesRankingLoss）是一种适用于语义检索和信息召回任务的损失函数。它的主要优点在于不需要构造负样本，
因为该损失函数会将一个批次中的所有非正样本作为负样本，从而在最终结果的概率分布上，正样本的概率高于其他负样本。

'''

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mliqing20[0m ([33mliqing20-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


dataset = dataset.select_columns(['anchor', 'positive', 'negative'])


Epoch,Training Loss,Validation Loss,Bge-base-zh-v1.5 Cosine Accuracy@1,Bge-base-zh-v1.5 Cosine Accuracy@3,Bge-base-zh-v1.5 Cosine Accuracy@5,Bge-base-zh-v1.5 Cosine Accuracy@10,Bge-base-zh-v1.5 Cosine Precision@1,Bge-base-zh-v1.5 Cosine Precision@3,Bge-base-zh-v1.5 Cosine Precision@5,Bge-base-zh-v1.5 Cosine Precision@10,Bge-base-zh-v1.5 Cosine Recall@1,Bge-base-zh-v1.5 Cosine Recall@3,Bge-base-zh-v1.5 Cosine Recall@5,Bge-base-zh-v1.5 Cosine Recall@10,Bge-base-zh-v1.5 Cosine Ndcg@10,Bge-base-zh-v1.5 Cosine Mrr@10,Bge-base-zh-v1.5 Cosine Map@100
1,0.0135,No log,0.895652,0.991304,1.0,1.0,0.895652,0.330435,0.2,0.1,0.895652,0.991304,1.0,1.0,0.959747,0.945652,0.945652
2,0.0018,No log,0.930435,0.991304,0.991304,1.0,0.930435,0.330435,0.198261,0.1,0.930435,0.991304,0.991304,1.0,0.971582,0.961957,0.961957
3,0.0004,No log,0.913043,0.991304,0.991304,0.991304,0.913043,0.330435,0.198261,0.09913,0.913043,0.991304,0.991304,0.991304,0.962421,0.952174,0.952964
4,0.0008,No log,0.904348,0.991304,0.991304,1.0,0.904348,0.330435,0.198261,0.1,0.904348,0.991304,0.991304,1.0,0.962309,0.949275,0.949275
5,0.0007,No log,0.904348,0.991304,1.0,1.0,0.904348,0.330435,0.2,0.1,0.904348,0.991304,1.0,1.0,0.962575,0.949565,0.949565


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


cost time: 323.40s


'\n多负样本排序损失函数（MultipleNegativesRankingLoss）是一种适用于语义检索和信息召回任务的损失函数。它的主要优点在于不需要构造负样本，\n因为该损失函数会将一个批次中的所有非正样本作为负样本，从而在最终结果的概率分布上，正样本的概率高于其他负样本。\n\n'

In [19]:
corpus = test_dataset['corpus']
queries = test_dataset['queries']
relevant_docs = test_dataset['relevant_docs']

# # Load a model
# 替换成自己的模型完整路径或使用huggingface modl id
model_name = "/content/ft_bge-base-zh-v1.5"
model = SentenceTransformer(model_name, device="cuda" if torch.cuda.is_available() else "cpu")
print("Model loaded")

s_time = time.time()

# # Evaluate the model
evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name=f"cunstom",
    score_functions={"cosine": cos_sim}
)

# Evaluate the model
result = evaluator(model)
pprint(result)
print(f"Time cost: {time.time() - s_time:.2f}s")

Model loaded
{'cunstom_cosine_accuracy@1': 0.9304347826086956,
 'cunstom_cosine_accuracy@10': 1.0,
 'cunstom_cosine_accuracy@3': 0.991304347826087,
 'cunstom_cosine_accuracy@5': 0.991304347826087,
 'cunstom_cosine_map@100': 0.9619565217391305,
 'cunstom_cosine_mrr@10': 0.9619565217391305,
 'cunstom_cosine_ndcg@10': 0.9715823752329212,
 'cunstom_cosine_precision@1': 0.9304347826086956,
 'cunstom_cosine_precision@10': 0.1,
 'cunstom_cosine_precision@3': 0.3304347826086957,
 'cunstom_cosine_precision@5': 0.1982608695652174,
 'cunstom_cosine_recall@1': 0.9304347826086956,
 'cunstom_cosine_recall@10': 1.0,
 'cunstom_cosine_recall@3': 0.991304347826087,
 'cunstom_cosine_recall@5': 0.991304347826087}
Time cost: 2.76s


微调后，recall@1': 从原来的0.887 提升到 0.930，
mrr@10也从0.94 提升到0.96