<a href="https://colab.research.google.com/gist/virattt/ba0b660cdcaf4161ca1e6e5d8b5de4f8/langgraph-financial-agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook includes my code for creating a financial agent using [LangGraph](https://github.com/langchain-ai/langgraph).

The agent has two tools:

1. Extract a ticker from a user query.
2. Given a ticker, get its latest price using [Polygon](https://polygon.io/).

You will need two things to run the code:

1. OpenAI API key ([link](https://platform.openai.com/account/api-keys))
2. Polygon API key ([link](https://polygon.io/))

I've tried to make the code as easy as possible to read and run.  If you have any questions, please feel free to message me on [X](https://twitter.com/virattt)!

## Step 0 - Install dependencies

In [1]:
# !pip install langgraph
# !pip install -U langchain langchain_openai langchainhub

In [2]:
import os

# Set your OpenAI API key
os.environ["OPENAI_API_KEY"] = os.environ.get('OPENAI_API_KEY', 'sk-wtv')

## Create agents

In [3]:
from langchain.tools.render import format_tool_to_openai_function
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder


def create_agent(llm, tools, system_message: str):
    """Create an agent."""
    functions = [format_tool_to_openai_function(t) for t in tools]

    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You are a helpful AI assistant, collaborating with other assistants."
                " Use the provided tools to progress towards answering the question."
                " If you are unable to fully answer, that's OK, another assistant with different tools "
                " will help where you left off. Execute what you can to make progress."
                " If you or any of the other assistants have the final answer or deliverable,"
                " prefix your response with FINAL ANSWER so the team knows to stop."
                " You have access to the following tools: {tool_names}.\n{system_message}",
            ),
            MessagesPlaceholder(variable_name="messages"),
        ]
    )
    prompt = prompt.partial(system_message=system_message)
    prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))

    return prompt | llm.bind_functions(functions)


## Define LLM

In [4]:
from langchain_openai.chat_models import ChatOpenAI


# Choose the LLM that will drive the agent
# llm = VLLMOpenAI(
llm = ChatOpenAI(
    openai_api_key="wtv",
    openai_api_base="http://localhost:8000/v1",
    model_name="mistralai/Mistral-7B-Instruct-v0.2",
    model_kwargs={"stop": ["\n\n"]},
    # trust_remote_code=True,  # mandatory for hf models
    # max_new_tokens=128,
    # top_k=10,
    # top_p=0.95,
    # temperature=1,
)

## Create Graph

##### States

In [5]:
from typing import TypedDict, Annotated, Sequence
import operator
from langchain_core.messages import BaseMessage


class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]

##### Nodes

In [6]:
from functools import partial

from langchain_core.messages import (
    FunctionMessage,
    HumanMessage,
)

from agents.tools.plant import sense

# Helper function to create a node for a given agent
def agent_node(state, agent, name):
    result = agent.invoke(state)
    # We convert the agent output into a format that is suitable to append to the global state
    if isinstance(result, FunctionMessage):
        pass
    else:
        result = HumanMessage(**result.dict(exclude={"type", "name"}), name=name)
    return {
        "messages": [result],
        # Since we have a strict workflow, we can
        # track the sender so we know who to pass to next.
        "sender": name,
    }

def sensitive_node(state):
    senses = sense()
    result = 


    return {"messages": []}

In [7]:
from agents.tools import search_online, read_news
from agents.tools.persona import summarize_activity
from agents.tools.plant import sense


# ==============================
sensitive_agent = create_agent(
    llm,
    [sense],
    system_message="You must simulate a plant's mood based on it's sensor readings",
)
sensitive_node = partial(agent_node, agent=sensitive_agent, name="Sensitive")

# ==============================
researcher_agent = create_agent(
    llm,
    [search_online, read_news],
    system_message="You must do your best to look for truth.",
)
research_node = partial(agent_node, agent=researcher_agent, name="Researcher")


# ==============================
ego_agent = create_agent(
    llm,
    [summarize_activity],
    system_message="You must act as a bored, polite, curious person would."
)
ego_node = partial(agent_node, agent=ego_agent, name="Ego")

## Graph Tools

In [8]:
import json

from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation

from agents.tools import TOOLS


tool_executor = ToolExecutor(TOOLS)


def tool_node(state):
    """This runs tools in the graph

    It takes in an agent action and calls that tool and returns the result."""
    messages = state["messages"]
    # Based on the continue condition
    # we know the last message involves a function call
    last_message = messages[-1]
    # We construct an ToolInvocation from the function_call
    tool_input = json.loads(
        last_message.additional_kwargs["function_call"]["arguments"]
    )
    # We can pass single-arg inputs by value
    if len(tool_input) == 1 and "__arg1" in tool_input:
        tool_input = next(iter(tool_input.values()))
    tool_name = last_message.additional_kwargs["function_call"]["name"]
    action = ToolInvocation(
        tool=tool_name,
        tool_input=tool_input,
    )
    # We call the tool_executor and get back a response
    response = tool_executor.invoke(action)
    # We use the response to create a FunctionMessage
    function_message = FunctionMessage(
        content=f"{tool_name} response: {str(response)}", name=action.tool
    )
    # We return a list, because this will get added to the existing list
    return {"messages": [function_message]}

## Define Edge logic

In [9]:
# Either agent can decide to end
def router(state):
    # This is the router
    messages = state["messages"]
    last_message = messages[-1]
    if "function_call" in last_message.additional_kwargs:
        # The previus agent is invoking a tool
        return "call_tool"
    if "FINAL ANSWER" in last_message.content:
        # Any agent decided the work is done
        return "end"
    return "continue"

## Step 5 - Graph

In [10]:
from enum import Enum

from langgraph.graph import StateGraph, END

# Define a new graph
workflow = StateGraph(AgentState)
 
class Node(Enum):
    SENSITIVE = 'sensitive'
    RESEARCHER = 'researcher'
    EGO = 'ego'

workflow.add_node(Node.SENSITIVE.value, sensitive_node)
workflow.add_node(Node.RESEARCHER.value, research_node)
workflow.add_node(Node.EGO.value, ego_node)
workflow.add_node('call_tool', tool_node)

# This means that this node is the first one called
workflow.set_entry_point(Node.SENSITIVE.value)

workflow.add_conditional_edges(
    Node.RESEARCHER.value,
    router,
    {"continue": Node.EGO.value, "call_tool": "call_tool", "end": END},
)
workflow.add_conditional_edges(
    Node.SENSITIVE.value,
    router,
    {"continue": Node.EGO.value, "call_tool": "call_tool", "end": END},
)
workflow.add_conditional_edges(
    Node.EGO.value,
    router,
    {"continue": END, "call_tool": "call_tool", "end": END},
)
workflow.add_conditional_edges(
    "call_tool",
    # Each agent node updates the 'sender' field
    # the tool calling node does not, meaning
    # this edge will route back to the original agent
    # who invoked the tool
    lambda x: x["sender"],
    {
        node.value: node.value
        for node in Node
    },
)

# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()

# Run!

In [13]:
from langchain_core.messages import HumanMessage, SystemMessage

state = {
    "messages": [
        SystemMessage(content="You are a helpful assistant."),
        HumanMessage(
            content="Tell me a interesting fact about the progressive rock band King Crimson"
        ),
    ],
}
app.invoke(
    state,
    {
        "recursion_limit": 150,
    },
)

BadRequestError: Error code: 400 - {'object': 'error', 'message': 'Conversation roles must alternate user/assistant/user/assistant/...', 'type': 'invalid_request_error', 'param': None, 'code': None}