In [None]:
import os, torch, warnings, re, gc
from transformers import AutoTokenizer, LlamaForCausalLM
from peft import PeftModel
import pandas as pd
from tqdm import tqdm
from transformers import BitsAndBytesConfig

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
warnings.filterwarnings("ignore")
device = "cuda" if torch.cuda.is_available() else "cpu"
print("使用设备:", device)

def clear_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

base_model_path = "huggyllama/llama-7b"
adapter_path = "D:\\model\\ShenNong-TCM-LLM" 

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    llm_int8_skip_modules=["lm_head"]  
)

tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = LlamaForCausalLM.from_pretrained(
    base_model_path,
    quantization_config=bnb_config,
    device_map="auto",
    low_cpu_mem_usage=True
)
clear_memory()

model = PeftModel.from_pretrained(base_model, adapter_path).eval()
clear_memory()

def shennong_predict(prompt: str, question: str, candidate_answers: str) -> str:

    input_text = f"{prompt}\n题目：{question}\n选项：{candidate_answers}\n答案："
    
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    generation_config = {
        "input_ids": inputs.input_ids,
        "attention_mask": inputs.attention_mask,
        "max_new_tokens": 64,  
        "temperature": 0.3,
        "do_sample": True,
        "top_p": 0.95,
        "pad_token_id": tokenizer.eos_token_id
    }
    
    outputs = model.generate(**generation_config)
    generated_ids = outputs[0][len(inputs.input_ids[0]):]
    raw_answer = tokenizer.decode(generated_ids, skip_special_tokens=True)
    
    answer_match = re.search(r'[A-Z](?:、[A-Z])*', raw_answer)
    return answer_match.group(0) if answer_match else "解析失败"

# %% 3. 分批处理Excel数据
INPUT_XLSX = "合理用药数据集v4.1-测试用.xlsx"
OUTPUT_XLSX = "神农模型结果_new.xlsx"

# 精简提示词减少token长度
PROMPTS = {
    "单选": "以下是关于中药处方审核的单选题，请根据规则选择正确的选项。仅输出选项即可。",
    "多选": "以下是关于中药处方审核的多选题，请根据规则选择所有正确的选项。仅输出选项即可。"
}

xls = pd.ExcelFile(INPUT_XLSX)
with pd.ExcelWriter(OUTPUT_XLSX) as writer:
    for sheet in ["单选", "多选"]:
        if sheet not in xls.sheet_names:
            print(f"⚠️ 找不到工作表：{sheet}，跳过")
            continue
            
        df = pd.read_excel(xls, sheet_name=sheet)
        prompt = PROMPTS[sheet]
        model_answers = []
        
        # 分批处理避免内存溢出
        batch_size = 8  # 根据内存调整
        for i in tqdm(range(0, len(df), batch_size), desc=f"处理{sheet}"):
            batch = df.iloc[i:i+batch_size]
            batch_answers = []
            
            for _, row in batch.iterrows():
                question = str(row["Question"])
                candidate_answers = str(row["Candidate answers"])
                ans = shennong_predict(prompt, question, candidate_answers)
                batch_answers.append(ans)
                clear_memory()  # 每条处理后清理
            
            model_answers.extend(batch_answers)
            del batch, batch_answers
            clear_memory()
        
        df["模型答案"] = model_answers
        df.to_excel(writer, sheet_name=sheet, index=False)
        del model_answers
        clear_memory()

print("✅ 推理完成，结果已写入", OUTPUT_XLSX)

In [None]:
import pandas as pd
import re

INPUT_FILE = "神农模型结果_new.xlsx"
OUTPUT_FILE = "神农模型结果_new_清洗后.xlsx"

def clean_answer(ans):
    """清洗模型答案"""
    if pd.isna(ans):
        return ""
    # 提取所有大写字母
    letters = re.findall(r"[A-Z]", str(ans).upper())
    if len(letters) == 1:
        return letters[0]          
    return " ".join(letters)       

# 读取
xls = pd.ExcelFile(INPUT_FILE)

with pd.ExcelWriter(OUTPUT_FILE) as writer:
    for sheet in xls.sheet_names:
        df = pd.read_excel(xls, sheet_name=sheet)
        if "模型答案" in df.columns:
            df["模型答案"] = df["模型答案"].apply(clean_answer)
        df.to_excel(writer, sheet_name=sheet, index=False)

print("清洗完成 →", OUTPUT_FILE)