In [None]:
%%capture --no-stderr
%pip install --quiet -U langchain_core langgraph langchain_google_genai

In [None]:
import os, getpass

def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")

_set_env("GOOGLE_API_KEY")

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI

llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.0) 

In [None]:
 
def add(x: float, y:float) -> float:
    """Add 'x' and 'y'."""
    return x + y

 
def subtract(x: float, y:float) -> float:
    """Subtract 'x' and 'y'."""
    return x - y


def multiply(x: float, y:float) -> float:
    """Multiply 'x' and 'y'."""
    return x * y


def divide(x: float, y:float) -> float:
    """Divide 'x' and 'y'."""
    return x / y


def exponentiate(x: float, y: float) -> float:
    """Raise 'x' to the power of 'y'."""
    return x ** y

llm_with_tools = llm.bind_tools([add, subtract, multiply, divide, exponentiate])

In [None]:
from IPython.display import Image, display
from langgraph.graph import StateGraph, START, END
from langgraph.graph import MessagesState
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import tools_condition
from langchain_core.messages import SystemMessage, HumanMessage

sys_message = SystemMessage(
    """You are an helpful assistant tasked with performing arithmetic on a set of inputs.
    
    ----
    User: 2 divide 3
    AI: The answer is 0.66
    
    User: 20 divide -7
    AI: The answer is -2.85
    
    User: 20 multiply 7
    AI: The answer is 140.00    
    
    ----    
    
    """
) 


def assistant(state: MessagesState):
    return {"messages": llm_with_tools.invoke([sys_message] + state["messages"])}


builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode([add, subtract, multiply, divide, exponentiate]))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")


from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()

graph = builder.compile(checkpointer=memory, interrupt_before=["tools"])

display(Image(graph.get_graph(xray=True).draw_mermaid_png()))

In [None]:
initial_input = {
    "messages": "Multiply 2 by 3"
}

thread = {
    "configurable":{
        "thread_id": "1"
    }
}

for event in graph.stream(initial_input, thread, stream_mode="values"):
    event["messages"][-1].pretty_print()
    


In [None]:
state = graph.get_state(thread)
state.next

In [None]:
for event in graph.stream(None, thread, stream_mode="values"):
    event["messages"][-1].pretty_print()
    

In [None]:
initial_input = {
    "messages": "Multiply 2 by 3"
}

thread = {
    "configurable":{
        "thread_id": "3"
    }
}

for event in graph.stream(initial_input, thread, stream_mode="values"):
    event["messages"][-1].pretty_print()
    

user_approval = input("Do you want to call the tool? (yes/no): ")

if user_approval.lower() == "yes":
        
    for event in graph.stream(None, thread, stream_mode="values"):
        event["messages"][-1].pretty_print()
        
else:
    print("user cancelled the operation")