In [None]:
%%capture --no-stderr
%pip install --quiet -U langgraph langchain-google-genai langgraph_sdk langgraph-prebuilt

In [None]:
import os, getpass

def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")

_set_env("GOOGLE_API_KEY")

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI

#from langchain.chat_models import init_chat_model
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.0) 

In [None]:
def add(x: float, y:float) -> float:
    """Add 'x' and 'y'."""
    return x + y

 
def subtract(x: float, y:float) -> float:
    """Subtract 'x' and 'y'."""
    return x - y


def multiply(x: float, y:float) -> float:
    """Multiply 'x' and 'y'."""
    return x * y


def divide(x: float, y:float) -> float:
    """Divide 'x' and 'y'."""
    return x / y


def exponentiate(x: float, y: float) -> float:
    """Raise 'x' to the power of 'y'."""
    return x ** y

llm_with_tools = llm.bind_tools([add, subtract, multiply, divide, exponentiate])

In [None]:
from IPython.display import Image, display
from langgraph.graph import StateGraph, START, END
from langgraph.graph import MessagesState
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import tools_condition

from langchain_core.messages import SystemMessage, HumanMessage

sys_message = SystemMessage(
    """You are an helpful assistant tasked with performing arithmetic on a set of inputs.
    
    ----
    User: 2 divide 3
    AI: The answer is 0.66
    
    User: 20 divide -7
    AI: The answer is -2.85
    
    User: 20 multiply 7
    AI: The answer is 140.00    
    
    ----    
    
    """
) 


def assistant(state: MessagesState):
    return {"messages": llm_with_tools.invoke([sys_message] + state["messages"])}


builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode([add, subtract, multiply, divide, exponentiate]))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")


from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()

graph = builder.compile(checkpointer=memory)

#display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
config = {"configurable": {"thread_id": "1"}}

messages = graph.invoke({
    "messages": HumanMessage(content = "Multiply 2 and 3")
}, config)

for m in messages['messages']:
    m.pretty_print()

In [None]:
graph.get_state(config)

In [None]:
all_states =[s for s in graph.get_state_history(config)]

In [None]:
len(all_states)

In [None]:
all_states[0]

In [None]:
all_states[-2]

## Replay Graph

As we are passing *to_replay.config* it will run from that specific snapshot.

In [None]:
to_replay = all_states[-2]

In [None]:
to_replay

In [None]:
to_replay.values

In [None]:
to_replay.config

In [None]:
to_replay.next

In [None]:
for event in graph.stream(None, to_replay.config, stream_mode="values"):
    event["messages"][-1].pretty_print()

## Forking

In [None]:
len(all_states)

In [None]:
to_fork = all_states[-2]
to_fork.values["messages"]

In [None]:
to_fork.values["messages"][0].id

In [None]:
len(all_states)

In [None]:
to_fork.config

In [None]:
fork_config = graph.update_state(
    to_fork.config,
    {"messages": [HumanMessage(content='Multiply 5 and 3', 
                               id=to_fork.values["messages"][0].id)]},
)

In [None]:
fork_config

In [None]:
print(all_states[0].values["messages"])
print(all_states[-1].values["messages"])
print(all_states[-2].values["messages"])

In [None]:
all_states_new = [state for state in graph.get_state_history(config) ]
print(all_states_new[0].values["messages"])
print(all_states_new[-1].values["messages"])
print(all_states_new[-2].values["messages"])

In [None]:
all_states[0].values["messages"]

In [None]:
len(all_states_new)

In [None]:
graph.get_state(config)

In [None]:
for event in graph.stream(None, fork_config, stream_mode="values"):
    event['messages'][-1].pretty_print()