In [None]:
from langchain.llms import HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain.prompts import PromptTemplate
#from langchain_huggingface import HuggingFacePipeline
# 1. 加载本地模型
model_path = "models/Qwen3-1.7B"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", trust_remote_code=True)
# 2. 创建文本生成pipeline
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=3000,
    temperature=0.2,
    do_sample=True
)
# 3. 转换为LangChain可用的LLM
llm = HuggingFacePipeline(pipeline=pipe)
template = """<|im_start|>system
你是一个法律AI助手，请根据提供的上下文信息回答问题。请直接回答问题，不要解释推理过程。<|im_end|>
<|im_start|>user
上下文：{context}

问题：{question}<|im_end|>
<|im_start|>assistant
"""
prompt = PromptTemplate.from_template(template)

In [None]:
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import LLMChain
# 加载同一个 embedding 模型（必须和保存时一致）
EMBED_PATH = "models/BAAI--bge-small-zh-v1.5"
embedding_model = HuggingFaceEmbeddings(model_name=EMBED_PATH)
# 加载向量数据库，并允许危险的反序列化
vectorstore = FAISS.load_local("faiss_lawdb", embedding_model, allow_dangerous_deserialization=True)
qa_chain = LLMChain(prompt=prompt, llm=llm)

In [None]:
def build_reference(docs: list) -> str:
    references = []
    for doc in docs:
        meta = doc.metadata
        law_type = meta.get("law_type", "未知法律")
        structure = meta.get("structure", {})
        bian = structure.get("编", {}).get("序号", "")
        bian_title = structure.get("编", {}).get("标题", "")
        zhang = structure.get("章", {}).get("序号", "")
        zhang_title = structure.get("章", {}).get("标题", "")
        jie = structure.get("节", {}).get("序号", "")
        jie_title = structure.get("节", {}).get("标题", "")
        # 拼接结构字符串
        structure_str = f"第{bian}编 {bian_title} 第{zhang}章 {zhang_title}"
        if jie:
            structure_str += f" 第{jie}节 {jie_title}"
        content = doc.page_content
        reference = f"《{law_type}》 {structure_str}\n{content}"
        references.append(reference)

    return "\n\n".join(references)  # 多条之间空两行分隔

def extract_final_answer(response_text: str,docs) -> str:
    # 提取 assistant 标签之后的内容（如果有）
    if "<|im_start|>assistant" in response_text:
        raw_an = response_text.split("<|im_start|>assistant")[-1].strip()
    else:
        raw_an = response_text.strip()
    # 提取被 <think> 包围的内容
    import re
    think_matches = re.findall(r"<think>(.*?)</think>", raw_an, re.DOTALL)
    # 去除 <think> 标签及其内容，保留非 think 内容
    raw_an_without_think = re.sub(r"<think>.*?</think>", "", raw_an, flags=re.DOTALL).strip()
    reference = build_reference(docs)
    # 拼接回答 + HTML 格式参考法条
    answer = (
        f"{raw_an_without_think}<br><div class='reference-box'>参考条文：\n{reference}</div>"
    )
    return answer
import gradio as gr
custom_css = """
.reference-box {
    background-color: #f0f4ff;  /* 淡蓝色背景 */
    border-left: 4px solid #3b82f6;
    padding: 8px;
    margin-top: 6px;
    border-radius: 4px;
    font-size: 14px;
    color: #333;
}
"""
def gr_rag_chatbot(history, query):
    docs = vectorstore.similarity_search(query, k=3)
    context = "\n".join([doc.page_content for doc in docs])
    response = qa_chain.invoke({"context": context, "question": query})
    raw_answer = response['text']
    answer = extract_final_answer(raw_answer,docs)
    # 更新对话历史：右边用户，下面AI回答
    history.append((f"👤 {query}", f"🤖 {answer}"))
    return history, ""  # 返回对话内容 & 清空输入框
with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("# 📚 法律助手 Demo")
    chatbot = gr.Chatbot(label="对话记录", height=600)
    with gr.Row():
        msg_input = gr.Textbox(
            placeholder="请输入你的法律问题，例如：婚姻无效的情形有哪些？",
            show_label=False,
            lines=3,
            scale=4
        )
        send_btn = gr.Button("发送", scale=1)
    state = gr.State([])  # 保存对话历史
    send_btn.click(
        fn=gr_rag_chatbot,
        inputs=[state, msg_input],
        outputs=[chatbot, msg_input]
    )
demo.launch(server_name="0.0.0.0", server_port=7860)