In [5]:
from typing import TypedDict, Dict, List, Optional, Literal
from langgraph.graph import StateGraph, END
from langchain.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from py2neo import Graph as Py2NeoGraph
from pydantic import BaseModel
import json

# External imports from your modules
from graph_chain import build_graph_qa_chain
from retriever_weaviate import retriever_weaviate
from determine_database import final_answer
from retriever_answer_async import partial_answer_async

In [7]:
from prompts import (
    determine_prompt,
    recall_prompt,
)

In [8]:
class AgentState(TypedDict):
    question: str
    question_type: Optional[str]
    parsed_output: Dict
    step: int
    database: Optional[str]
    strategy: Optional[str]
    context_history: Dict[str, str]
    retrieved_chunks: List[Dict]
    sufficient: bool
    llm: object
    answer: Optional[str]
    record: Optional[str]


class QuestionAnalysis(BaseModel):
    question_type: Literal["Knowledge-Type", "Entity-Type", "Mixed-Type"]
    database_to_call: Literal["Literature Text Database", "Literature Graph Database", "Both"]
    first_database_to_call: str
    methods: Optional[str] = None
    call_strategy: str


class RecallDecision(BaseModel):
    sufficient: bool
    next_database: Optional[str] = None
    reason: Optional[str] = None
    strategy: Optional[str] = None

In [10]:
huggingface_key = "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"
wcd_url = "https://yk7x0nnmqzuvgr8h2wu2sa.c0.asia-southeast1.gcp.weaviate.cloud"
wcd_api_key = "VqrACrlTS5xlfLNY2aAjHoLEI9RVU3EPDaMt"

uri = "neo4j+s://c0abcb56.databases.neo4j.io"
username = "neo4j"
password = "Vr5PhOR-n657dwRQsDfVWy_EYIE3QUUU59p7eOxJ39Q"

In [11]:
from langchain_openai import ChatOpenAI
import os
os.environ["OPENAI_API_KEY"] = "sk-58c2a9a30bf74bc0bd69688acc27c83e"  

# 替换模型初始化
llm = ChatOpenAI(model="qwen-turbo-1101",
                base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") 

In [12]:
def determine_question_type(state: AgentState) -> AgentState:
    prompt_template = ChatPromptTemplate.from_template(determine_prompt)
    model = state['llm'].with_structured_output(QuestionAnalysis)
    result = model.invoke(prompt_template.format_messages(question=state['question']))
    state["question_type"] = result.question_type
    state["parsed_output"] = result.dict()
    state["database"] = result.first_database_to_call if result.question_type == "Mixed-Type" else result.database_to_call
    state["strategy"] = result.call_strategy
    return state


def handle_entity_type(state: AgentState) -> AgentState:
    graph = Py2NeoGraph(uri, auth=("neo4j", password))
    graph_qa_chain = build_graph_qa_chain(state['llm'], graph, entity_cache)
    result = graph_qa_chain.invoke({"question": state["question"]})
    state["answer"] = result
    return state


async def handle_knowledge_type(state: AgentState) -> AgentState:
    all_result, final_response, _, _ = await partial_answer_async(
        state["llm"], state["question"], wcd_url, wcd_api_key, huggingface_key
    )
    state["answer"] = final_response
    state["retrieved_chunks"] = all_result
    return state


def graph_retrieval_node(state: AgentState) -> AgentState:
    graph = Py2NeoGraph(uri, auth=("neo4j", password))
    graph_qa_chain = build_graph_qa_chain(state['llm'], graph, entity_cache)
    result = graph_qa_chain.invoke({"question": f"the entities needs to be searched{state['strategy']}"})
    state["context_history"][f"step {state['step']} from graph"] = str(result)
    state["retrieved_chunks"] = result
    return state

def text_retrieval_node(state: AgentState) -> AgentState:
    chunks = retriever_weaviate(state['strategy'], wcd_url, wcd_api_key, huggingface_key, limit=4)
    context_list = [{'chunk': item['chunk'], 'title': item['title']} for item in chunks]
    state["context_history"][f"step {state['step']} from text"] = json.dumps(context_list)
    state["retrieved_chunks"] = chunks
    return state


def recall_decision_node(state: AgentState) -> AgentState:
    prompt_template = ChatPromptTemplate.from_template(recall_prompt)
    model = state['llm'].with_structured_output(RecallDecision)
    context_str = json.dumps(state["context_history"], indent=2)
    result = model.invoke(
        prompt_template.format_messages(
            question=state["question"],
            context=context_str,
            strategy=state["strategy"]
        )
    )
    state["sufficient"] = result.sufficient
    if not result.sufficient:
        state["database"] = result.next_database
        state["strategy"] = result.strategy
    return state


def final_answer_node(state: AgentState) -> AgentState:
    state["answer"] = final_answer(
        state["llm"], state["question"], state["context_history"], state["strategy"], state.get("record", "")
    )
    return state


def route_by_question_type(state: AgentState) -> str:
    if state["question_type"] == "Entity-Type":
        return "handle_entity_type"
    elif state["question_type"] == "Knowledge-Type":
        return "handle_knowledge_type"
    else:
        return "graph_retrieval" if state["database"] == "Literature Graph Database" else "text_retrieval"


def route_after_recall(state: AgentState) -> str:
    if state["sufficient"]:
        return "final_answer"
    else:
        return "graph_retrieval" if state["database"] == "Literature Graph Database" else "text_retrieval"

In [15]:
file_path = "entity_cache.json"
with open(file_path, 'r', encoding='utf-8') as f:
    entity_cache = json.load(f)

builder = StateGraph(AgentState)

builder.add_node("determine_question_type", determine_question_type)
builder.add_node("handle_entity_type", handle_entity_type)
builder.add_node("handle_knowledge_type", handle_knowledge_type)
builder.add_node("graph_retrieval", graph_retrieval_node)
builder.add_node("text_retrieval", text_retrieval_node)
builder.add_node("recall_decision", recall_decision_node)
builder.add_node("final_answer", final_answer_node)

builder.set_entry_point("determine_question_type")
builder.add_conditional_edges("determine_question_type", route_by_question_type)
builder.add_edge("handle_entity_type", END)
builder.add_edge("handle_knowledge_type", END)
builder.add_edge("graph_retrieval", "recall_decision")
builder.add_edge("text_retrieval", "recall_decision")
builder.add_conditional_edges("recall_decision", route_after_recall)
builder.add_edge("final_answer", END)

graph = builder.compile()


async def agent(llm, question: str, record: str = ""):
    initial_state: AgentState = {
        "question": question,
        "parsed_output": {},
        "step": 1,
        "database": None,
        "strategy": None,
        "context_history": {},
        "retrieved_chunks": [],
        "sufficient": False,
        "llm": llm,
        "record": record,
        "answer": None,
        "question_type": None,
    }
    result = await graph.ainvoke(initial_state)
    return result

In [16]:
!pip install supabase

Collecting supabase
  Downloading supabase-2.17.0-py3-none-any.whl.metadata (11 kB)
Collecting gotrue==2.12.3 (from supabase)
  Downloading gotrue-2.12.3-py3-none-any.whl.metadata (6.5 kB)
Collecting postgrest==1.1.1 (from supabase)
  Downloading postgrest-1.1.1-py3-none-any.whl.metadata (3.5 kB)
Collecting realtime==2.6.0 (from supabase)
  Downloading realtime-2.6.0-py3-none-any.whl.metadata (6.6 kB)
Collecting storage3==0.12.0 (from supabase)
  Downloading storage3-0.12.0-py3-none-any.whl.metadata (1.9 kB)
Collecting supafunc==0.10.1 (from supabase)
  Downloading supafunc-0.10.1-py3-none-any.whl.metadata (1.2 kB)
Collecting pyjwt<3.0.0,>=2.10.1 (from gotrue==2.12.3->supabase)
  Downloading PyJWT-2.10.1-py3-none-any.whl.metadata (4.0 kB)
Collecting strenum<0.5.0,>=0.4.9 (from postgrest==1.1.1->supabase)
  Downloading StrEnum-0.4.15-py3-none-any.whl.metadata (5.3 kB)
Collecting pydantic<3,>=1.10 (from gotrue==2.12.3->supabase)
  Downloading pydantic-2.11.7-py3-none-any.whl.metadata (67

  You can safely remove it manually.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
langchain 0.3.20 requires langsmith<0.4,>=0.1.17, but you have langsmith 0.4.8 which is incompatible.
langchain-community 0.3.5 requires langsmith<0.2.0,>=0.1.125, but you have langsmith 0.4.8 which is incompatible.
tabled-pdf 0.1.4 requires scikit-learn<2.0.0,>=1.5.2, but you have scikit-learn 1.2.2 which is incompatible.
zhipuai 2.1.5.20230904 requires pyjwt<2.9.0,>=2.8.0, but you have pyjwt 2.10.1 which is incompatible.
  return process_handler(cmd, _system_body)
  return process_handler(cmd, _system_body)
  return process_handler(cmd, _system_body)
