# Multi-agent network with Snowflake tools for querying unstructured and structured data

Adapted from the original [Langgraph multi-agent notebook example](https://github.com/langchain-ai/langgraph/blob/main/docs/docs/tutorials/multi_agent/multi-agent-collaboration.ipynb)

A single agent can usually operate effectively using a handful of tools within a single domain, but even using powerful models like `gpt-4`, it can be less effective at using many tools. 

One way to approach complicated tasks is through a "divide-and-conquer" approach: create a specialized agent for each task or domain and route tasks to the correct "expert". This is an example of a [multi-agent network](https://langchain-ai.github.io/langgraph/concepts/multi_agent/#network) architecture.

This notebook (inspired by the paper [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation](https://arxiv.org/abs/2308.08155), by Wu, et. al.) shows one way to do this using LangGraph.

This notebook is an extension of the multi-agent-collaboration notebook, showing how access to more tools - particularly with private data can enhance the ability of a data agent.

We will slowly build up the agent with more tools, starting with web search, then adding document search via Cortex Search, and lastly replacing document search with a Cortex Agent that can both document search and query snowflake tables in sql via Cortex Analyst.


In [1]:
%%capture --no-stderr
# pip install -U langchain_community langchain_openai langchain_experimental matplotlib langgraph pygraphviz

In [2]:
APP_NAME = "Finance Research Agent with Data 101"  # set this app name for your use case

## Set keys

In [3]:
import os

# need both API keys
os.environ["OPENAI_API_KEY"] = "sk-proj-..."
os.environ["TAVILY_API_KEY"] = "tvly-dev-..."  # get a free one at https://app.tavily.com/home



os.environ["TRULENS_OTEL_TRACING"] = (
    "1"  # to enable OTEL tracing -> note the Snowsight UI experience for now is limited to PuPr customers, not yet supported for OSS.
)

os.environ["SNOWFLAKE_ACCOUNT"] = "SFDEVREL_ENTERPRISE"
os.environ["SNOWFLAKE_USER"] = "JREINI"
os.environ["SNOWFLAKE_USER_PASSWORD"] = "..."
os.environ["SNOWFLAKE_DATABASE"] = "AGENTS_DB"
os.environ["SNOWFLAKE_SCHEMA"] = "NOTEBOOKS"
os.environ["SNOWFLAKE_ROLE"] = "CORTEX_USER_ROLE"
os.environ["SNOWFLAKE_WAREHOUSE"] = "CONTAINER_RUNTIME_WH"
os.environ["SNOWFLAKE_JWT"] = "..."

## Import libraries

In [4]:
import datetime
import json
import time
import datetime
from typing import Annotated, Literal
import os, textwrap
from IPython.display import Image
from IPython.display import display
import uuid
import re
from langchain_core.documents import Document
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_openai import OpenAIEmbeddings
from langchain.load.dump import dumps
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import StructuredTool
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_experimental.utilities import PythonREPL
from langchain_openai import ChatOpenAI
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import MessagesState
from langgraph.graph import StateGraph
from langgraph.prebuilt import create_react_agent
from langgraph.types import Command
from snowflake.snowpark import Session
from trulens.apps.app import TruApp
from trulens.connectors.snowflake import SnowflakeConnector
from trulens.core.otel.instrument import instrument
from trulens.core.run import Run
from trulens.core.run import RunConfig
from trulens.otel.semconv.trace import BASE_SCOPE
from trulens.otel.semconv.trace import SpanAttributes
from trulens.providers.openai import OpenAI

## Create TruLens/Snowflake Connection

In [None]:
# Snowflake account for trulens
snowflake_connection_parameters = {
    "account": os.environ["SNOWFLAKE_ACCOUNT"],
    "user": os.environ["SNOWFLAKE_USER"],
    "password": os.environ["SNOWFLAKE_USER_PASSWORD"],
    "database": os.environ["SNOWFLAKE_DATABASE"],
    "schema": os.environ["SNOWFLAKE_SCHEMA"],
    "role": os.environ["SNOWFLAKE_ROLE"],
    "warehouse": os.environ["SNOWFLAKE_WAREHOUSE"],
}
snowpark_session_trulens = Session.builder.configs(
    snowflake_connection_parameters
).create()


trulens_sf_connector = SnowflakeConnector(
    snowpark_session=snowpark_session_trulens
)

### Define the agent with web search and charting tools

In [6]:
def build_graph(search_max_results: int = 5):
    def make_system_prompt(suffix: str) -> str:
        return (
            "You are a helpful AI assistant, collaborating with other assistants."
            " Use the provided tools to progress towards answering the question."
            " If you are unable to fully answer, that's OK, another assistant with different tools "
            " will help where you left off. Execute what you can to make progress."
            " If you or any of the other assistants have the final answer or deliverable,"
            " prefix your response with FINAL ANSWER so the team knows to stop."
            f"\n{suffix}"
        )

    tavily_tool = TavilySearchResults(max_results=search_max_results)

    tool_registry = {str(uuid.uuid4()): tavily_tool}

    # Index tool descriptions in a vector store for semantic tool retrieval
    tool_documents = [
        Document(
            page_content=tavily_tool.description,
            id=tid,
            metadata={"tool_name": tavily_tool.name},
        )
        for tid, tavily_tool in tool_registry.items()
    ]
    vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings())
    vector_store.add_documents(tool_documents)

    @instrument(
        span_type="SELECT_TOOLS",
        attributes=lambda ret, exc, *args, **kw: {
            # ---- state as JSON-text (OTLP needs a scalar) -----------------
            f"{BASE_SCOPE}.select_tools_input_state":
                json.dumps(                                      # ← turns dict → str
                    {
                        **{k: v for k, v in args[0].items() if k != "messages"},
                        "messages": [
                            {"type": m.__class__.__name__, "content": m.content}
                            if hasattr(m, "content")             # BaseMessage subclasses
                            else m                               # already JSON-friendly
                            for m in args[0].get("messages", [])
                        ],
                    }
                ),

            # ---- selected tool IDs as a simple comma-separated string -----
            f"{BASE_SCOPE}.selected_tool_ids":
                ", ".join(ret.get("selected_tools", [])) if isinstance(ret, dict) else "",
        },
    )
    def select_tools(state: MessagesState):
        """
        This tool is used to select tools based on the user query.
        The tool uses the vector store to find the most semantically similar tool to the user query.
        The tool then returns the selected tool's ID."""
        messages = state["messages"]

        last_user_query = messages[-1]["content"] if isinstance(messages[-1], dict) \
                        else messages[-1].content
        
        print(last_user_query)

        docs = vector_store.similarity_search(last_user_query)
        return {"selected_tools": [doc.id for doc in docs]}


    # Warning: This executes code locally, which can be unsafe when not sandboxed

    repl = PythonREPL()

    llm = ChatOpenAI(model="gpt-4o")

    def ensure_chart_dir():
        target_dir = "./langgraph_saved_images_snowflaketools/v1"
        os.makedirs(target_dir, exist_ok=True)
        return target_dir

    
    @tool
    @instrument(
        span_type="PYTHON_REPL_TOOL",
        attributes={
            f"{BASE_SCOPE}.python_tool_input_code": "code",
        },
    )
    def python_repl_tool(code: str):
        """
        Run arbitrary Python, grab the CURRENT matplotlib figure (if any),
        save it to ./langgraph_saved_images_snowflaketools/v1/chart_<uuid>.png,
        and return a first-line `CHART_PATH=…`.
        """
        import matplotlib
        matplotlib.use("Agg")                 # headless safety
        import matplotlib.pyplot as plt
        import uuid, os, textwrap, io, contextlib, sys

        # ------------------ run user code & capture stdout ------------------
        repl.run(code)

        # ------------------ locate a figure (if generated) ------------------
        fig = plt.gcf()
        has_axes = bool(fig.axes)            # True if something was plotted

        # ------------------ always save if we have a figure -----------------
        chart_path = ""
        if has_axes:
            target_dir = "./langgraph_saved_images_snowflaketools/v1"
            os.makedirs(target_dir, exist_ok=True)
            chart_path = os.path.join(
                target_dir, f"chart_{uuid.uuid4().hex}.png"
            )
            fig.savefig(chart_path, format="png")
            plt.close(fig)

        # ------------------ tool result (1st line = CHART_PATH) -------------
        return (
            f"CHART_PATH={chart_path if chart_path else 'NONE'}\n"
        )

        

    def get_next_node(last_message: BaseMessage, goto: str):
        if "FINAL ANSWER" in last_message.content:
            # Any agent decided the work is done
            return END
        return goto

    # Research agent and node
    research_agent = create_react_agent(
        llm,
        tools=[tavily_tool],
        prompt=make_system_prompt(
            "You can only do research. You are working with both a chart generator and a chart summarizer colleagues."
        ),
    )


    @instrument(
        span_type="RESEARCH_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.research_node_input": args[0]["messages"][
                -1
            ].content,
            f"{BASE_SCOPE}.research_node_response": ret.update["messages"][
                -1
            ].content
            if hasattr(ret, "update")
            else json.dumps(ret, indent=4, sort_keys=True),
            f"{BASE_SCOPE}.tool_messages": [
                dumps(message)
                for message in ret.update["messages"]
                if isinstance(message, ToolMessage)
            ]
            if hasattr(ret, "update")
            else "No tool call",
        },
    )
    @instrument(
        span_type=SpanAttributes.SpanType.RETRIEVAL,
        attributes=lambda ret, exception, *args, **kwargs: {
            SpanAttributes.RETRIEVAL.QUERY_TEXT: args[0]["messages"][
                -1
            ].content,
            SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS: [
                ret.update["messages"][-1].content
            ]
            if hasattr(ret, "update")
            else [json.dumps(ret, indent=4, sort_keys=True)],
        },
    )
    def research_agent_node(state: MessagesState) -> Command[Literal["chart_generator"]]:
        """
        This function represents the research agent node in the workflow.
        It selects tools based on the user's query and invokes the research agent
        to perform the research task. If tools are selected, they are bound to the
        LLM, and a new research agent is created with the selected tools. Otherwise,
        the default research agent is used.

        The function determines the next node in the workflow based on the content
        of the last message. If the message contains "FINAL ANSWER," the workflow
        ends. Otherwise, it transitions to the chart generator node.
        """
        selected_ids = state.get("selected_tools", [])
        if selected_ids:
            selected_tools = [tool_registry[tid] for tid in selected_ids]
            # Bind the selected tools to the LLM
            bound_llm = llm.bind_tools(selected_tools)
            # Create a research agent using the bound LLM and same prompt
            bound_research_agent = create_react_agent(
                bound_llm,
                tools=[],  # tools are already bound
                prompt=research_agent.prompt,
            )
            result = bound_research_agent.invoke(state)
        else:
            result = research_agent.invoke(state)
        goto = get_next_node(result["messages"][-1], "chart_generator")
        result["messages"][-1] = HumanMessage(content=result["messages"][-1].content, name="researcher")
        return Command(update={"messages": result["messages"]}, goto=goto)

    # Chart generator agent and node
    # NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED
    chart_agent = create_react_agent(
        llm,
        [python_repl_tool],
        prompt=make_system_prompt(
            "You can only generate charts. The generated chart should be save at a local directory at current directory PATH './langgraph_saved_images_snowflaketools/v1' , and this PATH should be sent to your colleague. You are working with a chart summarizer colleague."
        ),
    )

    def extract_chart_path(text: str) -> str | None:
        """
        Returns the first CHART_PATH=… found in `text`, else None.
        """
        m = re.search(r"^CHART_PATH=(.+)$", text, flags=re.MULTILINE)
        return m.group(1).strip() if m else None

    @instrument(
        span_type="CHART_GENERATOR_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.chart_node_input": args[0]["messages"][-1].content,
            f"{BASE_SCOPE}.chart_node_response": (
                ret.update["messages"][-1].content if ret and hasattr(ret, "update") and ret.update else "No update response"
            ),
        },
    )
    def chart_node(state: MessagesState) -> Command[Literal["chart_summarizer"]]:
        """
        This function represents the chart generation node in the workflow.
        It invokes the chart generation agent to create a chart based on the provided state.
        The generated chart is saved to a specified directory, and its path is extracted from the tool messages.
        A summary prompt is then prepared to send to the chart summarizer agent.
        If the chart path is not found, the workflow ends; otherwise, it transitions to the chart summarizer node.
        """
        # 1. let the agent run
        result = chart_agent.invoke(state)

        # 2. try to grab a path from any ToolMessage
        chart_path = None
        for msg in result["messages"]:
            if isinstance(msg, ToolMessage):
                chart_path = extract_chart_path(msg.content)
                if chart_path and chart_path.upper() != "NONE":
                    break

        print(f"Chart Path: {chart_path!r}")

        # 3. If no valid chart path, return to researcher with error message
        if not chart_path or chart_path.upper() == "NONE":
            result["messages"].append(
                HumanMessage(
                    content="Failed to generate chart. Please try again with different parameters.",
                    name="chart_generator"
                )
            )
            return Command(update={"messages": result["messages"]}, goto="researcher")
            
        # 4. Add chart path to messages for downstream nodes
        result["messages"].append(
            HumanMessage(
                content=f"CHART_PATH={chart_path}",
                name="chart_generator"
            )
        )

        # 5. Add summary prompt
        summary_prompt = (
            "Please summarise the chart in ≤ 3 sentences."
        )
        result["messages"].append(
            HumanMessage(name="chart_generator", content=summary_prompt)
        )

        return Command(update={"messages": result["messages"]}, goto="chart_summarizer")


    # Build the image captioning agent.
    # If you have any specific image processing tools (e.g., for extracting chart images),
    # you can add them in the tools list. For now, we leave it empty.
    chart_summary_agent = create_react_agent(
        llm,
        tools=[],  # Add image processing tools if available/needed.
        prompt=make_system_prompt(
                    """You can only generate charts with Python.
        ALWAYS:
        1. Save the figure as PNG to './langgraph_saved_images_snowflaketools/v1'.
        2. Do NOT display the image inline.
        3. End your reply with `CHART_PATH=<absolute-or-relative-path>`.
        You are working with a summariser colleague."""
        ),
    )

    reflection_prompt_template = PromptTemplate(
        input_variables=["user_query", "chart_summary"],
        template="""\
        You are an AI assistant tasked with reflecting on the quality of a chart summary. The user has asked the following question:
        "{user_query}"

        You are given the following chart summary:
        "{chart_summary}"

        Your task is to evaluate how well the chart summary answers the user's question. Consider the following:
        - Does the summary capture the **key insights** and trends from the chart, even if in a more general form?
        - Does it provide **adequate context** to address the user's query, even if it's not exhaustive?
        - If the summary provides some context but could benefit from more details, consider it sufficient for now unless significant details are missing.

        If the summary **generally** addresses the question, respond with 'Task complete'. If the summary **lacks significant** details or clarity, then respond with 'Refine summary'. Avoid being overly critical unless the summary completely misses key elements necessary to answer the query.

        Please provide your answer in a **concise and encouraging** manner.
        """
            )

    # Create the chain using the prompt template and the LLM (ChatOpenAI)
    reflection_chain = (
        reflection_prompt_template | llm
    )

    @instrument(
        span_type="CHART_SUMMARY_REFLECTION",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.chart_summary_reflection_input_user_query": args[0],
            f"{BASE_SCOPE}.chart_summary_reflection_input_chart_summary": args[1],
            f"{BASE_SCOPE}.chart_summary_reflection_response": ret,
        },
)
    def perform_reflection(user_query: str, chart_summary: str) -> str:
        """
        This function uses an LLM to reflect on the quality of a chart summary
        and determine if the task is complete or requires further refinement.
        """
        # Call the chain with the user query and chart summary
        reflection_result = reflection_chain.invoke({"user_query": user_query, "chart_summary": chart_summary})
        return reflection_result.content


    @instrument(
        span_type="CHART_SUMMARY_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.summary_node_input": args[0]["messages"][-1].content,
            f"{BASE_SCOPE}.summary_node_output": (
                ret.update["messages"][-1].content
                if hasattr(ret, "update") else "NO SUMMARY GENERATED"
            ),
        },
    )
    def chart_summary_node(state: MessagesState) -> Command[Literal[END]]:
        """
        This function represents the chart summarizer node in the workflow graph.
        It uses the chart summary agent to generate a concise summary for the chart image
        provided by the chart generator node. The summary is limited to three sentences
        and is based on the chart image saved at the specified local path.
        """
        # 1. Extract chart path from messages
        chart_path = None
        for msg in state["messages"]:
            if isinstance(msg, ToolMessage) and "CHART_PATH=" in msg.content:
                chart_path = extract_chart_path(msg.content)
                if chart_path and chart_path.upper() != "NONE":
                    break

        print(f"Chart Path in Chart Summary Node: {chart_path!r}")

        # 2. If no valid chart path, return to researcher with error message
        if not chart_path or chart_path.upper() == "NONE":
            return Command(
                update={
                    "messages": state["messages"] + [
                        HumanMessage(
                            content="No valid chart was generated. Please try again.",
                            name="chart_summarizer"
                        )
                    ]
                },
                goto="researcher"
            )

        # 3. Run the summarizer
        result = chart_summary_agent.invoke(state)
        if not result or "messages" not in result:
            return Command(
                update={
                    "messages": state["messages"] + [
                        HumanMessage(
                            content="Failed to generate chart summary. Please try again.",
                            name="chart_summarizer"
                        )
                    ]
                },
                goto="researcher"
            )

        # 4. Add reflection
        user_query = state["messages"][-2].content
        chart_summary = result["messages"][-1].content
        reflection = perform_reflection(user_query, chart_summary)

        # 5. Determine next node
        goto = END if "Task complete" in reflection or "FINAL ANSWER" in reflection else "researcher"
        result["messages"][-1] = HumanMessage(name="chart_summarizer", content=chart_summary)

        return Command(update={"messages": result["messages"]}, goto=goto)



    workflow = StateGraph(MessagesState)
    workflow.add_node("select_tools", select_tools)
    workflow.add_node("research_agent", research_agent_node)
    workflow.add_node("chart_generator", chart_node)
    workflow.add_node("chart_summarizer", chart_summary_node)

    # Update transitions: begin with tool selection then go to research agent.
    workflow.add_edge(START, "select_tools")
    workflow.add_edge("select_tools", "research_agent")
    workflow.add_edge("research_agent", "chart_generator")
    workflow.add_edge("chart_generator", "chart_summarizer")
    # workflow.add_edge("chart_summarizer", "select_tools")
    workflow.add_edge("chart_summarizer", END)
    graph = workflow.compile()

    return graph

## Register the agent and create a run

In [None]:
class TruAgent:
    def __init__(self):
        self.graph = build_graph()

    @instrument(
        span_type=SpanAttributes.SpanType.RECORD_ROOT,
        attributes={
            SpanAttributes.RECORD_ROOT.INPUT: "query",
            SpanAttributes.RECORD_ROOT.OUTPUT: "return",
        },
    )
    def invoke_agent_graph(self, query: str) -> str:
        state = {"messages": [{"role": "user", "content": query}]}
        events = self.graph.stream(
            state,
            # Maximum number of steps to take in the graph
            {"recursion_limit": 150},
        )

        # resp_messages = []

        for event in events:
            node_name, payload = next(iter(event.items()))
            if not payload:                     # <- skip empty payloads
                continue

            messages = payload.get("messages")
            if not messages:
                continue
        return (
            messages[-1].content
            if messages and hasattr(messages[-1], "content")
            else ""
        )


tru_agent = TruAgent()

tru_agent_app = TruApp(
    tru_agent,
    app_name=APP_NAME,
    app_version="web search",
    connector=trulens_sf_connector,
    main_method=tru_agent.invoke_agent_graph,
)

st_1 = datetime.datetime.fromtimestamp(time.time()).strftime(
    "%Y-%m-%d %H:%M:%S"
)

run_config = RunConfig(
    run_name="Multi-agent demo run - web search and charting",
    description="this is a run with access to web search and charting capabilities",
    dataset_name="Research test dataset",
    source_type="DATAFRAME",
    label="langgraph demo",
    dataset_spec={
        "RECORD_ROOT.INPUT": "query",
    },
)

run: Run = tru_agent_app.add_run(run_config)

## Display the agent's graph

In [None]:
display(Image(tru_agent.graph.get_graph().draw_png()))

## Start the run

This runs the agent in batch using the queries in the `input_df`.

In [None]:
import pandas as pd

user_queries = [
    "In 2023, how did the fed funds rate fluctuate? What were the key drivers? Create a line chart that best illustrates this data, including a caption with the key drivers.",
    "What is the total market value of securities by asset class according to SEC filings? Create a bar chart that best illustrates this data.",
    "Who are the top 10 filing managers by number of holdings in the most recent reporting quarter according to SEC filings?",
]

user_queries_df = pd.DataFrame(user_queries, columns=["query"])

run.start(input_df=user_queries_df)

## Compute metrics

In [None]:
import time

while run.get_status() == "INVOCATION_IN_PROGRESS":
    time.sleep(3)

run.compute_metrics(["groundedness", "context_relevance"])

Web is not as precise as it could be if it had access to private minutes data. Let's supplement web search with a document search.

## Add Cortex Search to the agent

In [11]:
from snowflake.snowpark import Session
from snowflake.core import Root
from pydantic import BaseModel

class CortexSearchArgs(BaseModel):
    query: str


# --- Define a new Cortex Search Tool to perform document search via Cortex ---
class CortexSearchTool(StructuredTool):
    name: str = "CortexSearch"
    description: str = "Searches documents using Cortex Search via Snowflake."
    args_schema: type[BaseModel] = CortexSearchArgs
    session: Session

    def run(self, query: str) -> str:
        """
        Executes a search query using the Cortex Search service in Snowflake.

        Args:
            query (str): The search query string.

        Returns:
            str: A JSON string containing the search results, limited to 10 entries.
        """
        root = Root(self.session)
        search_service = (
            root.databases["CORTEX_SEARCH_TUTORIAL_DB"]
            .schemas["PUBLIC"]
            .cortex_search_services["FOMC_SEARCH_SERVICE"]
        )
        resp = search_service.search(query=query, columns=["chunk"], limit=10)
        return resp.to_json()

def build_graph_with_search(search_max_results: int = 5):
    def make_system_prompt(suffix: str) -> str:
        return (
            "You are a helpful AI assistant, collaborating with other assistants."
            " Use the provided tools to progress towards answering the question."
            " If you are unable to fully answer, that's OK, another assistant with different tools "
            " will help where you left off. Execute what you can to make progress."
            " If you or any of the other assistants have the final answer or deliverable,"
            " prefix your response with FINAL ANSWER so the team knows to stop."
            f"\n{suffix}"
        )

    tavily_tool = TavilySearchResults(max_results=search_max_results)

    # Create document search tool using Cortex Search (uses your Snowflake session)
    cortex_search_tool = CortexSearchTool(session=snowpark_session_trulens)

    # The tool registry now includes both the web and document search tools.
    tool_registry = {
        str(uuid.uuid4()): tavily_tool,
        str(uuid.uuid4()): cortex_search_tool,
    }

    # Index tool descriptions in a vector store for semantic tool retrieval
    tool_documents = [
        Document(
            page_content=tavily_tool.description,
            id=tid,
            metadata={"tool_name": tavily_tool.name},
        )
        for tid, tavily_tool in tool_registry.items()
    ]
    vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings())
    vector_store.add_documents(tool_documents)

    @instrument(
        span_type="SELECT_TOOLS",
        attributes=lambda ret, exc, *args, **kw: {
            # ---- state as JSON-text (OTLP needs a scalar) -----------------
            f"{BASE_SCOPE}.select_tools_input_state":
                json.dumps(                                      # ← turns dict → str
                    {
                        **{k: v for k, v in args[0].items() if k != "messages"},
                        "messages": [
                            {"type": m.__class__.__name__, "content": m.content}
                            if hasattr(m, "content")             # BaseMessage subclasses
                            else m                               # already JSON-friendly
                            for m in args[0].get("messages", [])
                        ],
                    }
                ),

            # ---- selected tool IDs as a simple comma-separated string -----
            f"{BASE_SCOPE}.selected_tool_ids":
                ", ".join(ret.get("selected_tools", [])) if isinstance(ret, dict) else "",
        },
    )
    def select_tools(state: MessagesState):
        """
        This tool is used to select tools based on the user query.
        The tool uses the vector store to find the most semantically similar tool to the user query.
        The tool then returns the selected tool's ID."""
        messages = state["messages"]

        last_user_query = messages[-1]["content"] if isinstance(messages[-1], dict) \
                        else messages[-1].content
        
        print(last_user_query)

        docs = vector_store.similarity_search(last_user_query)
        return {"selected_tools": [doc.id for doc in docs]}


    # Warning: This executes code locally, which can be unsafe when not sandboxed

    repl = PythonREPL()

    llm = ChatOpenAI(model="gpt-4o")

    def ensure_chart_dir():
        target_dir = "./langgraph_saved_images_snowflaketools/v2"
        os.makedirs(target_dir, exist_ok=True)
        return target_dir

    
    @tool
    @instrument(
        span_type="PYTHON_REPL_TOOL",
        attributes={
            f"{BASE_SCOPE}.python_tool_input_code": "code",
        },
    )
    def python_repl_tool(code: str):
        """
        Run arbitrary Python, grab the CURRENT matplotlib figure (if any),
        save it to ./langgraph_saved_images_snowflaketools/v2/chart_<uuid>.png,
        and return a first-line `CHART_PATH=…`.
        """
        import matplotlib
        matplotlib.use("Agg")                 # headless safety
        import matplotlib.pyplot as plt
        import uuid, os, textwrap, io, contextlib, sys

        # ------------------ run user code & capture stdout ------------------
        repl.run(code)

        # ------------------ locate a figure (if generated) ------------------
        fig = plt.gcf()
        has_axes = bool(fig.axes)            # True if something was plotted

        # ------------------ always save if we have a figure -----------------
        chart_path = ""
        if has_axes:
            target_dir = "./langgraph_saved_images_snowflaketools/v2"
            os.makedirs(target_dir, exist_ok=True)
            chart_path = os.path.join(
                target_dir, f"chart_{uuid.uuid4().hex}.png"
            )
            fig.savefig(chart_path, format="png")
            plt.close(fig)

        # ------------------ tool result (1st line = CHART_PATH) -------------
        return (
            f"CHART_PATH={chart_path if chart_path else 'NONE'}\n")

        

    def get_next_node(last_message: BaseMessage, goto: str):
        if "FINAL ANSWER" in last_message.content:
            # Any agent decided the work is done
            return END
        return goto

    # Research agent and node
    research_agent = create_react_agent(
        llm,
        tools=[tavily_tool],
        prompt=make_system_prompt(
            "You can only do research. You are working with both a chart generator and a chart summarizer colleagues."
        ),
    )


    @instrument(
        span_type="RESEARCH_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.research_node_input": args[0]["messages"][
                -1
            ].content,
            f"{BASE_SCOPE}.research_node_response": ret.update["messages"][
                -1
            ].content
            if hasattr(ret, "update")
            else json.dumps(ret, indent=4, sort_keys=True),
            f"{BASE_SCOPE}.tool_messages": [
                dumps(message)
                for message in ret.update["messages"]
                if isinstance(message, ToolMessage)
            ]
            if hasattr(ret, "update")
            else "No tool call",
        },
    )
    @instrument(
        span_type=SpanAttributes.SpanType.RETRIEVAL,
        attributes=lambda ret, exception, *args, **kwargs: {
            SpanAttributes.RETRIEVAL.QUERY_TEXT: args[0]["messages"][
                -1
            ].content,
            SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS: [
                ret.update["messages"][-1].content
            ]
            if hasattr(ret, "update")
            else [json.dumps(ret, indent=4, sort_keys=True)],
        },
    )
    def research_agent_node(state: MessagesState) -> Command[Literal["chart_generator"]]:
        """
        This function represents the research agent node in the workflow.
        It selects tools based on the user's query and invokes the research agent
        to perform the research task. If tools are selected, they are bound to the
        LLM, and a new research agent is created with the selected tools. Otherwise,
        the default research agent is used.

        The function determines the next node in the workflow based on the content
        of the last message. If the message contains "FINAL ANSWER," the workflow
        ends. Otherwise, it transitions to the chart generator node.
        """
        selected_ids = state.get("selected_tools", [])
        if selected_ids:
            selected_tools = [tool_registry[tid] for tid in selected_ids]
            # Bind the selected tools to the LLM
            bound_llm = llm.bind_tools(selected_tools)
            # Create a research agent using the bound LLM and same prompt
            bound_research_agent = create_react_agent(
                bound_llm,
                tools=[],  # tools are already bound
                prompt=research_agent.prompt,
            )
            result = bound_research_agent.invoke(state)
        else:
            result = research_agent.invoke(state)
        goto = get_next_node(result["messages"][-1], "chart_generator")
        result["messages"][-1] = HumanMessage(content=result["messages"][-1].content, name="researcher")
        return Command(update={"messages": result["messages"]}, goto=goto)

    # Chart generator agent and node
    # NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED
    chart_agent = create_react_agent(
        llm,
        [python_repl_tool],
        prompt=make_system_prompt(
            "You can only generate charts. The generated chart should be save at a local directory at current directory PATH './langgraph_saved_images_snowflaketools/v1' , and this PATH should be sent to your colleague. You are working with a chart summarizer colleague."
        ),
    )

    def extract_chart_path(text: str) -> str | None:
        """
        Returns the first CHART_PATH=… found in `text`, else None.
        """
        m = re.search(r"^CHART_PATH=(.+)$", text, flags=re.MULTILINE)
        return m.group(1).strip() if m else None

    @instrument(
        span_type="CHART_GENERATOR_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.chart_node_input": args[0]["messages"][-1].content,
            f"{BASE_SCOPE}.chart_node_response": (
                ret.update["messages"][-1].content if ret and hasattr(ret, "update") and ret.update else "No update response"
            ),
        },
    )
    def chart_node(state: MessagesState) -> Command[Literal["chart_summarizer"]]:
        """
        This function represents the chart generation node in the workflow.
        It invokes the chart generation agent to create a chart based on the provided state.
        The generated chart is saved to a specified directory, and its path is extracted from the tool messages.
        A summary prompt is then prepared to send to the chart summarizer agent.
        If the chart path is not found, the workflow ends; otherwise, it transitions to the chart summarizer node.
        """
        # 1. let the agent run
        result = chart_agent.invoke(state)

        # 2. try to grab a path from any ToolMessage
        chart_path = None
        for msg in result["messages"]:
            if isinstance(msg, ToolMessage):
                chart_path = extract_chart_path(msg.content)
                if chart_path and chart_path.upper() != "NONE":
                    break

        print(f"Chart Path: {chart_path!r}")

        # 3. If no valid chart path, return to researcher with error message
        if not chart_path or chart_path.upper() == "NONE":
            result["messages"].append(
                HumanMessage(
                    content="Failed to generate chart. Please try again with different parameters.",
                    name="chart_generator"
                )
            )
            return Command(update={"messages": result["messages"]}, goto="researcher")
            
        # 4. Add chart path to messages for downstream nodes
        result["messages"].append(
            HumanMessage(
                content=f"CHART_PATH={chart_path}",
                name="chart_generator"
            )
        )

        # 5. Add summary prompt
        summary_prompt = (
            "Please summarise the chart in ≤ 3 sentences."
        )
        result["messages"].append(
            HumanMessage(name="chart_generator", content=summary_prompt)
        )

        return Command(update={"messages": result["messages"]}, goto="chart_summarizer")


    # Build the image captioning agent.
    # If you have any specific image processing tools (e.g., for extracting chart images),
    # you can add them in the tools list. For now, we leave it empty.
    chart_summary_agent = create_react_agent(
        llm,
        tools=[],  # Add image processing tools if available/needed.
        prompt=make_system_prompt(
                    """You can only generate charts with Python.
        ALWAYS:
        1. Save the figure as PNG to './langgraph_saved_images_snowflaketools/v2'.
        2. Do NOT display the image inline.
        3. End your reply with `CHART_PATH=<absolute-or-relative-path>`.
        You are working with a summariser colleague."""
        ),
    )

    reflection_prompt_template = PromptTemplate(
        input_variables=["user_query", "chart_summary"],
        template="""\
        You are an AI assistant tasked with reflecting on the quality of a chart summary. The user has asked the following question:
        "{user_query}"

        You are given the following chart summary:
        "{chart_summary}"

        Your task is to evaluate how well the chart summary answers the user's question. Consider the following:
        - Does the summary capture the **key insights** and trends from the chart, even if in a more general form?
        - Does it provide **adequate context** to address the user's query, even if it's not exhaustive?
        - If the summary provides some context but could benefit from more details, consider it sufficient for now unless significant details are missing.

        If the summary **generally** addresses the question, respond with 'Task complete'. If the summary **lacks significant** details or clarity, then respond with 'Refine summary'. Avoid being overly critical unless the summary completely misses key elements necessary to answer the query.

        Please provide your answer in a **concise and encouraging** manner.
        """
            )

    # Create the chain using the prompt template and the LLM (ChatOpenAI)
    reflection_chain = (
        reflection_prompt_template | llm
    )

    @instrument(
        span_type="CHART_SUMMARY_REFLECTION",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.chart_summary_reflection_input_user_query": args[0],
            f"{BASE_SCOPE}.chart_summary_reflection_input_chart_summary": args[1],
            f"{BASE_SCOPE}.chart_summary_reflection_response": ret,
        },
)
    def perform_reflection(user_query: str, chart_summary: str) -> str:
        """
        This function uses an LLM to reflect on the quality of a chart summary
        and determine if the task is complete or requires further refinement.
        """
        # Call the chain with the user query and chart summary
        reflection_result = reflection_chain.invoke({"user_query": user_query, "chart_summary": chart_summary})
        return reflection_result.content


    @instrument(
        span_type="CHART_SUMMARY_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.summary_node_input": args[0]["messages"][-1].content,
            f"{BASE_SCOPE}.summary_node_output": (
                ret.update["messages"][-1].content
                if hasattr(ret, "update") else "NO SUMMARY GENERATED"
            ),
        },
    )
    def chart_summary_node(state: MessagesState) -> Command[Literal[END]]:
        """
        This function represents the chart summarizer node in the workflow graph.
        It uses the chart summary agent to generate a concise summary for the chart image
        provided by the chart generator node. The summary is limited to three sentences
        and is based on the chart image saved at the specified local path.
        """
        # 1. Extract chart path from messages
        chart_path = None
        for msg in state["messages"]:
            if isinstance(msg, ToolMessage) and "CHART_PATH=" in msg.content:
                chart_path = extract_chart_path(msg.content)
                if chart_path and chart_path.upper() != "NONE":
                    break

        print(f"Chart Path in Chart Summary Node: {chart_path!r}")

        # 2. If no valid chart path, return to researcher with error message
        if not chart_path or chart_path.upper() == "NONE":
            return Command(
                update={
                    "messages": state["messages"] + [
                        HumanMessage(
                            content="No valid chart was generated. Please try again.",
                            name="chart_summarizer"
                        )
                    ]
                },
                goto="researcher"
            )

        # 3. Run the summarizer
        result = chart_summary_agent.invoke(state)
        if not result or "messages" not in result:
            return Command(
                update={
                    "messages": state["messages"] + [
                        HumanMessage(
                            content="Failed to generate chart summary. Please try again.",
                            name="chart_summarizer"
                        )
                    ]
                },
                goto="researcher"
            )

        # 4. Add reflection
        user_query = state["messages"][-2].content
        chart_summary = result["messages"][-1].content
        reflection = perform_reflection(user_query, chart_summary)

        # 5. Determine next node
        goto = END if "Task complete" in reflection or "FINAL ANSWER" in reflection else "researcher"
        result["messages"][-1] = HumanMessage(name="chart_summarizer", content=chart_summary)

        return Command(update={"messages": result["messages"]}, goto=goto)



    workflow = StateGraph(MessagesState)
    workflow.add_node("select_tools", select_tools)
    workflow.add_node("research_agent", research_agent_node)
    workflow.add_node("chart_generator", chart_node)
    workflow.add_node("chart_summarizer", chart_summary_node)

    # Update transitions: begin with tool selection then go to research agent.
    workflow.add_edge(START, "select_tools")
    workflow.add_edge("select_tools", "research_agent")
    workflow.add_edge("research_agent", "chart_generator")
    workflow.add_edge("chart_generator", "chart_summarizer")
    # workflow.add_edge("chart_summarizer", "select_tools")
    workflow.add_edge("chart_summarizer", END)
    graph = workflow.compile()

    return graph

In [None]:
import datetime

class TruAgent:
    def __init__(self):
        self.graph = build_graph_with_search()

    @instrument(
        span_type=SpanAttributes.SpanType.RECORD_ROOT,
        attributes={
            SpanAttributes.RECORD_ROOT.INPUT: "query",
            SpanAttributes.RECORD_ROOT.OUTPUT: "return",
        },
    )
    def invoke_agent_graph(self, query: str) -> str:
        events = self.graph.stream(
            {
                "messages": [("user", query)],
            },
            # Maximum number of steps to take in the graph
            {"recursion_limit": 150},
        )

        # resp_messages = []

        for event in events:
            # Grab the payload if it exists
            payload = next(iter(event.values()), None)
            if payload is None:
                continue  # skip this event if no payload
            
            messages = payload.get("messages")
        return (
            messages[-1].content
            if messages and hasattr(messages[-1], "content")
            else ""
        )

tru_agent = TruAgent()

tru_agent_app = TruApp(
    tru_agent,
    app_name=APP_NAME,
    app_version="doc and web search",
    connector=trulens_sf_connector,
    main_method=tru_agent.invoke_agent_graph,
)

st_1 = datetime.datetime.fromtimestamp(time.time()).strftime(
    "%Y-%m-%d %H:%M:%S"
)

run_config = RunConfig(
    run_name="Multi-agent demo run - document and web search",
    description="this is a run with access to cortex search and tavily + qualitative caption",
    dataset_name="Research test dataset",
    source_type="DATAFRAME",
    label="langgraph demo",
    dataset_spec={
        "RECORD_ROOT.INPUT": "query",
    },
)

run: Run = tru_agent_app.add_run(run_config)

In [None]:
display(Image(tru_agent.graph.get_graph().draw_png()))

In [None]:
run.start(
    input_df=user_queries_df
)  # note: if you use MFA, you will need to authenticate with the Duo prompt many times.

In [None]:
import time

while run.get_status() == "INVOCATION_IN_PROGRESS":
    time.sleep(3)

run.compute_metrics(["groundedness", "context_relevance"])

### Use Cortex Agent to gain access to querying structured SEC data without complicating the graph

In [16]:
def build_graph_with_agent():
    def make_system_prompt(suffix: str) -> str:
        return (
            "You are a helpful AI assistant, collaborating with other assistants."
            " Use the provided tools to progress towards answering the question."
            " If you are unable to fully answer, that's OK, another assistant with different tools "
            " will help where you left off. Execute what you can to make progress."
            " If you or any of the other assistants have the final answer or deliverable,"
            " prefix your response with FINAL ANSWER so the team knows to stop."
            f"\n{suffix}"
        )

    tavily_tool = TavilySearchResults(max_results=3)

    # Warning: This executes code locally, which can be unsafe when not sandboxed
    repl = PythonREPL()

    @tool(name="python_repl_tool")
    @instrument(
        span_type="PYTHON_REPL_TOOL",
        attributes={
            f"{BASE_SCOPE}.python_tool_input_code": "code",
        },
    )
    def python_repl_tool(code: str):
        """
        Executes code and saves the chart image even if stdout does not return a figure.
        """
        try:
            chart_uuid = str(uuid.uuid4())
            target_dir = "./langgraph_saved_images_snowflaketools/v1"
            os.makedirs(target_dir, exist_ok=True)
            chart_path = f"{target_dir}/chart_{chart_uuid}.png"

            # Execute user code
            exec_globals = {}
            exec_locals = {}
            exec(code, exec_globals, exec_locals)

            # After user code, forcibly save current matplotlib figure if exists
            import matplotlib.pyplot as plt
            fig = plt.gcf()  # get current figure
            if fig:
                fig.savefig(chart_path)
                print(f"Chart saved at {chart_path}")
            else:
                print("⚠️ No figure generated.")

            return f"✅ Code executed.\n🔗 CHART_PATH={chart_path}"

        except Exception as e:
            return f"❌ Error executing code: {e!r}"


    @instrument(
        span_type="WEB_RESEARCH_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.research_node_input": args[0]["messages"][
                -1
            ].content,
            f"{BASE_SCOPE}.research_node_response": ret.update["messages"][
                -1
            ].content
            if hasattr(ret, "update")
            else json.dumps(ret, indent=4, sort_keys=True),
            f"{BASE_SCOPE}.tool_messages": [
                dumps(message)
                for message in ret.update["messages"]
                if isinstance(message, ToolMessage)
            ]
            if hasattr(ret, "update")
            else "No tool call",
        },
    )
    @instrument(
        span_type=SpanAttributes.SpanType.RETRIEVAL,
        attributes=lambda ret, exception, *args, **kwargs: {
            SpanAttributes.RETRIEVAL.QUERY_TEXT: args[0]["messages"][
                -1
            ].content,
            SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS: [
                ret.update["messages"][-1].content
            ]
            if hasattr(ret, "update")
            else [json.dumps(ret, indent=4, sort_keys=True)],
        },
    )
    def web_research_node(
        state: MessagesState,
    ) -> Command[Literal["chart_generator"]]:
        result = research_agent.invoke(state)
        goto = get_next_node(result["messages"][-1], "chart_generator")
        # wrap in a human message, as not all providers allow
        # AI message at the last position of the input messages list
        result["messages"][-1] = HumanMessage(
            content=result["messages"][-1].content, name="researcher"
        )
        return Command(
            update={
                # share internal message history of research agent with other agents
                "messages": result["messages"],
            },
            goto=goto,
        )

    def get_next_node(last_message: BaseMessage, goto: str):
        if "FINAL ANSWER" in last_message.content:
            # Any agent decided the work is done
            return END
        return goto

    chart_agent = create_react_agent(
        llm,
        [python_repl_tool],
        prompt=make_system_prompt(
            "You can only generate charts. The generated chart should be save at a local directory at current directory PATH './langgraph_saved_images_snowflaketools/v3' , and this PATH should be sent to your colleague. You are working with a chart summarizer colleague."
        ),
    )

    @instrument(
        span_type="CHART_GENERATOR_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.chart_node_input": args[0]["messages"][-1].content,
            f"{BASE_SCOPE}.chart_node_response": ret.update["messages"][
                -1
            ].content
            if hasattr(ret, "update")
            else json.dumps(ret, indent=4, sort_keys=True),
        },
    )
    def chart_node(
        state: MessagesState,
    ) -> Command[Literal["chart_summarizer"]]:
        result = chart_agent.invoke(state)
        goto = get_next_node(result["messages"][-1], "chart_summarizer")
        # wrap in a human message, as not all providers allow
        # AI message at the last position of the input messages list
        result["messages"][-1] = HumanMessage(
            content=result["messages"][-1].content, name="chart_generator"
        )
        return Command(
            update={
                # share internal message history of chart agent with other agents
                "messages": result["messages"],
            },
            goto=goto,
        )

    # Build the image captioning agent.
    # If you have any specific image processing tools (e.g., for extracting chart images),
    # you can add them in the tools list. For now, we leave it empty.
    chart_summary_agent = create_react_agent(
        llm,
        tools=[],  # Add image processing tools if available/needed.
        prompt=make_system_prompt(
            "You can only generate image captions. You are working with a researcher colleague and a chart generator colleague. "
            + "Your task is to generate a concise summary for the provided chart image saved at a local PATH, where the PATH should be and only be provided by your chart generator colleague. The summary should be no more than 3 sentences."
        ),
    )

    @instrument(
        span_type="CHART_SUMMARY_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.summary_node_input": args[0]["messages"][-1].content,
            f"{BASE_SCOPE}.summary_node_output": ret.update["messages"][
                -1
            ].content
            if hasattr(ret, "update")
            else "NO SUMMARY GENERATED",
        },
    )
    def chart_summary_node(
        state: MessagesState,
    ) -> Command[Literal["researcher", END]]:
        result = chart_summary_agent.invoke(state)
        # Determine the next node based on the content of the last message
        goto = get_next_node(result["messages"][-1], "researcher")
        # Wrap the output message in a HumanMessage to maintain consistency in the conversation flow.
        result["messages"][-1] = HumanMessage(
            content=result["messages"][-1].content, name="chart_summarizer"
        )
        return Command(
            update={"messages": result["messages"]},
            goto=goto,
        )

    import requests

    @instrument(
        span_type="RESEARCH_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.research_node_input": args[0]["messages"][
                -1
            ].content,
            f"{BASE_SCOPE}.research_node_response": ret.update["messages"][
                -1
            ].content
            if hasattr(ret, "update")
            else json.dumps(ret, indent=4, sort_keys=True),
            f"{BASE_SCOPE}.tool_messages": [
                dumps(message)
                for message in ret.update["messages"]
                if isinstance(message, ToolMessage)
            ]
            if hasattr(ret, "update")
            else "No tool call",
        },
    )
    @instrument(
        span_type=SpanAttributes.SpanType.RETRIEVAL,
        attributes=lambda ret, exception, *args, **kwargs: {
            SpanAttributes.RETRIEVAL.QUERY_TEXT: args[0]["messages"][
                -1
            ].content,
            SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS: [
                ret.update["messages"][-1].content
            ]
            if hasattr(ret, "update")
            else [json.dumps(ret, indent=4, sort_keys=True)],
        },
    )
    def research_node(
        state: MessagesState,
    ) -> Command[Literal["researcher", END]]:
        # Extract the user query from the state.
        user_query = state["messages"][0].content if state["messages"] else ""
        # Prepare the payload for the Cortex Agents API.
        payload = {
            "model": "gpt-4.1-mini",
            "response_instruction": "You are a helpful AI assistant, collaborate with colleagues if needed.",
            "experimental": {},
            "tools": [
                {
                    "tool_spec": {
                        "type": "cortex_analyst_text_to_sql",
                        "name": "Analyst1",
                    }
                },
                {"tool_spec": {"type": "cortex_search", "name": "Search1"}},
            ],
            "tool_resources": {
                "Analyst1": {
                    "semantic_model_file": "@agents_db.notebooks.semantic_models.sec_filings.yaml"
                },
                "Search1": {
                    "name": "CORTEX_SEARCH_TUTORIAL_DB.PUBLIC.FOMC_SEARCH_SERVICE"
                },
            },
            "tool_choice": {"type": "auto"},
            "messages": [
                {
                    "role": "user",
                    "content": [{"type": "text", "text": user_query}],
                }
            ],
        }
        # Define the API endpoint and headers.
        api_url = "http://SFDEVREL-SFDEVREL_ENTERPRISE.snowflakecomputing.com/api/v2/cortex/agent:run"
        headers = {
            "Authorization": os.environ.get("SNOWFLAKE_JWT"),
            "Content-Type": "application/json",
        }
        try:
            response = requests.post(api_url, json=payload, headers=headers)
            if response.status_code != 200:
                result_text = f"Failed Cortex Agents API call with status code {response.status_code}: {response.text}"
            else:
                data = response.json()
                # Aggregate the text content from the response's delta message.
                contents = data.get("delta", {}).get("content", [])
                result_parts = [
                    chunk.get("text", "")
                    for chunk in contents
                    if chunk.get("type") == "text"
                ]
                result_text = " ".join(result_parts).strip() or str(data)
        except Exception as e:
            result_text = f"Exception during API call: {str(e)}"
        from langchain_core.messages import HumanMessage

        result_message = HumanMessage(content=result_text, name="researcher")
        goto = (
            "researcher"
            if "FINAL ANSWER" not in result_message.content
            else END
        )
        return Command(
            update={"messages": state["messages"] + [result_message]},
            goto=goto,
        )

    # Build the workflow graph
    workflow = StateGraph(MessagesState)

    # Add all nodes
    workflow.add_node("researcher", research_node)
    workflow.add_node("web_researcher", web_research_node)
    workflow.add_node("chart_generator", chart_node)
    workflow.add_node("chart_summarizer", chart_summary_node)

    # Entry point
    workflow.set_entry_point("researcher")

    # Parallel or conditional edge from researcher
    workflow.add_edge(
        "researcher", "web_researcher"
    )  # optional: could be conditional
    workflow.add_edge("researcher", "chart_generator")

    # Web researcher supplements data before chart
    workflow.add_edge("web_researcher", "chart_generator")

    # After chart is generated, summarize it
    workflow.add_edge("chart_generator", "chart_summarizer")

    # End of flow
    workflow.add_edge("chart_summarizer", END)

    graph = workflow.compile()
    return graph

In [None]:
class TruAgent:
    def __init__(self):
        self.graph = build_graph_with_agent()

    @instrument(
        span_type=SpanAttributes.SpanType.RECORD_ROOT,
        attributes={
            SpanAttributes.RECORD_ROOT.INPUT: "query",
            SpanAttributes.RECORD_ROOT.OUTPUT: "return",
        },
    )
    def invoke_agent_graph(self, query: str) -> str:
        events = self.graph.stream(
            {
                "messages": [("user", query)],
            },
            # Maximum number of steps to take in the graph
            {"recursion_limit": 150},
        )

        # resp_messages = []

        for event in events:
            messages = list(event.values())[0]["messages"]
        return (
            messages[-1].content
            if messages and hasattr(messages[-1], "content")
            else ""
        )


tru_agent = TruAgent()

tru_agent_app = TruApp(
    tru_agent,
    app_name=APP_NAME,
    app_version="cortex agent + web search",
    connector=trulens_sf_connector,
    main_method=tru_agent.invoke_agent_graph,
)

st_1 = datetime.datetime.fromtimestamp(time.time()).strftime(
    "%Y-%m-%d %H:%M:%S"
)

run_config = RunConfig(
    run_name="Multi-agent demo run - cortex agent + web search + charting",
    description="this is a run with access to cortex agent, with internally uses cortex search and analyst as tools",
    dataset_name="Research test dataset",
    source_type="DATAFRAME",
    label="langgraph demo",
    dataset_spec={
        "RECORD_ROOT.INPUT": "query",
    },
)

run: Run = tru_agent_app.add_run(run_config)

In [None]:
import pandas as pd

user_queries = [
    "In 2023, how did the fed funds rate fluctuate? What were the key drivers? Create a line chart that best illustrates this data, including a caption with the key drivers. Once the chart is generated, summarize the chart, and finish.",
    "What is the total market value of securities by asset class according to SEC filings? Create a bar chart that best illustrates this data. Once the chart is generated, summarize the chart, and finish.",
    "Who are the top 10 filing managers by number of holdings in the most recent reporting quarter according to SEC filings? Create a bar chart that best illustrates this data. Once the chart is generated, summarize the chart, and finish.",
]

user_queries_df = pd.DataFrame(user_queries, columns=["query"])

run.start(
    input_df=user_queries_df
)  # note: if you use MFA, you will need to authenticate with the Duo prompt many times.

In [None]:
import time

while run.get_status() == "INVOCATION_IN_PROGRESS":
    time.sleep(3)

run.compute_metrics(["groundedness", "context_relevance"])