In [6]:
# 安装必要的库

# 导入库
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from sklearn.metrics import ndcg_score
from langchain.embeddings import HuggingFaceEmbeddings

# 加载 HotpotQA 数据集
dataset = load_dataset("hotpot_qa", "fullwiki", split="train[:1000]", trust_remote_code=True)

print(dataset[0])
# 提取文档和问题
# 修改corpus提取方式
corpus = [' '.join([' '.join(paragraph) for paragraph in item['context']['sentences']]) for item in dataset]



# 其他部分保持不变
queries = [item['question'] for item in dataset]
answers = [item['answer'] for item in dataset]

model_name = "moka-ai/m3e-base"
embeddings = HuggingFaceEmbeddings(model_name=model_name)

# 现在corpus中的每个元素都是字符串，可以正常处理
corpus_embeddings = np.array(embeddings.embed_documents(corpus))
query_embeddings = np.array(embeddings.embed_queries(queries))

# 构建 FAISS 索引
index = faiss.IndexFlatL2(corpus_embeddings.shape[1])
index.add(corpus_embeddings)

# 检索
k = 10
D, I = index.search(query_embeddings.cpu().detach().numpy(), k)

# 评估
recall_at_10 = []
mrr = []
ndcg_at_5 = []

for i, indices in enumerate(I):
    retrieved_docs = [corpus[idx] for idx in indices]
    # 检查答案是否在检索到的文档中
    hits = [1 if answers[i] in doc else 0 for doc in retrieved_docs]
    recall = 1 if any(hits) else 0
    recall_at_10.append(recall)
    if 1 in hits:
        rank = hits.index(1) + 1
        mrr.append(1 / rank)
    else:
        mrr.append(0)
    ndcg = ndcg_score([hits], [list(range(k, 0, -1))])
    ndcg_at_5.append(ndcg)

# 输出平均指标
print(f"Recall@10: {np.mean(recall_at_10):.2f}")
print(f"MRR: {np.mean(mrr):.2f}")
print(f"NDCG@5: {np.mean(ndcg_at_5):.2f}")


{'id': '5a7a06935542990198eaf050', 'question': "Which magazine was started first Arthur's Magazine or First for Women?", 'answer': "Arthur's Magazine", 'type': 'comparison', 'level': 'medium', 'supporting_facts': {'title': ["Arthur's Magazine", 'First for Women'], 'sent_id': [0, 0]}, 'context': {'title': ['Radio City (Indian radio station)', 'History of Albanian football', 'Echosmith', "Women's colleges in the Southern United States", 'First Arthur County Courthouse and Jail', "Arthur's Magazine", '2014–15 Ukrainian Hockey Championship', 'First for Women', 'Freeway Complex Fire', 'William Rast'], 'sentences': [["Radio City is India's first private FM radio station and was started on 3 July 2001.", ' It broadcasts on 91.1 (earlier 91.0 in most cities) megahertz from Mumbai (where it was started in 2004), Bengaluru (started first in 2001), Lucknow and New Delhi (since 2003).', ' It plays Hindi, English and regional songs.', ' It was launched in Hyderabad in March 2006, in Chennai on 7 Ju

KeyboardInterrupt: 