# Multi-agent network
Adapted from the original [Langgraph multi-agent notebook example](https://github.com/langchain-ai/langgraph/blob/main/docs/docs/tutorials/multi_agent/multi-agent-collaboration.ipynb)

A single agent can usually operate effectively using a handful of tools within a single domain, but even using powerful models like `gpt-4`, it can be less effective at using many tools. 

One way to approach complicated tasks is through a "divide-and-conquer" approach: create a specialized agent for each task or domain and route tasks to the correct "expert". This is an example of a [multi-agent network](https://langchain-ai.github.io/langgraph/concepts/multi_agent/#network) architecture.

This notebook (inspired by the paper [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation](https://arxiv.org/abs/2308.08155), by Wu, et. al.) shows one way to do this using LangGraph.

This notebook in particular replaces the standard search tool with two different researchers, one for qualitative research using Cortex Search over federal reserve minutes, and a second for quantitative research using Cortex Analyst over structured sec filings data.


In [None]:
%%capture --no-stderr
# pip install -U langchain_community langchain_openai langchain_experimental matplotlib langgraph pygraphviz

####  We use Tavily to perform web search with LLMs for illustration. But any search tool should do.

In [None]:
import os

# need both API keys
os.environ["OPENAI_API_KEY"] = "sk-proj-..."

os.environ["TRULENS_OTEL_TRACING"] = (
    "1"  # to enable OTEL tracing -> note the Snowsight UI experience for now is limited to PuPr customers, not yet supported for OSS.
)

os.environ["SNOWFLAKE_ACCOUNT"] = "..."
os.environ["SNOWFLAKE_USER"] = "..."
os.environ["SNOWFLAKE_USER_PASSWORD"] = "..."
os.environ["SNOWFLAKE_DATABASE"] = "AGENTS_DB"
os.environ["SNOWFLAKE_SCHEMA"] = "NOTEBOOKS"
os.environ["SNOWFLAKE_ROLE"] = "CORTEX_USER_ROLE"
os.environ["SNOWFLAKE_WAREHOUSE"] = "..."

## Define tools

We will also define some tools that our agents will use in the future

## Create graph

Now that we've defined our tools and made some helper functions, will create the individual agents below and tell them how to talk to each other using LangGraph.

### Define Agent Nodes

We now need to define the nodes.

First, we'll create a utility to create a system prompt for each agent.

In [None]:
import json
from typing import Annotated, Literal

from IPython.display import Image
from IPython.display import display
from langchain.load.dump import dumps
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_experimental.utilities import PythonREPL
from langchain_openai import ChatOpenAI
from langgraph.graph import END
from langgraph.graph import MessagesState
from langgraph.graph import StateGraph
from langgraph.prebuilt import create_react_agent
from langgraph.types import Command
from trulens.core.otel.instrument import instrument
from trulens.otel.semconv.trace import BASE_SCOPE
from trulens.otel.semconv.trace import SpanAttributes


def build_graph():
    def make_system_prompt(suffix: str) -> str:
        return (
            "You are a helpful AI assistant, collaborating with other assistants."
            " Use the provided tools to progress towards answering the question."
            " If you are unable to fully answer, that's OK, another assistant with different tools "
            " will help where you left off. Execute what you can to make progress."
            " If you or any of the other assistants have the final answer or deliverable,"
            " prefix your response with FINAL ANSWER so the team knows to stop."
            f"\n{suffix}"
        )

    # Warning: This executes code locally, which can be unsafe when not sandboxed

    repl = PythonREPL()

    @tool
    @instrument(
        span_type="PYTHON_REPL_TOOL",
        attributes={
            f"{BASE_SCOPE}.python_tool_input_code": "code",
        },
    )
    def python_repl_tool(
        code: Annotated[
            str, "The python code to execute to generate your chart."
        ],
    ):
        """Use this to execute python code. If you want to see the output of a value,
        you should print it out with `print(...)`. This is visible to the user."""

        try:
            result = repl.run(code)
        except BaseException as e:
            return f"Failed to execute. Error: {repr(e)}"
        result_str = (
            f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}"
        )
        return (
            result_str
            + "\n\nIf you have completed all tasks, respond with FINAL ANSWER."
        )

    llm = ChatOpenAI(
        model="gpt-4o",
    )

    def get_next_node(last_message: BaseMessage, goto: str):
        if "FINAL ANSWER" in last_message.content:
            # Any agent decided the work is done
            return END
        return goto

    @instrument(
        span_type="QUALITATIVE_RESEARCH_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.qualitative_research_node_input": args[0][
                "messages"
            ][-1].content,
            f"{BASE_SCOPE}.qualitative_research_node_response": ret.update[
                "messages"
            ][-1].content
            if hasattr(ret, "update")
            else json.dumps(ret, indent=4, sort_keys=True),
            f"{BASE_SCOPE}.tool_messages": [
                dumps(message)
                for message in ret.update["messages"]
                if isinstance(message, ToolMessage)
            ]
            if hasattr(ret, "update")
            else "No tool call",
        },
    )
    @instrument(
        span_type=SpanAttributes.SpanType.RETRIEVAL,
        attributes=lambda ret, exception, *args, **kwargs: {
            SpanAttributes.RETRIEVAL.QUERY_TEXT: args[0]["messages"][
                -1
            ].content,
            SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS: [
                ret.update["messages"][-1].content
            ]
            if hasattr(ret, "update")
            else [json.dumps(ret, indent=4, sort_keys=True)],
        },
    )
    def qualitative_research_node(
        state: MessagesState,
    ) -> Command[Literal["chart_generator"]]:
        # Extract the user query from the state.
        user_query = state["messages"][0].content if state["messages"] else ""

        # Use Snowflake environment variables already set in the notebook.
        CONNECTION_PARAMETERS = {
            "account": os.environ["SNOWFLAKE_ACCOUNT"],
            "user": os.environ["SNOWFLAKE_USER"],
            "password": os.environ["SNOWFLAKE_USER_PASSWORD"],
            "role": os.environ["SNOWFLAKE_ROLE"],
            "database": "CORTEX_SEARCH_TUTORIAL_DB",
            "warehouse": os.environ["SNOWFLAKE_WAREHOUSE"],
            "schema": "PUBLIC",
        }
        from snowflake.snowpark import Session

        session = Session.builder.configs(CONNECTION_PARAMETERS).create()
        from snowflake.core import Root

        root = Root(session)

        # Replace the placeholders with your actual Cortex search service details.
        my_service = (
            root.databases["CORTEX_SEARCH_TUTORIAL_DB"]
            .schemas["PUBLIC"]
            .cortex_search_services["FOMC_SEARCH_SERVICE"]
        )

        # Execute the Cortex search call.
        resp = my_service.search(
            query=user_query,
            columns=["chunk"],
            limit=20,
        )
        result_json = resp.to_json()

        # Create a human message with the Cortex search result.
        from langchain_core.messages import HumanMessage

        response_message = HumanMessage(
            content=f"Cortex Search result: {result_json}",
            name="qualitative_researcher",
        )

        # Return a command to pass control to the chart generator node.
        return Command(
            update={"messages": state["messages"] + [response_message]},
            goto="chart_generator",
        )

    # Chart generator agent and node
    # NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED
    chart_agent = create_react_agent(
        llm,
        [python_repl_tool],
        prompt=make_system_prompt(
            "You can only generate charts. The generated chart should be save at a local directory at current directory PATH './langgraph_saved_images' , and this PATH should be sent to your colleague. You are working with a chart summarizer colleague."
        ),
    )

    @instrument(
        span_type="CHART_GENERATOR_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.chart_node_input": args[0]["messages"][-1].content,
            f"{BASE_SCOPE}.chart_node_response": ret.update["messages"][
                -1
            ].content
            if hasattr(ret, "update")
            else json.dumps(ret, indent=4, sort_keys=True),
        },
    )
    def chart_node(
        state: MessagesState,
    ) -> Command[Literal["chart_summarizer"]]:
        result = chart_agent.invoke(state)
        goto = get_next_node(result["messages"][-1], "chart_summarizer")
        # wrap in a human message, as not all providers allow
        # AI message at the last position of the input messages list
        result["messages"][-1] = HumanMessage(
            content=result["messages"][-1].content, name="chart_generator"
        )
        return Command(
            update={
                # share internal message history of chart agent with other agents
                "messages": result["messages"],
            },
            goto=goto,
        )

    # Build the image captioning agent.
    # If you have any specific image processing tools (e.g., for extracting chart images),
    # you can add them in the tools list. For now, we leave it empty.
    chart_summary_agent = create_react_agent(
        llm,
        tools=[],  # Add image processing tools if available/needed.
        prompt=make_system_prompt(
            "You can only generate image captions. You are working with a researcher colleague and a chart generator colleague. "
            + "Your task is to generate a concise summary for the provided chart image saved at a local PATH, where the PATH should be and only be provided by your chart generator colleague. The summary should be no more than 3 sentences."
        ),
    )

    @instrument(
        span_type="CHART_SUMMARY_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.summary_node_input": args[0]["messages"][-1].content,
            f"{BASE_SCOPE}.summary_node_output": ret.update["messages"][
                -1
            ].content
            if hasattr(ret, "update")
            else "NO SUMMARY GENERATED",
        },
    )
    def chart_summary_node(
        state: MessagesState,
    ) -> Command[Literal["qualitative_researcher", END]]:
        result = chart_summary_agent.invoke(state)
        # Determine the next node based on the content of the last message
        goto = get_next_node(result["messages"][-1], "qualitative_researcher")
        # Wrap the output message in a HumanMessage to maintain consistency in the conversation flow.
        result["messages"][-1] = HumanMessage(
            content=result["messages"][-1].content, name="chart_summarizer"
        )
        return Command(
            update={"messages": result["messages"]},
            goto=goto,
        )

    import requests

    @instrument(
        span_type="QUANTITATIVE_RESEARCH_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.quantitative_research_node_input": args[0][
                "messages"
            ][-1].content,
            f"{BASE_SCOPE}.quantitative_research_node_response": ret.update[
                "messages"
            ][-1].content
            if hasattr(ret, "update")
            else json.dumps(ret, indent=4, sort_keys=True),
            f"{BASE_SCOPE}.tool_messages": [
                dumps(message)
                for message in ret.update["messages"]
                if isinstance(message, ToolMessage)
            ]
            if hasattr(ret, "update")
            else "No tool call",
        },
    )
    @instrument(
        span_type=SpanAttributes.SpanType.RETRIEVAL,
        attributes=lambda ret, exception, *args, **kwargs: {
            SpanAttributes.RETRIEVAL.QUERY_TEXT: args[0]["messages"][
                -1
            ].content,
            SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS: [
                ret.update["messages"][-1].content
            ]
            if hasattr(ret, "update")
            else [json.dumps(ret, indent=4, sort_keys=True)],
        },
    )
    def quantitative_research_node(
        state: MessagesState,
    ) -> Command[Literal["qualitative_researcher", END]]:
        # Extract the user query from the state.
        user_query = state["messages"][0].content if state["messages"] else ""
        # Prepare the payload for Cortex Analyst API.
        payload = {
            "messages": [
                {
                    "role": "user",
                    "content": [{"type": "text", "text": user_query}],
                }
            ],
            "semantic_model_file": "@agents_db.notebooks.semantic_models.sec_filings.yaml",
        }
        # Define the API endpoint and headers (update these values as needed).
        api_url = os.environ.get(
            "CORES_ANALYST_URL",
            "http://localhost:8000/api/v2/cortex/analyst/message",
        )
        headers = {
            "Authorization": os.environ.get(
                "CORES_ANALYST_AUTH_TOKEN", "Bearer YOUR_TOKEN_HERE"
            ),
            "Content-Type": "application/json",
        }
        try:
            response = requests.post(api_url, json=payload, headers=headers)
            if response.status_code != 200:
                result_text = f"Failed Cortex Analyst API call with status code {response.status_code}: {response.text}"
            else:
                data = response.json()
                contents = data.get("message", {}).get("content", [])
                result_text = " ".join([
                    c.get("text", c.get("statement", "")) for c in contents
                ])
        except Exception as e:
            result_text = f"Exception during API call: {str(e)}"
        from langchain_core.messages import HumanMessage

        result_message = HumanMessage(
            content=result_text, name="quantitative_researcher"
        )
        goto = (
            "qualitative_researcher"
            if "FINAL ANSWER" not in result_message.content
            else END
        )
        return Command(
            update={"messages": state["messages"] + [result_message]}, goto=goto
        )

    # Build the workflow graph.
    workflow = StateGraph(MessagesState)
    workflow.add_node("qualitative_researcher", qualitative_research_node)
    workflow.add_node("quantitative_researcher", quantitative_research_node)
    workflow.add_node("chart_generator", chart_node)
    workflow.add_node("chart_summarizer", chart_summary_node)

    # Define transitions between nodes
    workflow.set_entry_point("qualitative_researcher")
    workflow.add_edge("qualitative_researcher", "quantitative_researcher")
    workflow.add_edge("quantitative_researcher", "chart_generator")
    workflow.add_edge("chart_generator", "chart_summarizer")
    workflow.add_edge("chart_summarizer", END)
    graph = workflow.compile()

    return graph

In [None]:
import datetime
import time

from snowflake.snowpark import Session
from trulens.apps.app import TruApp
from trulens.connectors.snowflake import SnowflakeConnector
from trulens.core.run import Run
from trulens.core.run import RunConfig

# Snowflake account for trulens


snowflake_connection_parameters = {
    "account": os.environ["SNOWFLAKE_ACCOUNT"],
    "user": os.environ["SNOWFLAKE_USER"],
    "password": os.environ["SNOWFLAKE_USER_PASSWORD"],
    "database": os.environ["SNOWFLAKE_DATABASE"],
    "schema": os.environ["SNOWFLAKE_SCHEMA"],
    "role": os.environ["SNOWFLAKE_ROLE"],
    "warehouse": os.environ["SNOWFLAKE_WAREHOUSE"],
}
snowpark_session_trulens = Session.builder.configs(
    snowflake_connection_parameters
).create()


trulens_sf_connector = SnowflakeConnector(
    snowpark_session=snowpark_session_trulens
)


class TruAgent:
    def __init__(self, max_results=3):
        self.graph = build_graph()

    @instrument(
        span_type=SpanAttributes.SpanType.RECORD_ROOT,
        attributes={
            SpanAttributes.RECORD_ROOT.INPUT: "query",
            SpanAttributes.RECORD_ROOT.OUTPUT: "return",
        },
    )
    def invoke_agent_graph(self, query: str) -> str:
        events = self.graph.stream(
            {
                "messages": [("user", query)],
            },
            # Maximum number of steps to take in the graph
            {"recursion_limit": 10},
        )

        # resp_messages = []

        for event in events:
            messages = list(event.values())[0]["messages"]
        return (
            messages[-1].content
            if messages and hasattr(messages[-1], "content")
            else ""
        )


tru_agent = TruAgent(max_results=1)

APP_NAME = "Financial Research Agent"

tru_agent_app = TruApp(
    tru_agent,
    app_name=APP_NAME,
    app_version="updated graph - search and analyst - lower recursion limit",
    connector=trulens_sf_connector,
    main_method=tru_agent.invoke_agent_graph,
)

st_1 = datetime.datetime.fromtimestamp(time.time()).strftime(
    "%Y-%m-%d %H:%M:%S"
)

run_config = RunConfig(
    run_name="Multi-agent demo run - search and analyst graph - lower recursion limit",
    description="this is a run with access to cortex search and cortex analyst, along with chart and summarization capabilities",
    dataset_name="Research test dataset",
    source_type="DATAFRAME",
    label="langgraph demo",
    dataset_spec={
        "RECORD_ROOT.INPUT": "query",
    },
)

run: Run = tru_agent_app.add_run(run_config)

In [None]:
display(Image(tru_agent.graph.get_graph().draw_png()))

In [None]:
import pandas as pd

user_queries = [
    "In 2023, how did the fed funds rate fluctuate? What were the key drivers?",
    "What were the core drivers of inflation in 2023? How did the federal reserve respond?",
    "What is the total market value of securities by asset class according to SEC filings? Create a bar chart that best illustrates this data. Once the chart is generated, summarize the chart, and finish.",
    "Who are the top 10 filing managers by number of holdings in the most recent reporting quarter according to SEC filings? Create a bar chart that best illustrates this data. Once the chart is generated, summarize the chart, and finish.",
]

user_queries_df = pd.DataFrame(user_queries, columns=["query"])

run.start(
    input_df=user_queries_df
)  # note: if you use MFA, you will need to authenticate with the Duo prompt that appears many times.

In [None]:
import time

while run.get_status() == "INVOCATION_IN_PROGRESS":
    time.sleep(3)

### Compute metrics:
Here, we compute context relevance for the retrieval quality of the Tavily web search tool, and groundedness for the summaries generated from the charts to evaluate whether the text summaries are supported by the content in the graphical charts. 

In [None]:
run.compute_metrics(["groundedness", "context_relevance"])