In [None]:
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableLambda
from fastapi import FastAPI, Request
from pydantic import BaseModel
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableLambda
from tools.memory import save_recall_memory, search_recall_memories
from tools.rag import rag_tool
from tools.llm import llm_chat_tool
from tools.search import search_tool


In [None]:
def load_memories(state: State, config: RunnableConfig) -> State:
    """Load memories for the current conversation.

    Args:
        state (schemas.State): The current state of the conversation.
        config (RunnableConfig): The runtime configuration for the agent.

    Returns:
        State: The updated state with loaded memories.
    """

    vectorstore_long_term_memory.simalarity_search

    conv_str = get_buffer_string(state["messages"][-3:]) # get all messages in the conversation or change to 2-3
    # conv_str = tokenizer.decode(tokenizer.encode(conv_str)[-2048:]) # tokenize last messages and limit to 2048 tokens
    recall_memories = search_recall_memories.invoke(conv_str, config)
    return {
        "messages": recall_memories,
    }


def agent(state: State) -> State:
    """Process the current state and generate a response using the LLM.

    Args:
        state (schemas.State): The current state of the conversation.

    Returns:
        schemas.State: The updated state with the agent's response.
    """
    bound = prompt | llm_with_tools
    recall_str = (
        "<recall_memory>\n" + "\n".join(state["messages"]) + "\n</recall_memory>"
    )
    prediction = bound.invoke(
        {
            "messages": state["messages"],
            "recall_memories": recall_str,
        }
    )

    # save the response to the long-term memory
    prediction.add_document(
        Document(
            page_content=prediction.content,
            metadata={
                "role": "assistant",
                "source": "agent_response",
                "timestamp": prediction.created_at,
                user_id thread id
            },
        )
    )

    return {
        "messages": [prediction],
    }


def route_tools(state: State):
    """Determine whether to use tools or end the conversation based on the last message.

    Args:
        state (schemas.State): The current state of the conversation.

    Returns:
        Literal["tools", "__end__"]: The next step in the graph.
    """
    msg = state["messages"][-1]
    if msg.tool_calls:
        return "tools"

    return END

In [None]:
def build_agent():
    builder = StateGraph()
    builder = StateGraph(State)

    builder.add_node(load_memories)
    builder.add_node(agent)
    builder.add_node("tools", ToolNode(tools))

    builder.add_edge(START, "load_memories")
    builder.add_edge("load_memories", "agent")
    builder.add_conditional_edges("agent", route_tools, ["tools", END])
    builder.add_edge("tools", "agent")

    memory = InMemorySaver()
    graph = builder.compile(checkpointer=memory)
    

    return builder.compile()

In [None]:
collection_name = "recall_memory"

vectorstore_recall = rag_tool(
    client_qd=client_qd,
    collection_name=collection_name,
    embeddings=embeddings

In [None]:
from fastapi import FastAPI, Request
from pydantic import BaseModel
# from langgraph_agent import build_agent

app = FastAPI()
agent = build_agent()

class UserInput(BaseModel):
    message: str
    session_id: str  # Optional: for multi-user memory

@app.post("/chat")
def chat(user_input: UserInput):
    state = {"input": user_input.message, "session_id": user_input.session_id}
    result = agent.invoke(state)
    return {"response": result.get("input", "Sorry, something went wrong.")}