In [28]:
# Cell 1: 导入必要的库
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
import torch
import faiss
import numpy as np
import os
import pickle

In [29]:
# Cell 2: 加载数据集并取前1000行
def load_and_prepare_data():
    dataset = load_dataset("enelpol/rag-mini-bioasq", "text-corpus")
    # 只取test split的前1000行
    texts = dataset['test']['passage'][:1000]
    ids = dataset['test']['id'][:1000]
    return texts, ids

In [30]:
# Cell 3: 加载模型和tokenizer
def load_model_and_tokenizer():
    model_name = "ncbi/MedCPT-Query-Encoder"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    # 如果有GPU则使用GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    return model, tokenizer, device

In [31]:
# Cell 4: 计算embeddings的函数
def compute_embeddings(texts, model, tokenizer, device, batch_size=32):
    embeddings = []
    
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]
        
        # 对文本进行编码
        inputs = tokenizer(batch_texts, padding=True, truncation=True, 
                         max_length=512, return_tensors="pt")
        
        # 将输入移到相应设备
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # 计算embeddings
        with torch.no_grad():
            outputs = model(**inputs)
            # 使用最后一层的[CLS]标记的输出作为文档的嵌入
            batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            embeddings.append(batch_embeddings)
    
    # 将所有batch的embeddings合并
    embeddings = np.vstack(embeddings)
    return embeddings

In [32]:
# Cell 5: 创建和保存FAISS索引
def create_and_save_faiss_index(embeddings, texts, ids, save_dir="faiss_index"):
    # 创建维度相同的FAISS索引
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    
    # 将向量添加到索引中
    index.add(embeddings)
    
    # 创建保存目录
    os.makedirs(save_dir, exist_ok=True)
    
    # 保存FAISS索引
    faiss.write_index(index, os.path.join(save_dir, "docs.index"))
    
    # 保存文本和ID映射
    with open(os.path.join(save_dir, "texts.pkl"), "wb") as f:
        pickle.dump({"texts": texts, "ids": ids}, f)
    
    return index

In [33]:
# Cell 6: 主函数调用所有步骤
def main():
    # 1. 加载数据
    texts, ids = load_and_prepare_data()
    print(f"Loaded {len(texts)} documents")
    
    # 2. 加载模型和tokenizer
    model, tokenizer, device = load_model_and_tokenizer()
    print(f"Model loaded and moved to {device}")
    
    # 3. 计算embeddings
    embeddings = compute_embeddings(texts, model, tokenizer, device)
    print(f"Computed embeddings with shape: {embeddings.shape}")
    
    # 4. 创建和保存FAISS索引
    index = create_and_save_faiss_index(embeddings, texts, ids)
    print(f"Created and saved FAISS index with {index.ntotal} vectors")

In [25]:
# 测试数据加载
texts, ids = load_and_prepare_data()
print(f"数据示例:\n{texts[:2]}\n")
print(f"ID示例:\n{ids[:2]}")

数据示例:
['New data on viruses isolated from patients with subacute thyroiditis de Quervain \nare reported. Characteristic morphological, cytological, some physico-chemical \nand biological features of the isolated viruses are described. A possible role \nof these viruses in human and animal health disorders is discussed. The isolated \nviruses remain unclassified so far.', "We describe an improved method for detecting deficiency of the acid hydrolase, \nalpha-1,4-glucosidase in leukocytes, the enzyme defect in glycogen storage \ndisease Type II (Pompe disease). The procedure requires smaller volumes of blood \nand less time than previous methods. The assay involves the separation of \nleukocytes by Peter's method for beta-glucosidase and a modification of Salafsky \nand Nadler's fluorometric method for alpha-glucosidase."]

ID示例:
[9797, 11906]


In [26]:
# 测试模型加载
model, tokenizer, device = load_model_and_tokenizer()
print(f"模型设备: {device}")

模型设备: cuda


In [27]:
# 测试小批量数据的embeddings计算
test_embeddings = compute_embeddings(texts[:5], model, tokenizer, device)
print(f"测试embeddings形状: {test_embeddings.shape}")

测试embeddings形状: (5, 768)


In [34]:
main()

Loaded 1000 documents
Model loaded and moved to cuda
Computed embeddings with shape: (1000, 768)
Created and saved FAISS index with 1000 vectors


In [35]:
# 测试加载和查询
def test_vectorstore():
    # 1. 加载保存的索引
    index = faiss.read_index("faiss_index/docs.index")
    
    # 2. 加载保存的文本和ID映射
    with open("faiss_index/texts.pkl", "rb") as f:
        stored_data = pickle.load(f)
        texts = stored_data["texts"]
        ids = stored_data["ids"]
    
    # 3. 加载模型来编码查询
    model, tokenizer, device = load_model_and_tokenizer()
    
    # 4. 测试查询
    test_query = "What is Pompe disease?"
    # 使用相同的方式编码查询
    query_embedding = compute_embeddings([test_query], model, tokenizer, device)
    
    # 5. 搜索最相关的3个文档
    k = 3  # 返回前3个最相似的文档
    distances, indices = index.search(query_embedding, k)
    
    print("查询:", test_query)
    print("\n最相关的文档:")
    for i in range(k):
        print(f"\n--- 相似度得分: {1 - distances[0][i]:.4f} ---")
        print(f"ID: {ids[indices[0][i]]}")
        print(f"文本: {texts[indices[0][i]][:200]}...")

# 运行测试
test_vectorstore()

查询: What is Pompe disease?

最相关的文档:

--- 相似度得分: -61.9027 ---
ID: 1294876
文本: The mucopolysaccharidosis represent a broad spectrum of disorders due to the 
deficiency of one of a group of enzymes which degrade three classes of 
mucopolysaccharides: heparan sulfate, dermatan-sul...

--- 相似度得分: -62.4059 ---
ID: 1588015
文本: A sister and brother, now aged 7 and 9 years, presented with developmental 
arrest, gait disturbance, dementia, and a progressive myoclonic epilepsy 
syndrome with hyperacusis in the second year of li...

--- 相似度得分: -69.7647 ---
ID: 2079710
文本: The inherited deficiency of galactosylceramide beta-galactosidase (E.C. 
3.2.1.46: galactocerebrosidase) activity results in globoid cell leukodystrophy 
in humans (Krabbe disease) and in mice (twitch...
