<a href="https://colab.research.google.com/github/satvik314/ai_experiments/blob/main/langraph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
pip install -qU langgraph langchain langchain_openai tavily-python

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m803.6/803.6 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m32.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m225.1/225.1 kB[0m [31m20.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m52.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.4/49.4 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.9/75.9 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.9/76.9 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.3/58.3 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency r

In [4]:
import os
from getpass import getpass

os.environ['OPENAI_API_KEY'] = getpass("openai: ")
os.environ['TAVILY_API_KEY'] = getpass("tavily: ")
os.environ['LANGCHAIN_API_KEY'] = getpass("langchain: ")
os.environ['LANGCHAIN_TRACING_V2'] = "true"

openai: ··········
tavily: ··········
langchain: ··········


In [6]:
#setup the tools
from langchain_community.tools.tavily_search import TavilySearchResults
tools = [TavilySearchResults(max_results=1)]

#wrap the tools in tool_executor
from langgraph.prebuilt import ToolExecutor
tool_executor = ToolExecutor(tools)


In [7]:
#setting up model and functions

from langchain_openai import ChatOpenAI
from langchain.tools.render import format_tool_to_openai_function

model = ChatOpenAI(temperature = 0, streaming = True)

functions = [format_tool_to_openai_function(t) for t in tools]

#binding functions to the model
model = model.bind_functions(functions)


In [8]:
#define agent state

from typing import TypedDict, Annotated, Sequence
import operator
from langchain_core.messages import BaseMessage

#every agent state is a list of messages
class AgentState(TypedDict):
  messages: Annotated[Sequence[BaseMessage], operator.add]

In [9]:
#define the nodes

from langgraph.prebuilt import ToolInvocation
import json
from langchain_core.messages import FunctionMessage

#define a function whether to continue or not
def should_continue(state):
  messages = state['messages']
  last_message = messages[-1]
  if "function_call" not in last_message.additional_kwargs:
    return "end"
  else:
    return "continue"

#define the function that calls the model
def call_model(state):
  messages = state['messages']
  response = model.invoke(messages)
  return {"messages" : [response]}

#define the functions to execute tools
def call_tool(state):
  messages = state['messages']
  last_message = messages[-1]

  action = ToolInvocation(
      tool = last_message.additional_kwargs['function_call']['name'],
      tool_input = json.loads(last_message.additional_kwargs['function_call']['arguments'])
  )

  response = tool_executor.invoke(action)
  function_message = FunctionMessage(content = str(response), name = action.tool)
  return {"messages" : [function_message]}

In [11]:
from langgraph.graph import StateGraph, END
# Define a new graph
workflow = StateGraph(AgentState)

# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("action", call_tool)

# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")

# We now add a conditional edge
workflow.add_conditional_edges(
    # First, we define the start node. We use `agent`.
    # This means these are the edges taken after the `agent` node is called.
    "agent",
    # Next, we pass in the function that will determine which node is called next.
    should_continue,
    # Finally we pass in a mapping.
    # The keys are strings, and the values are other nodes.
    # END is a special node marking that the graph should finish.
    # What will happen is we will call `should_continue`, and then the output of that
    # will be matched against the keys in this mapping.
    # Based on which one it matches, that node will then be called.
    {
        # If `tools`, then we call the tool node.
        "continue": "action",
        # Otherwise we finish.
        "end": END
    }
)

# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge('action', 'agent')

# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()

In [13]:
from langchain_core.messages import HumanMessage

inputs = {"messages" : [HumanMessage(content = "what is the weather in Koramangala?")]}
app.invoke(inputs)

{'messages': [HumanMessage(content='what is the weather in Koramangala?'),
  AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\n  "query": "weather in Koramangala"\n}', 'name': 'tavily_search_results_json'}}),
  FunctionMessage(content="[{'url': 'https://weather.com/en-IN/weather/tenday/l/Bangalore+Karnataka?canonicalCityId=73dd3db2099ff26d4daa3c9346a09dd27d681b36ce0b94b63f14ace60adc209e', 'content': 'recents Special Forecasts 10-Day Weather-Koramangala 5th Block, Karnataka Today Fri 26 | Day  Today Fri 26 | Day Partly cloudy. High 29°C. Winds E at 10 to 15 km/h. Fri 26 | Night  Fri 26 | Night Generally clear. Hazy. Low 17°C. Winds E at 10 to 15 km/h. Sat 27 Sat 27 | Day  Thu 01 Thu 01 | Day Partly cloudy. High 31°C. Winds SSW and variable. Thu 01 | Night10-Day Weather - Koramangala 5th Block, Karnataka As of 17:11 IST Tonight --/ 16° 3% | Night 16° 3% E 12 km/h Generally clear. Hazy. Low 16°C. Winds E at 10 to 15 km/h. Humidity 69% UV Index 0...'}]", name='tav