In [None]:
# enable reloading
%load_ext autoreload
# all the modules should be reloaded before executing the code
%autoreload 2


In [None]:
from pathlib import Path
from typing import Annotated, Literal

import rootutils
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langgraph.types import Command, interrupt
from loguru import logger
from pydantic import BaseModel

rootutils.setup_root(search_from=str(Path.cwd().parent), indicator=[".git", "pyproject.toml"], pythonpath=True)
from src.agent.my_mcps import mcp_config


In [None]:
"""
Utility functions for context engineering notebooks.
"""

from rich.console import Console
from rich.panel import Panel
import json

console = Console()


def format_message_content(message):
    """Convert message content to displayable string"""
    if isinstance(message.content, str):
        return message.content
    elif isinstance(message.content, list):
        # Handle complex content like tool calls
        parts = []
        for item in message.content:
            if item.get("type") == "text":
                parts.append(item["text"])
            elif item.get("type") == "tool_use":
                parts.append(f"\n🔧 Tool Call: {item['name']}")
                parts.append(f"   Args: {json.dumps(item['input'], indent=2)}")
        return "\n".join(parts)
    else:
        return str(message.content)


def format_messages(messages):
    """Format and display a list of messages with Rich formatting"""
    for m in messages:
        msg_type = m.__class__.__name__.replace("Message", "")
        content = format_message_content(m)

        if msg_type == "Human":
            console.print(Panel(content, title="🧑 Human", border_style="blue"))
        elif msg_type == "Ai":
            console.print(Panel(content, title="🤖 Assistant", border_style="green"))
        elif msg_type == "Tool":
            console.print(Panel(content, title="🔧 Tool Output", border_style="yellow"))
        else:
            console.print(Panel(content, title=f"📝 {msg_type}", border_style="white"))


In [None]:
from langgraph.graph import MessagesState


class States(MessagesState):
    """State of conversation between Agent and User."""

    # messages: Annotated[list[BaseMessage], add_messages] = []


protected_tools: list[str] = ["create_directory", "edit_file", "write_file"]

In [None]:
client = MultiServerMCPClient(connections=mcp_config["mcpServers"])
tools = await client.get_tools()

In [None]:
from langchain_ollama import ChatOllama
from langchain_perplexity import ChatPerplexity
from langchain_google_genai import ChatGoogleGenerativeAI

# llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash").bind_tools(tools)

# llm = ChatPerplexity(model="sonar-pro", temperature=0)
llm = ChatOllama(model="qwen3:8b", temperature=0).bind_tools(tools)
# from langchain_openai import ChatOpenAI

# llm = ChatOpenAI(
#     model="gpt-4.1-mini-2025-04-14",
#     temperature=0.1,
# ).bind_tools(tools)

llm.invoke("hii how are you ? ")

In [None]:
from langgraph.checkpoint.memory import MemorySaver


def human_tool_review_node(
    state: States,
) -> Command[Literal["tools", "assistant_node"]]:
    """Node is a placeholder for the human to review the final report generation process to verify proper tool call checks before tools are called by the agent."""
    print("[INFO] human_tool_review_node called")
    last_message = state["messages"][-1]

    # Ensure we have a valid AI message with tool calls
    if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
        msg = "human_tool_review_node called without valid tool calls"
        logger.error(msg)
        raise ValueError(msg)

    tool_call = last_message.tool_calls[-1]

    # Stop graph execution and wait for human input
    human_review: dict = interrupt(
        {"message": "Your input is required for the following tool:", "tool_call": tool_call},
    )
    review_action = human_review.get("action")
    review_data = human_review.get("data")

    if review_action == "accept":
        return Command(
            goto="tools",
        )
    return Command(
        goto="assistant_node",
        update={
            "messages": [
                HumanMessage(content=review_data),
            ],
        },
    )


def assistant_node(state: States) -> States:
    print("[INFO] assistant_node called")
    response = llm.invoke(
        [
            SystemMessage(
                content="You are a helpful assistant. You have access to the local filesystem but only within an approved directory. The approved directory is /projects/workspace and all paths must begin with /projects/workspace/. You must use /project/workspace/generated_example directory. if directory does not exists then create it and then give a good name of the <file_name>.md file (for example sw_design.md) and save the generated report in /project/workspace/generated_example directory.",
            ),
            *state["messages"],
        ],
    )
    state["messages"] = [*state["messages"], response]
    return state


def router(state: States) -> str:
    print("[INFO] router called")
    last_message = state["messages"][-1]
    if isinstance(last_message, AIMessage) and last_message.tool_calls:
        if any(tool_call["name"] in protected_tools for tool_call in last_message.tool_calls):
            return "human_tool_review_node"
        return "tools"
    return END


builder = StateGraph(States)

builder.add_node("assistant_node", assistant_node)
builder.add_node("human_tool_review_node", human_tool_review_node)
builder.add_node("tools", ToolNode(tools))

builder.add_edge(START, "assistant_node")
builder.add_conditional_edges("assistant_node", router, ["tools", "human_tool_review_node", END])
builder.add_edge("tools", "assistant_node")

graph = builder.compile(checkpointer=MemorySaver())
graph

In [None]:
_input = {
    "messages": [
        HumanMessage(
            content="Generate a report on the project planning process. I don't know where to start, i want to create simple chatbot using langgraph. i am testing that you can use filesystem or not. simply generate a report without asking further question.",
        ),
    ],
}

In [None]:
# Thread
from langchain_core.messages import AIMessageChunk

thread = {"configurable": {"thread_id": "1"}}


async def test2():
    async for stream_mode, chunk in graph.astream(_input, thread, stream_mode=["updates", "messages"]):
        # format_messages(event["messages"])  # event["messages"][-1].pretty_print()
        # if stream_mode == "updates":
        #     graph_name = list(chunk.keys())[0]
        #     print(graph_name)
        #     message = chunk[graph_name]["messages"][-1]
        #     yield message
        if stream_mode == "messages":  # TODO: grab graph name here and use token streaming from messages
            message, metadata = chunk
            subgraph_name = metadata["langgraph_node"]
            if isinstance(message, AIMessageChunk):
                if message.response_metadata:
                    finish_reason = message.response_metadata.get("finish_reason", "")
                    if finish_reason == "tool_calls":
                        yield "\n\n", subgraph_name

                if message.tool_call_chunks:
                    tool_chunk = message.tool_call_chunks[0]

                    tool_name = tool_chunk.get("name", "")
                    args = tool_chunk.get("args", "")

                    if tool_name:
                        tool_call_str = f"\n\n< TOOL CALL: {tool_name} >\n\n"
                    if args:
                        tool_call_str = args

                    yield tool_call_str, subgraph_name
                else:
                    yield message.content, subgraph_name

In [None]:
from platform import node


async def test2():
    async for chunk in graph.astream(_input, thread, stream_mode="updates"):
        node_name = next(iter(chunk.keys()))
        if node_name == "assistant_node":
            msg = chunk[node_name]
            msg = msg["messages"][-1].content
            yield f"{node_name}: {msg}"
        elif node_name == "__interrupt__":
            total_interrupts = []
            for _interrupts in chunk[node_name]:
                msg = _interrupts.value.get("message")
                tool_call = _interrupts.value.get("tool_call")
                tool_name = tool_call.get("name")
                tool_args = tool_call.get("args")
                formatted_string = "\n".join(f"{key.capitalize()}: {value}" for key, value in tool_args.items())
                tool_call_str = f"\n{msg}\n\n< TOOL CALL: tool_name: {tool_name} >\ntool_arg: {formatted_string}"
                total_interrupts.append(tool_call_str)
            yield "\n\n".join(total_interrupts)

        # if "messages" in event:
        #     latest_message = event["messages"][-1]
        #     if hasattr(latest_message, "tool_calls") and latest_message.tool_calls:
        #         print(latest_message.tool_calls)
        #         yield latest_message.tool_calls
        #     yield latest_message.content


In [None]:
async for response in test2():
    # if isinstance(response, dict):
    print(response)
    # keys = list(response.keys())
    # graph_name = keys[0]
    # print(graph_name)
    # message = response[graph_name]["messages"][-1]
    # print(message.content)


In [None]:
subgraph


In [None]:
graph.get_state(config=thread)


In [None]:
response


In [None]:
# new_state = graph.get_state(thread).values
# for m in new_state["messages"]:
#     m.pretty_print()
# async for event in graph.astream(_input, thread, stream_mode="values"):
#     event["messages"][-1].pretty_print()


In [None]:
# graph.update_state(
#     thread,
#     {"messages": [HumanMessage(content="accept")]},
# )
# _input = {"messages": [HumanMessage(content="accept")]}

# new_state = graph.get_state(thread).values
# for m in new_state["messages"]:
#     m.pretty_print()
async for event in graph.astream(Command(resume={"action": "accept", "data": ""}), thread, stream_mode="values"):
    event["messages"][-1].pretty_print()

In [None]:
async for event in graph.astream(None, thread, stream_mode="values"):
    event["messages"][-1].pretty_print()


In [None]:
# from langchain.chat_models import init_chat_model

# model_shell = init_chat_model(
#     configurable_fields=("model", "max_tokens"),
# )

# report_generator_config = {
#     "model": "ollama:qwen3:8b",
# }
# report_generator_model = model_shell.with_config(report_generator_config)
# report_generator_model.invoke("hello world")