In [None]:
%%capture --no-stderr
%pip install --quiet -U langchain_core langgraph langchain_google_genai

In [None]:
import os, getpass

def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")

_set_env("GOOGLE_API_KEY")

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI

llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.0) 

In [None]:
 
def add(x: float, y:float) -> float:
    """Add 'x' and 'y'."""
    return x + y

 
def subtract(x: float, y:float) -> float:
    """Subtract 'x' and 'y'."""
    return x - y


def multiply(x: float, y:float) -> float:
    """Multiply 'x' and 'y'."""
    return x * y


def divide(x: float, y:float) -> float:
    """Divide 'x' and 'y'."""
    return x / y


def exponentiate(x: float, y: float) -> float:
    """Raise 'x' to the power of 'y'."""
    return x ** y

llm_with_tools = llm.bind_tools([add, subtract, multiply, divide, exponentiate])

In [None]:
from IPython.display import Image, display
from langgraph.graph import StateGraph, START, END
from langgraph.graph import MessagesState
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import tools_condition
from langchain_core.messages import SystemMessage, HumanMessage

sys_message = SystemMessage(
    """You are an helpful assistant tasked with performing arithmetic on a set of inputs.
    
    ----
    User: 2 divide 3
    AI: The answer is 0.66
    
    User: 20 divide -7
    AI: The answer is -2.85
    
    User: 20 multiply 7
    AI: The answer is 140.00    
    
    ----    
    
    """
) 


def assistant(state: MessagesState):
    return {"messages": llm_with_tools.invoke([sys_message] + state["messages"])}


builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode([add, subtract, multiply, divide, exponentiate]))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")


from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()

graph = builder.compile(checkpointer=memory, interrupt_before=["assistant"])

display(Image(graph.get_graph(xray=True).draw_mermaid_png()))

In [None]:
initial_input = {
    "messages": "Multiply 2 by 3"
}

thread = {
    "configurable":{
        "thread_id": "1"
    }
}

for event in graph.stream(initial_input, thread, stream_mode="values"):
    event["messages"][-1].pretty_print()
    


In [None]:
state = graph.get_state(thread)
state

In [None]:
graph.update_state(
    thread,
    {"messages":[HumanMessage(content="No, actually multiply 3 by 3")]}
)

In [None]:
new_state = graph.get_state(thread).values

for m in new_state["messages"]:
    m.pretty_print()

In [None]:
for event in graph.stream(None, thread, stream_mode="values"):
    event["messages"][-1].pretty_print()

In [None]:
for event in graph.stream(None, thread, stream_mode="values"):
    event["messages"][-1].pretty_print()

In [None]:

URL = "http://127.0.0.1:2024"

from langgraph_sdk import get_client

In [None]:
client = get_client(url=URL)

assistants = await client.assistants.search()

In [None]:
assistants

In [None]:
thread = await client.threads.create()

In [None]:
from langchain_core.messages import HumanMessage

input_message = {"messages": [HumanMessage(content="Multiply 3 by 12")]}

async for chunk in client.runs.stream(
    thread["thread_id"],
    "agent",
    input = input_message,
    stream_mode="values",
    interrupt_before = ["assistant"]
):

    messages = chunk.data.get("messages", [])
    if messages:
        print(messages[-1])
    print("---"*50)

         


In [None]:
state = await client.threads.get_state(thread["thread_id"])
state

In [None]:
lastmessage = state["values"]["messages"][-1]
lastmessage

In [None]:
lastmessage["content"] = "No actually multiply 3 by 3"
lastmessage

In [None]:
await client.threads.update_state(thread["thread_id"], {"messages": lastmessage})

In [None]:
async for chunk in client.runs.stream(
    thread["thread_id"],
    "agent",
    input = None,
    stream_mode="values",
    interrupt_before = ["assistant"]
):

    messages = chunk.data.get("messages", [])
    if messages:
        print(messages[-1])
    print("---"*50)

In [None]:
async for chunk in client.runs.stream(
    thread["thread_id"],
    "agent",
    input = None,
    stream_mode="values",
    interrupt_before = ["assistant"]
):

    messages = chunk.data.get("messages", [])
    if messages:
        print(messages[-1])
    print("---"*50)

In [None]:
## Human Interupt

from langchain_core.runnables.graph import MermaidDrawMethod

sys_message = SystemMessage(content="You are helpful assistant tasked with performing arithmatic on a set of input")

def human_feedback(state: MessagesState):
    pass

def assistant(state: MessagesState):
    return {"messages":[llm_with_tools.invoke([sys_message]+state["messages"])]}


builder = StateGraph(MessagesState)

builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode([add, subtract, multiply, divide, exponentiate]))
builder.add_node("human_feedback", human_feedback)

builder.add_edge(START, "human_feedback")
builder.add_edge("human_feedback", "assistant")
builder.add_conditional_edges(
    "assistant",
    tools_condition
)
builder.add_edge("tools", "human_feedback")

memory = MemorySaver()

graph = builder.compile(interrupt_before=["human_feedback"], checkpointer=memory)

display(Image(graph.get_graph().draw_mermaid_png()))



In [None]:
# Input
initial_input = {"messages": "Multiply 2 and 3"}

# Thread
thread = {"configurable": {"thread_id": "5"}}

# Run the graph until the first interruption
for event in graph.stream(initial_input, thread, stream_mode="values"):
    event["messages"][-1].pretty_print()
    
# Get user input
user_input = input("Tell me how you want to update the state: ")

# We now update the state as if we are the human_feedback node
graph.update_state(thread, {"messages": user_input}, as_node="human_feedback")

# Continue the graph execution
for event in graph.stream(None, thread, stream_mode="values"):
    event["messages"][-1].pretty_print()

In [None]:
# Continue the graph execution
for event in graph.stream(None, thread, stream_mode="values"):
    event["messages"][-1].pretty_print()