In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import time
from datasets import load_dataset
from tqdm import tqdm
import re
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer

# 设备设置
device = "cuda" if torch.cuda.is_available() else "cpu"

# 加载 GLM-4V 模型
def load_model(model_path):
    print(f"Loading model from: {model_path}")  # 调试信息
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    print("Tokenizer loaded successfully.")  # 调试信息
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        trust_remote_code=True
    ).to(device).eval()
    print("Model loaded successfully.")  # 调试信息
    return model, tokenizer

# 加载文本嵌入模型
def load_text_embedding_model(model_path):
    print(f"Loading text embedding model from: {model_path}")  # 调试信息
    model = SentenceTransformer(model_path)
    print("Text embedding model loaded successfully.")  # 调试信息
    return model

# 加载文本检索系统
def load_text_retrieval_system(index_file, texts_file):
    """
    加载文本检索系统（FAISS 索引和文本数据）。
    """
    print(f"Loading text retrieval system from: {index_file} and {texts_file}")  # 调试信息
    index = faiss.read_index(index_file)
    with open(texts_file, "r", encoding="utf-8") as f:
        texts = json.load(f)
    print("Text retrieval system loaded successfully.")  # 调试信息
    return index, texts

# 检索文本
def retrieve_texts(query, index, texts, k=3):
    """
    检索与查询相关的文本。
    """
    query_embedding = text_embedding_model.encode([query], convert_to_tensor=False)
    distances, indices = index.search(query_embedding.astype("float32"), k)
    retrieved_texts = [texts[i] for i in indices[0]]
    return retrieved_texts

# 加载 MMLU-Pro 数据集
def load_mmlu_pro():
    """ 加载 MMLU-Pro 数据集 """
    dataset = load_dataset("/root/autodl-tmp/MMLU-Pro")
    test_df, val_df = dataset["test"], dataset["validation"]
    return test_df, val_df

# 预处理数据
def preprocess(test_df):
    res_df = []
    for each in test_df:
        options = [opt for opt in each["options"] if opt != "N/A"]
        each["options"] = options
        res_df.append(each)
    return res_df

# 筛选 Health 类别的数据
def filter_health_category(test_df):
    return [item for item in test_df if item["category"] == "health"]

# 改进输入格式
def format_query(question, options):
    """ 格式化问题输入，确保模型能清晰理解 """
    option_str = "\n".join([f"{chr(65 + i)}. {opt}" for i, opt in enumerate(options)])
    query = (
        f"Question: {question}\n"
        f"Options:\n{option_str}\n"
        f"Please choose the correct answer from the options (A, B, C, etc.). Answer in this format: Correct answer: "
    )
    return query

# 生成描述（集成 RAG）
def generate_description_with_rag(model, tokenizer, query, index, texts):
    """
    生成描述，集成 RAG 功能。
    """
    if not query.strip():
        return "错误：请输入问题。"

    # 检索相关文档
    retrieved_texts = retrieve_texts(query, index, texts)
    context = "Retrieved documents:\n" + "\n".join(retrieved_texts) + "\n\nQuestion: " + query

    # 处理文本输入
    inputs = tokenizer.apply_chat_template(
        [{"role": "user", "content": context}],
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True
    ).to(device)

    # 生成回答
    gen_kwargs = {"max_length": 2000, "do_sample": False, "top_k": 1}
    with torch.no_grad():
        start_time = time.time()
        outputs = model.generate(**inputs, **gen_kwargs)
        inference_time = time.time() - start_time
        outputs = outputs[:, inputs['input_ids'].shape[1]:]
        description = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return description, retrieved_texts, inference_time

# 提取模型答案（改进版）
def extract_answer(text):
    """ 从模型输出中提取答案，支持更多格式 """
    patterns = [
        r"answer is \(?([A-Z])\)?",  # answer is A
        r"Correct answer: ([A-Z])",  # Correct answer: A
        r"([A-Z]) is correct",       # A is correct
        r"\bOption ([A-Z])\b"        # Option A
    ]

    for pattern in patterns:
        match = re.search(pattern, text)
        if match:
            return match.group(1)

    return None  # 未能匹配到答案

# 测试模型（集成 RAG）
def test_model_with_rag(model, tokenizer, test_df, index, texts):
    results = []
    total_time = 0.0
    correct = 0
    retrieval_success = 0  # 检索成功的次数

    for item in tqdm(test_df, desc="Testing"):
        question = item["question"]
        options = item["options"]
        answer = item["answer"]

        # 格式化输入
        query = format_query(question, options)

        # 生成描述（集成 RAG）
        description, retrieved_texts, inference_time = generate_description_with_rag(model, tokenizer, query, index, texts)
        total_time += inference_time

        # 提取模型答案
        model_answer = extract_answer(description)

        # 检查是否正确
        is_correct = model_answer == answer
        if is_correct:
            correct += 1

        # 检查检索是否成功
        if retrieved_texts:  # 如果有检索到的文本，则认为检索成功
            retrieval_success += 1

        # 保存结果
        results.append({
            "question": question,
            "options": options,
            "answer": answer,
            "model_answer": model_answer,
            "is_correct": is_correct,
            "inference_time": inference_time,
            "retrieved_texts": retrieved_texts
        })

    # 计算准确率和平均推理时间
    accuracy = correct / len(test_df)
    avg_inference_time = total_time / len(test_df)
    retrieval_success_rate = retrieval_success / len(test_df)

    return results, accuracy, avg_inference_time, retrieval_success_rate

# 保存测试结果
def save_results(results, output_path):
    with open(output_path, "w") as f:
        json.dump(results, f, indent=4)

# 主函数
def main_jupyter(model_path="/root/autodl-tmp/glm-4v-9b", output_path="/root/autodl-tmp/health_test_results_with_rag.json"):
    # 加载模型
    model, tokenizer = load_model(model_path)

    # 加载文本嵌入模型
    global text_embedding_model
    text_embedding_model = load_text_embedding_model("/root/autodl-tmp/all-MiniLM-L6-v2")

    # 加载文本检索系统
    index_file = "/root/autodl-tmp/text_index.faiss"
    texts_file = "/root/autodl-tmp/text_texts.json"
    index, texts = load_text_retrieval_system(index_file, texts_file)

    # 加载 MMLU-Pro 数据集
    test_df, val_df = load_mmlu_pro()
    test_df = preprocess(test_df)

    # 筛选 Health 类别的数据
    health_test_df = filter_health_category(test_df)
    print(f"Health 类别数据量: {len(health_test_df)}")

    # 测试模型（集成 RAG）
    results, accuracy, avg_inference_time, retrieval_success_rate = test_model_with_rag(model, tokenizer, health_test_df, index, texts)

    # 保存结果
    save_results(results, output_path)

    # 打印结果
    print(f"测试完成！\nHealth 类别准确率: {accuracy * 100:.2f}%\n平均推理时间: {avg_inference_time:.4f} 秒")
    print(f"检索成功率: {retrieval_success_rate * 100:.2f}%")

# 在 Jupyter Notebook 中直接调用
main_jupyter()

Loading model from: /root/autodl-tmp/glm-4v-9b
Tokenizer loaded successfully.


Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

Model loaded successfully.
Loading text embedding model from: /root/autodl-tmp/all-MiniLM-L6-v2
Text embedding model loaded successfully.
Loading text retrieval system from: /root/autodl-tmp/text_index.faiss and /root/autodl-tmp/text_texts.json
Text retrieval system loaded successfully.
Health 类别数据量: 818


Testing: 100%|██████████| 818/818 [04:10<00:00,  3.26it/s]


测试完成！
Health 类别准确率: 35.45%
平均推理时间: 0.2601 秒
检索成功率: 100.00%
