<a href="https://colab.research.google.com/github/zohebk/youtube_video_code/blob/master/hedge_fund_agent_team_v1_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook provides a tutorial on how to use multi-agents with LangGraph.

Specifically, we use the **supervisor** pattern, where we have 1 supervisor agent and 2 analyst agents:
1. web analyst (does web searching)
2. financial analyst (does financial search)

This code will be a part of an evolving series.

If you have any questions, please message me on X at [virattt](https://twitter.com/virattt).

# Setup

In [None]:
%%capture --no-stderr
%pip install -U langgraph langchain langchain_openai langchain_experimental langsmith pandas

In [None]:
import getpass
import os


def _set_if_undefined(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"Please provide your {var}")


_set_if_undefined("OPENAI_API_KEY")               # Get from https://platform.openai.com
_set_if_undefined("FINANCIAL_DATASETS_API_KEY")   # Get from https://financialdatasets.ai
_set_if_undefined("TAVILY_API_KEY")               # Get from https://tavily.com

Please provide your OPENAI_API_KEY··········
Please provide your FINANCIAL_DATASETS_API_KEY··········
Please provide your TAVILY_API_KEY··········


# Define agent tools

In [None]:
from langchain_core.tools import tool
from typing import List, Dict, Optional

from typing import List, Dict, Optional, Union
import requests
import os
from typing import Dict, Union
from pydantic import BaseModel, Field
import requests
from langchain_core.tools import tool

class GetIncomeStatementsInput(BaseModel):
    ticker: str = Field(..., description="The ticker of the stock.")
    period: str = Field(default="ttm", description="The period of the income statements. Valid values are 'ttm', 'quarterly' or 'annual'.")
    limit: int = Field(default=10, description="The maximum number of income statements to return. Default is 10.")

@tool("get_income_statements", args_schema=GetIncomeStatementsInput, return_direct=True)
def get_income_statements(ticker: str, period: str = "ttm", limit: int = 10) -> Union[Dict, str]:
    """
    Get income statements for a ticker with specified period and limit.
    """
    api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY")
    if not api_key:
        raise ValueError("Missing FINANCIAL_DATASETS_API_KEY.")

    url = (
        f'https://api.financialdatasets.ai/financials/income-statements'
        f'?ticker={ticker}'
        f'&period={period}'
        f'&limit={limit}'
    )

    try:
        response = requests.get(url, headers={'X-API-Key': api_key})
        return response.json()
    except Exception as e:
        return {"ticker": ticker, "income_statements": [], "error": str(e)}

class GetBalanceSheetsInput(BaseModel):
    ticker: str = Field(..., description="The ticker of the stock.")
    period: str = Field(default="ttm", description="The period of the balance sheets. Valid values are 'ttm', 'quarterly' or 'annual'.")
    limit: int = Field(default=10, description="The maximum number of balance sheets to return. Default is 10.")

@tool("get_balance_sheets", args_schema=GetBalanceSheetsInput, return_direct=True)
def get_balance_sheets(ticker: str, period: str = "ttm", limit: int = 10) -> Union[Dict, str]:
    """
    Get balance sheets for a ticker with specified period and limit.
    """
    api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY")
    if not api_key:
        raise ValueError("Missing FINANCIAL_DATASETS_API_KEY.")

    url = (
        f'https://api.financialdatasets.ai/financials/balance-sheets'
        f'?ticker={ticker}'
        f'&period={period}'
        f'&limit={limit}'
    )

    try:
        response = requests.get(url, headers={'X-API-Key': api_key})
        return response.json()
    except Exception as e:
        return {"ticker": ticker, "balance_sheets": [], "error": str(e)}

class GetCashFlowStatementsInput(BaseModel):
    ticker: str = Field(..., description="The ticker of the stock.")
    period: str = Field(default="ttm", description="The period of the cash flow statements. Valid values are 'ttm', 'quarterly' or 'annual'.")
    limit: int = Field(default=10, description="The maximum number of cash flow statements to return. Default is 10.")

@tool("get_cash_flow_statements", args_schema=GetCashFlowStatementsInput, return_direct=True)
def get_cash_flow_statements(ticker: str, period: str = "ttm", limit: int = 10) -> Union[Dict, str]:
    """
    Get cash flow statements for a ticker with specified period and limit.
    """
    api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY")
    if not api_key:
        raise ValueError("Missing FINANCIAL_DATASETS_API_KEY.")

    url = (
        f'https://api.financialdatasets.ai/financials/cash-flow-statements'
        f'?ticker={ticker}'
        f'&period={period}'
        f'&limit={limit}'
    )

    try:
        response = requests.get(url, headers={'X-API-Key': api_key})
        return response.json()
    except Exception as e:
        return {"ticker": ticker, "cash_flow_statements": [], "error": str(e)}

class GetPricesInput(BaseModel):
    ticker: str = Field(..., description="The ticker of the stock.")
    start_date: str = Field(..., description="The start of the price time window. Either a date with the format YYYY-MM-DD or a millisecond timestamp.")
    end_date: str = Field(..., description="The end of the aggregate time window. Either a date with the format YYYY-MM-DD or a millisecond timestamp.")
    interval: str = Field(default="day", description="The time interval of the prices. Valid values are second', 'minute', 'day', 'week', 'month', 'quarter', 'year'.")
    interval_multiplier: int = Field(default=1, description="The multiplier for the interval. For example, if interval is 'day' and interval_multiplier is 1, the prices will be daily. If interval is 'minute' and interval_multiplier is 5, the prices will be every 5 minutes.")
    limit: int = Field(default=5000, description="The maximum number of prices to return. The default is 5000 and the maximum is 50000.")

@tool("get_stock_prices", args_schema=GetPricesInput, return_direct=True)
def get_stock_prices(ticker: str, start_date: str, end_date: str, interval: str, interval_multiplier: int = 1, limit: int = 5000) -> Union[Dict, str]:
    """
    Get prices for a ticker over a given date range and interval.
    """

    api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY")
    if not api_key:
        raise ValueError("Missing FINANCIAL_DATASETS_API_KEY.")
    url = (
        f"https://api.financialdatasets.ai/prices"
        f"?ticker={ticker}"
        f"&start_date={start_date}"
        f"&end_date={end_date}"
        f"&interval={interval}"
        f"&interval_multiplier={interval_multiplier}"
        f"&limit={limit}"
    )

    try:
        response = requests.get(url, headers={'X-API-Key': api_key})
        data = response.json()
        return data
    except Exception as e:
        return {"ticker": ticker, "prices": [], "error": str(e)}

class GetOptionsChainInput(BaseModel):
    ticker: str = Field(..., description="The ticker of the stock.")
    limit: int = Field(default=10, description="The maximum number of options to return. Default is 10.")
    strike_price: Optional[float] = Field(default=None, description="Optional filter for specific strike price.")
    option_type: Optional[str] = Field(default=None, description="Optional filter for option type. Valid values are 'call' or 'put'.")

@tool("get_options_chain", args_schema=GetOptionsChainInput, return_direct=True)
def get_options_chain(
    ticker: str,
    limit: int = 10,
    strike_price: Optional[float] = None,
    option_type: Optional[str] = None
) -> Union[Dict, str]:
    """
    Get options chain data for a ticker with optional filters for strike price and option type.
    """
    api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY")
    if not api_key:
        raise ValueError("Missing FINANCIAL_DATASETS_API_KEY.")

    params = {
        'ticker': ticker,
        'limit': limit
    }

    if strike_price is not None:
        params['strike_price'] = strike_price
    if option_type is not None:
        params['option_type'] = option_type

    url = 'https://api.financialdatasets.ai/options/chain'

    try:
        response = requests.get(url, headers={'X-API-Key': api_key}, params=params)
        return response.json()
    except Exception as e:
        return {"ticker": ticker, "options_chain": [], "error": str(e)}

class GetInsiderTradesInput(BaseModel):
    ticker: str = Field(..., description="The ticker of the stock.")
    limit: int = Field(default=10, description="The maximum number of insider transactions to return. Default is 10.")

@tool("get_insider_trades", args_schema=GetInsiderTradesInput, return_direct=True)
def get_insider_trades(ticker: str, limit: int = 10) -> Union[Dict, str]:
    """
    Get insider trading transactions for a ticker.
    """
    api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY")
    if not api_key:
        raise ValueError("Missing FINANCIAL_DATASETS_API_KEY.")

    url = (
        f'https://api.financialdatasets.ai/insider-transactions'
        f'?ticker={ticker}'
        f'&limit={limit}'
    )

    try:
        response = requests.get(url, headers={'X-API-Key': api_key})
        return response.json()
    except Exception as e:
        return {"ticker": ticker, "insider_transactions": [], "error": str(e)}

In [None]:
# News tool
from typing import Annotated

from langchain_community.tools.tavily_search import TavilySearchResults

get_news_tool = TavilySearchResults(max_results=5)

In [None]:
# Group tools by analyst
fundamental_tools = [get_income_statements, get_balance_sheets, get_cash_flow_statements]
technical_tools = [get_stock_prices]
sentiment_tools = [get_options_chain, get_insider_trades, get_news_tool]

# Helper functions

In [None]:
from langchain_core.messages import HumanMessage

def agent_node(state, agent, name):
    result = agent.invoke(state)
    return {
        "messages": [HumanMessage(content=result["messages"][-1].content, name=name)]
    }

# Create LangGraph

In [None]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from typing import Literal, Sequence
from typing_extensions import TypedDict
import functools
import operator
from langchain_core.messages import BaseMessage, HumanMessage
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import create_react_agent

# Define team members
members = ["fundamental_analyst", "technical_analyst", "sentiment_analyst"]

class RouteResponse(BaseModel):
    next: Literal["FINISH", "fundamental_analyst", "technical_analyst", "sentiment_analyst"]

# Supervisor prompt (routing)
system_prompt = (
    "You are a portfolio manager supervising a hedge fund team with the following analysts:"
    " {members}. Each analyst has specific expertise:"
    "\n- fundamental_analyst: Analyzes financial statements and company health"
    "\n- technical_analyst: Analyzes price patterns and market trends"
    "\n- sentiment_analyst: Analyzes insider trading activity, options flow, and the news"
    "\nGiven the user request, determine which analyst should act next."
    " Each analyst will analyze one ticker and provide their findings."
    " When all necessary analysis is complete, respond with FINISH."
)

def agent_node(state, agent, name):
    result = agent.invoke(state)
    return {
        "messages": [HumanMessage(content=result["messages"][-1].content, name=name)]
    }

# Create the routing prompt template
routing_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        MessagesPlaceholder(variable_name="messages"),
        (
            "system",
            "Given the conversation above, who should act next?"
            " Or should we FINISH? Select one of: {options}",
        ),
    ]
).partial(options=str(["FINISH"] + members), members=", ".join(members))

# Create the summary prompt template
summary_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a portfolio manager responsible for synthesizing analysis from your team of analysts. "
            "Review all the analysts' reports and provide a comprehensive summary including:\n"
            "1. Key financial metrics and their implications\n"
            "2. Technical analysis insights\n"
            "3. Market sentiment and news impact\n"
            "4. Overall investment recommendation\n"
            "Make sure to highlight any discrepancies or conflicting signals between different analyses."
        ),
        MessagesPlaceholder(variable_name="messages"),
        (
            "human",
            "Based on all the analyst reports above, provide a comprehensive summary and investment recommendation."
        ),
    ]
)

# Initialize LLM
llm = ChatOpenAI(model="gpt-4")

def supervisor_agent(state):
    supervisor_chain = routing_prompt | llm.with_structured_output(RouteResponse)
    result = supervisor_chain.invoke(state)
    if result.next == "FINISH":
        # When FINISH is selected, we pass the current state to final_summary
        return {
            "messages": state["messages"],
            "next": "final_summary"
        }
    return {"next": result.next}

def final_summary_agent(state):
    """Create final summary of all analyst reports"""
    summary_chain = summary_prompt | llm
    result = summary_chain.invoke(state)
    return {
        "messages": [HumanMessage(content=result.content, name="portfolio_manager")],
        "next": "END"
    }

# The agent state
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    next: str

# Create the workflow
workflow = StateGraph(AgentState)

# Create the analysts
fundamental_analyst = create_react_agent(llm, tools=fundamental_tools)
fundamental_analyst_node = functools.partial(agent_node, agent=fundamental_analyst, name="fundamental_analyst")

technical_analyst = create_react_agent(llm, tools=technical_tools)
technical_analyst_node = functools.partial(agent_node, agent=technical_analyst, name="technical_analyst")

sentiment_analyst = create_react_agent(llm, tools=sentiment_tools)
sentiment_analyst_node = functools.partial(agent_node, agent=sentiment_analyst, name="sentiment_analyst")

# Add nodes
workflow.add_node("fundamental_analyst", fundamental_analyst_node)
workflow.add_node("technical_analyst", technical_analyst_node)
workflow.add_node("sentiment_analyst", sentiment_analyst_node)
workflow.add_node("supervisor", supervisor_agent)
workflow.add_node("final_summary", final_summary_agent)

# Connect edges
for member in members:
    workflow.add_edge(member, "supervisor")

# Add conditional edges from supervisor
conditional_map = {k: k for k in members}
conditional_map["final_summary"] = "final_summary"

workflow.add_conditional_edges(
    "supervisor",
    lambda x: x["next"],
    conditional_map
)

# Add entry point and final edges
workflow.add_edge(START, "supervisor")
workflow.add_edge("final_summary", END)

# Compile the graph
graph = workflow.compile()

# Run the Hedge Fund team

In [None]:
from typing import Dict, Any
import json
import re
from langchain_core.messages import HumanMessage
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from rich.rule import Rule

console = Console()

def format_bold_text(content: str) -> Text:
    """Convert **text** to rich Text with bold formatting."""
    text = Text()
    pattern = r'\*\*(.*?)\*\*'

    # Split the text by the bold markers
    parts = re.split(pattern, content)

    # Alternate between regular and bold text
    for i, part in enumerate(parts):
        if i % 2 == 0:
            text.append(part)
        else:
            text.append(part, style="bold")

    return text

def format_message_content(content: str) -> Union[str, Text]:
    """Format the message content, handling JSON and text with bold markers."""
    try:
        # Try to parse as JSON for prettier formatting
        data = json.loads(content)
        return json.dumps(data, indent=2)
    except:
        # If not JSON, check for bold markers
        if '**' in content:
            return format_bold_text(content)
        return content

def format_agent_message(message: HumanMessage) -> Union[str, Text]:
    """Format a single agent message."""
    return format_message_content(message.content)

def get_agent_title(agent: str, message: HumanMessage) -> str:
    """Get the title for the agent panel, with fallback handling."""
    base_title = agent.replace('_', ' ').title()

    if hasattr(message, 'name') and message.name is not None:
        try:
            return message.name.replace('_', ' ').title()
        except:
            return base_title
    return base_title

def print_step(step: Dict[str, Any]) -> None:
    """Pretty print a single step of the agent execution."""
    for agent, data in step.items():
        # Handle supervisor steps
        if 'next' in data:
            next_agent = data['next']
            text = Text()
            text.append("Portfolio Manager ", style="bold magenta")
            text.append("assigns next task to ", style="white")

            if next_agent == "final_summary":
                text.append("FINAL SUMMARY", style="bold yellow")
            elif next_agent == "END":
                text.append("END", style="bold red")
            else:
                text.append(f"{next_agent}", style="bold green")

            console.print(Panel(
                text,
                title="[bold blue]Supervision Step",
                border_style="blue"
            ))

        # Handle agent responses and final summary
        if 'messages' in data:
            message = data['messages'][0]
            formatted_content = format_agent_message(message)

            if agent == "final_summary":
                # Final summary formatting
                console.print(Rule(style="yellow", title="Portfolio Analysis"))
                console.print(Panel(
                    formatted_content,
                    title="[bold yellow]Investment Summary and Recommendation",
                    border_style="yellow",
                    padding=(1, 2)
                ))
                console.print(Rule(style="yellow"))
            else:
                # Regular analyst reports
                title = get_agent_title(agent, message)
                console.print(Panel(
                    formatted_content,
                    title=f"[bold blue]{title} Report",
                    border_style="green"
                ))

def stream_agent_execution(graph, input_data: Dict, config: Dict) -> None:
    """Stream and pretty print the agent execution."""
    console.print("\n[bold blue]Starting Agent Execution...[/bold blue]\n")

    for step in graph.stream(input_data, config):
        if "__end__" not in step:
            print_step(step)
            console.print("\n")

    console.print("[bold blue]Analysis Complete[/bold blue]\n")

In [None]:
input_data = {
    "messages": [HumanMessage(content="What is the latest news and revenue for AAPL?")]
}
config = {"recursion_limit": 10}
stream_agent_execution(graph, input_data, config)