# Multi-agent network with Snowflake tools for querying unstructured and structured data

Adapted from the original [Langgraph multi-agent notebook example](https://github.com/langchain-ai/langgraph/blob/main/docs/docs/tutorials/multi_agent/multi-agent-collaboration.ipynb)

A single agent can usually operate effectively using a handful of tools within a single domain, but even using powerful models like `gpt-4`, it can be less effective at using many tools. 

This notebook is an extension of the multi-agent-collaboration notebook, showing how access to more tools - particularly with private data can enhance the ability of a data agent.

We will slowly build up the agent with more tools, starting with web search, then adding document search via Cortex Search, and lastly replacing document search with a Cortex Agent that can both document search and query snowflake tables in sql via Cortex Analyst.

We also make some useful improvements to the agentic flow in this notebook to handle the more complex set of tools: namely a reflection loop and safe exit. These improvements dramatically improve the efficiency of the agent.


In [None]:
%%capture --no-stderr
# pip install -U langchain_community langchain_openai langchain_experimental matplotlib langgraph pygraphviz google-search-results

In [None]:
APP_NAME = (
    "Finance Data and Research Agent"  # set this app name for your use case
)

## Set keys

In [None]:
import os

# need both API keys
os.environ["OPENAI_API_KEY"] = "sk-proj-..."
os.environ["SERPAPI_API_KEY"] = "..."

os.environ["SNOWFLAKE_ACCOUNT"] = "SFDEVREL_ENTERPRISE"
os.environ["SNOWFLAKE_USER"] = "JREINI"
os.environ["SNOWFLAKE_USER_PASSWORD"] = "..."
os.environ["SNOWFLAKE_DATABASE"] = "AGENTS_DB"
os.environ["SNOWFLAKE_SCHEMA"] = "NOTEBOOKS"
os.environ["SNOWFLAKE_ROLE"] = "CORTEX_USER_ROLE"
os.environ["SNOWFLAKE_WAREHOUSE"] = "CONTAINER_RUNTIME_WH"
os.environ["SNOWFLAKE_PAT"] = "..."

os.environ["TRULENS_OTEL_TRACING"] = (
    "1"  # to enable OTEL tracing -> note the Snowsight UI experience for now is limited to PuPr customers, not yet supported for OSS.
)

## Import libraries

In [None]:
from copy import deepcopy
import datetime
import json
import os
import sys
import time
from typing import List, Literal
import uuid

from IPython.display import Image
from IPython.display import display
from langchain.load.dump import dumps
from langchain.prompts import PromptTemplate
from langchain_community.utilities import SerpAPIWrapper
from langchain_core.documents import Document
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from langchain_core.messages import ToolMessage
from langchain_core.tools import StructuredTool
from langchain_core.tools import Tool
from langchain_core.tools import tool
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_experimental.utilities import PythonREPL
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import MessagesState
from langgraph.graph import StateGraph
from langgraph.managed.is_last_step import RemainingSteps
from langgraph.prebuilt import create_react_agent
from langgraph.types import Command
from pydantic import BaseModel
from snowflake.snowpark import Session
from trulens.apps.app import TruApp
from trulens.connectors.snowflake import SnowflakeConnector
from trulens.core.otel.instrument import instrument
from trulens.core.run import Run
from trulens.core.run import RunConfig
from trulens.otel.semconv.trace import BASE_SCOPE
from trulens.otel.semconv.trace import SpanAttributes

## Create TruLens/Snowflake Connection

In [None]:
# Snowflake account for trulens
snowflake_connection_parameters = {
    "account": os.environ["SNOWFLAKE_ACCOUNT"],
    "user": os.environ["SNOWFLAKE_USER"],
    "password": os.environ["SNOWFLAKE_USER_PASSWORD"],
    "database": os.environ["SNOWFLAKE_DATABASE"],
    "schema": os.environ["SNOWFLAKE_SCHEMA"],
    "role": os.environ["SNOWFLAKE_ROLE"],
    "warehouse": os.environ["SNOWFLAKE_WAREHOUSE"],
}
snowpark_session_trulens = Session.builder.configs(
    snowflake_connection_parameters
).create()


trulens_sf_connector = SnowflakeConnector(
    snowpark_session=snowpark_session_trulens
)

### Define the agent with web search and charting tools

In [None]:
class ToolState(MessagesState):
    selected_tools: List[str]
    chart_path: str
    remaining_steps: RemainingSteps


def build_graph():
    def canned_end_node(state: ToolState) -> Command[Literal["__end__"]]:
        print("starting CANNED END", flush=True)

        return Command(goto=END)

    def make_system_prompt(suffix: str) -> str:
        return (
            "You are a helpful AI assistant, collaborating with other assistants."
            " Use the provided tools to progress towards answering the question."
            " If you are unable to fully answer, that's OK, another assistant with different tools "
            " will help where you left off. Execute what you can to make progress."
            " If you or any of the other assistants have the final answer or deliverable,"
            " prefix your response with FINAL ANSWER so the team knows to stop."
            f"\n{suffix}"
        )

    search = SerpAPIWrapper()

    search_tool = Tool(
        name="web_search",
        description="Search the web for current information, such as weather or news",
        func=search.run,
    )

    # 4) Register it under a UUID and turn that into a Document
    tool_id = str(uuid.uuid4())
    tool_registry = {tool_id: search_tool}

    tool_documents = [
        Document(
            page_content=search_tool.name,  # your human-readable blurb
            id=tool_id,  # must match the registry key
            metadata={
                "tool_name": tool.name,
                "tool_description": tool.description,
            },
        )
        for tool_id, tool in tool_registry.items()
    ]
    vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings())
    vector_store.add_documents(tool_documents)

    @instrument(
        span_type="SELECT_TOOLS",
        attributes=lambda ret, exc, *args, **kw: {
            # ---- state as JSON-text (OTLP needs a scalar) -----------------
            f"{BASE_SCOPE}.select_tools_input_state": json.dumps(  # ← turns dict → str
                {
                    **{k: v for k, v in args[0].items() if k != "messages"},
                    "messages": [
                        {"type": m.__class__.__name__, "content": m.content}
                        if hasattr(m, "content")  # BaseMessage subclasses
                        else m  # already JSON-friendly
                        for m in args[0].get("messages", [])
                    ],
                }
            ),
            # ---- selected tool IDs as a simple comma-separated string -----
            f"{BASE_SCOPE}.selected_tool_ids": ", ".join(
                ret.get("selected_tools", [])
            )
            if isinstance(ret, dict)
            else "",
        },
    )
    def select_tools(
        state: ToolState,
    ) -> Command[Literal["research_agent", END]]:
        print(state["remaining_steps"])
        if state["remaining_steps"] <= 2:
            print("Bailing out", flush=True)
            return Command(
                update={"messages": state["messages"]},
                goto="canned_end",
            )
        messages = state["messages"]
        last = messages[-1]
        query = last["content"] if isinstance(last, dict) else last.content
        print("selecting tools based on", query)

        # 1. pull top-k with their scores
        results: list[tuple[Document, float]] = (
            vector_store.similarity_search_with_score(
                query,
                k=5,  # look at top-5 candidates
            )
        )
        print("tool search results", results, flush=True)

        # 2. filter by minimum cosine-similarity
        MIN_SIMILARITY = 0.6
        filtered = [doc for doc, score in results if score >= MIN_SIMILARITY]

        # 3a. no sufficiently similar tool → end
        if not filtered:
            print("no tool selected", flush=True)
            msg = HumanMessage(
                content="Sorry, I don’t have a tool that’s relevant enough to answer that.",
                name="assistant",
            )
            return Command(
                update={"messages": messages + [msg]},
                goto=END,
            )

        # 3b. otherwise select those tools and move on
        selected_ids = [doc.id for doc in filtered]
        print("tools selected", selected_ids, flush=True)
        return Command(
            update={
                "selected_tools": selected_ids,
            },
            goto="research_agent",
        )

    # Warning: This executes code locally, which can be unsafe when not sandboxed

    repl = PythonREPL()

    llm = ChatOpenAI(model="gpt-4o")

    @tool
    @instrument(
        span_type="PYTHON_REPL_TOOL",
        attributes={
            f"{BASE_SCOPE}.python_tool_input_code": "code",
        },
    )
    def python_repl_tool(code: str):
        """
        Run arbitrary Python, grab the CURRENT matplotlib figure (if any),
        save it to ./langgraph_saved_images_snowflaketools/v1/chart_<uuid>.png,
        and return a first-line `CHART_PATH=…`.
        """
        import os
        import uuid

        import matplotlib

        matplotlib.use("Agg")
        import matplotlib.pyplot as plt

        # 1) Run the user’s code
        repl.run(code)

        # 2) Check for a figure
        fig = plt.gcf()
        if fig.axes:
            target_dir = "./langgraph_saved_images_snowflaketools/v1"
            os.makedirs(target_dir, exist_ok=True)
            path = os.path.join(target_dir, f"chart_{uuid.uuid4().hex}.png")
            print(path, flush=True)
            fig.savefig(path, format="png")
            plt.close(fig)
        else:
            path = "NONE"

        # 3) Return only the CHART_PATH line
        return f"CHART_PATH={path}\n"

    def get_next_node(last_message: BaseMessage, goto: str):
        if "FINAL ANSWER" in last_message.content:
            # Any agent decided the work is done
            return END
        return goto

    @instrument(
        span_type="RESEARCH_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.research_node_input": args[0]["messages"][
                -1
            ].content,
            f"{BASE_SCOPE}.research_node_response": ret.update["messages"][
                -1
            ].content
            if hasattr(ret, "update")
            else json.dumps(ret, indent=4, sort_keys=True),
            f"{BASE_SCOPE}.tool_messages": [
                dumps(message)
                for message in ret.update["messages"]
                if isinstance(message, ToolMessage)
            ]
            if hasattr(ret, "update")
            else "No tool call",
        },
    )
    @instrument(
        span_type=SpanAttributes.SpanType.RETRIEVAL,
        attributes=lambda ret, exception, *args, **kwargs: {
            SpanAttributes.RETRIEVAL.QUERY_TEXT: args[0]["messages"][
                -1
            ].content,
            SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS: [
                msg.content
                for msg in ret.update["messages"]
                if isinstance(msg, ToolMessage) and msg.content
            ]
            if hasattr(ret, "update") and "messages" in ret.update
            else [],
        },
    )
    def research_agent_node(
        state: ToolState,
    ) -> Command[Literal["chart_generator"]]:
        print(state["remaining_steps"])
        if state["remaining_steps"] <= 2:
            print("Bailing out", flush=True)
            return Command(
                update={"messages": state["messages"]},
                goto="canned_end",
            )
        sys.__stdout__.write("🔍 [research_agent_node] start\n")
        sys.__stdout__.flush()

        # 1) bind & invoke as before
        selected_tools = [tool_registry[tid] for tid in state["selected_tools"]]
        bound_llm = llm.bind_tools(selected_tools)
        bound_agent = create_react_agent(
            bound_llm,
            tools=selected_tools,
            prompt=make_system_prompt("You can only do research…"),
        )

        sys.__stdout__.write("  ⏳ invoking bound_agent.invoke()\n")
        sys.__stdout__.flush()
        result = bound_agent.invoke(state)

        # 2) debug‐dump
        sys.__stdout__.write("  📬 raw research messages:\n")
        for m in result["messages"]:
            sys.__stdout__.write(
                f"    [{m.__class__.__name__}] {getattr(m, 'content', m)!r}\n"
            )
        sys.__stdout__.flush()

        # 3) strip out any tool_calls on those messages
        clean_messages = []
        for msg in result["messages"]:
            # deep‐copy so we don't mutate the original if you care
            m = deepcopy(msg)
            if hasattr(m, "tool_calls"):
                # either empty the list or delete the attr altogether
                m.tool_calls = []
            clean_messages.append(m)

        # 4) routing
        last = clean_messages[-1]
        goto = get_next_node(last, "chart_generator")
        sys.__stdout__.write(f"  ➡ next goto = {goto}\n\n")
        sys.__stdout__.flush()

        # 5) tag the final message as coming from your research agent
        clean_messages[-1] = HumanMessage(
            content=last.content, name="research_agent"
        )

        return Command(
            update={
                "messages": clean_messages,
                "chart_path": state.get("chart_path", ""),
            },
            goto=goto,
        )

    # Chart generator agent and node
    # NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED
    # 1) Define the chart‐agent: it only returns JSON with a "code" field
    chart_agent = create_react_agent(
        llm,
        [python_repl_tool],
        prompt=make_system_prompt(
            """You can only generate charts by returning a single JSON object, for example:
        {
        "code": "<your python plotting code here>"
        }
        —where <your python plotting code> uses matplotlib to create exactly one figure.
        Do NOT include any prose or tool‐call wrappers."""
        ),
    )

    def extract_chart_path(text: str) -> str | None:
        """
        Scan every line of tool stdout for 'CHART_PATH=' and return
        whatever follows, trimmed.  Returns None if no such line exists.
        """
        for line in text.splitlines():
            if "CHART_PATH=" in line:
                # split on the first '=', strip whitespace
                return line.split("CHART_PATH=", 1)[1].strip()
        return None

    @instrument(
        span_type="CHART_GENERATOR_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.chart_node_input": args[0]["messages"][-1].content,
            f"{BASE_SCOPE}.chart_node_response": (
                ret.update["messages"][-1].content
                if ret and hasattr(ret, "update") and ret.update
                else "No update response"
            ),
        },
    )
    def chart_node(state: ToolState) -> Command[Literal["chart_summarizer"]]:
        print(state["remaining_steps"])
        if state["remaining_steps"] <= 2:
            print("Bailing out", flush=True)
            return Command(
                update={"messages": state["messages"]},
                goto="canned_end",
            )
        # 0) If a path is already in state, skip
        # extract the current human query
        current_query = state["messages"][-1].content

        # if we already generated a chart for _this_ query, skip
        if state.get("last_query") == current_query and state.get("chart_path"):
            print(
                f"⚡️ skipping chart_node, existing path = {state['chart_path']}",
                flush=True,
            )
            return Command(
                update={"messages": state["messages"]}, goto="chart_summarizer"
            )

        # it's a new query (or first run) → clear any old chart_path and remember this query
        state.pop("chart_path", None)
        state["last_query"] = current_query

        # 1) Remember how many messages we had
        len_before = len(state["messages"])

        # 2) Run the agent exactly once
        agent_out = chart_agent.invoke(state)

        print(agent_out, flush=True)
        all_msgs = agent_out["messages"]

        # 3) Look at only the brand-new messages for our chart tool output
        new_segment = all_msgs[len_before:]
        tool_msgs = [
            m
            for m in new_segment
            if isinstance(m, ToolMessage) and "CHART_PATH=" in m.content
        ]

        if not tool_msgs:
            # If none found, trigger your retry logic
            print(
                "⚠️ chart_node: no CHART_PATH in new messages, retrying",
                flush=True,
            )
            print(state["remaining_steps"])
            if state["remaining_steps"] <= 2:
                print("Bailing out", flush=True)
                return Command(
                    update={"messages": state["messages"]},
                    goto="canned_end",
                )
            return Command(
                update={"messages": state["messages"]},
                goto="research_agent",
            )

        # 4) Parse the last one in case there are multiples
        tool_msg = tool_msgs[-1]
        tool_stdout = tool_msg.content
        print(f"chart_node 🖨 tool_stdout:\n{tool_stdout}", flush=True)

        chart_path = extract_chart_path(tool_stdout)
        print(f"chart_node 📂 parsed chart_path = {chart_path!r}", flush=True)
        # 5) Build your new messages list: include only that new ToolMessage
        new_msgs = state["messages"][:] + [tool_msg]

        # 6) Success! stash path into state and append the CHART_PATH marker
        new_msgs.append(
            HumanMessage(
                content=f"CHART_PATH={chart_path}", name="chart_generator"
            )
        )
        return Command(
            update={"messages": new_msgs, "chart_path": chart_path},
            goto="chart_summarizer",
        )

    reflection_prompt_template = PromptTemplate(
        input_variables=["user_query", "chart_summary"],
        template="""\
        You are an AI assistant tasked with reflecting on the quality of a chart summary. The user has asked the following question:
        "{user_query}"

        You are given the following chart summary:
        "{chart_summary}"

        Your task is to evaluate how well the chart summary answers the user's question. Consider the following:
        - Does the summary capture the **key insights** and trends from the chart, even if in a more general form?
        - Does it provide **adequate context** to address the user's query, even if it's not exhaustive?
        - If the summary provides some context but could benefit from more details, consider it sufficient for now unless significant details are missing.

        If the summary **generally** addresses the question, respond with 'Task complete'. If the summary **lacks significant** details or clarity, then respond with specific details on how the answer should be improved and what information is needed. Avoid being overly critical unless the summary completely misses key elements necessary to answer the query.

        Please provide your answer in a **concise and encouraging** manner.
        """,
    )

    # Create the chain using the prompt template and the LLM (ChatOpenAI)
    reflection_chain = reflection_prompt_template | llm

    @instrument(
        span_type="CHART_SUMMARY_REFLECTION",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.chart_summary_reflection_input_user_query": args[0],
            f"{BASE_SCOPE}.chart_summary_reflection_input_chart_summary": args[
                1
            ],
            f"{BASE_SCOPE}.chart_summary_reflection_response": ret,
        },
    )
    def perform_reflection(user_query: str, chart_summary: str) -> str:
        """
        This function uses an LLM to reflect on the quality of a chart summary
        and determine if the task is complete or requires further refinement.
        """
        print("doing reflection...")
        # Call the chain with the user query and chart summary
        reflection_result = reflection_chain.invoke({
            "user_query": user_query,
            "chart_summary": chart_summary,
        })
        print("reflection_result: ", reflection_result.content)
        return reflection_result.content

    @instrument(
        span_type="CHART_SUMMARY_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.summary_node_input": args[0]["messages"][-1].content,
            f"{BASE_SCOPE}.summary_node_output": (
                ret.update["messages"][-1].content
                if hasattr(ret, "update")
                else "NO SUMMARY GENERATED"
            ),
        },
    )
    def chart_summary_node(state: ToolState) -> Command[Literal["__end__"]]:
        print("▶️ entering chart_summary_node", flush=True)
        print(state["remaining_steps"])
        if state["remaining_steps"] <= 2:
            print("Bailing out", flush=True)
            return Command(
                update={"messages": state["messages"]},
                goto="canned_end",
            )

        # 1) find the chart_path in state
        chart_path = state.get("chart_path", "")
        print(f"  using state.chart_path = {chart_path!r}", flush=True)
        if not chart_path:
            return Command(
                update={
                    "messages": state["messages"]
                    + [
                        HumanMessage(
                            "No valid chart was generated. Please try again.",
                            name="chart_summarizer",
                        )
                    ]
                },
                goto="select_tools",
            )

        # 2) strip *everything* except human utterances
        human_history = [
            m for m in state["messages"] if isinstance(m, HumanMessage)
        ]

        # ensure our CHART_PATH marker is last
        if not human_history or not human_history[-1].content.startswith(
            "CHART_PATH="
        ):
            human_history.append(
                HumanMessage(
                    f"CHART_PATH={chart_path}", name="chart_summarizer"
                )
            )

        print(
            "  human_history:", [m.content for m in human_history], flush=True
        )

        # 3) build your ChatCompletion prompt
        system = SystemMessage(
            content=make_system_prompt(
                "You are an AI assistant whose *only* job is to summarise a chart image. "
                "Input is a message CHART_PATH=… pointing at a saved PNG. "
                "Output a concise (≤3 sentences) summary of the key trends."
            )
        )

        messages_for_llm = (
            [system]
            + human_history
            + [
                HumanMessage(
                    "Please summarise the above chart in ≤3 sentences."
                )
            ]
        )

        # 4) call the LLM directly—no tools, no React agent
        print("📝 calling ChatOpenAI directly for summary", flush=True)
        ai_msg: AIMessage = llm(messages_for_llm)
        summary = ai_msg.content
        print(f"📋 chart summary: {summary!r}", flush=True)

        # 5) reflect as before
        user_query = state["messages"][0].content
        print("🔍 reflecting on summary quality", flush=True)
        reflection = perform_reflection(user_query, summary)
        clean_ref = reflection.strip().lower()
        print(f"💡 reflection: {reflection!r}", flush=True)

        # 6) decide where to go
        if "task complete" in clean_ref:
            print("✅ done", flush=True)
            return Command(
                update={
                    "messages": state["messages"]
                    + [HumanMessage(summary, name="chart_summarizer")]
                },
                goto=END,
            )
        else:
            print("🔁 need to retry", flush=True)
            return Command(
                update={
                    "messages": state["messages"]
                    + [
                        HumanMessage(summary, name="chart_summarizer"),
                        HumanMessage(reflection, name="chart_reflection"),
                    ]
                },
                goto="select_tools",
            )

    workflow = StateGraph(ToolState)
    workflow.add_node("select_tools", select_tools)
    workflow.add_node("research_agent", research_agent_node)
    workflow.add_node("chart_generator", chart_node)
    workflow.add_node("chart_summarizer", chart_summary_node)
    workflow.add_node("canned_end", canned_end_node)

    # Update transitions: begin with tool selection then go to research agent.
    workflow.add_edge(START, "select_tools")

    # workflow.add_edge("select_tools", END)
    workflow.add_edge("select_tools", "research_agent")
    workflow.add_edge("research_agent", "chart_generator")
    workflow.add_edge("chart_generator", "chart_summarizer")
    workflow.add_edge("chart_summarizer", END)

    workflow.add_edge("canned_end", END)

    compiled_graph = workflow.compile()

    return compiled_graph

## Register the agent and create a run

In [None]:
class TruAgent:
    def __init__(self):
        self.graph = build_graph()

    @instrument(
        span_type=SpanAttributes.SpanType.RECORD_ROOT,
        attributes={
            SpanAttributes.RECORD_ROOT.INPUT: "query",
            SpanAttributes.RECORD_ROOT.OUTPUT: "return",
        },
    )
    def invoke_agent_graph(self, query: str) -> str:
        try:
            # rebuild the graph for each query
            self.graph = build_graph()
            # Initialize state with proper message format
            state = {"messages": [HumanMessage(content=query)]}

            # Stream events with recursion limit
            events = self.graph.stream(
                state,
                {"recursion_limit": 15},
            )

            # Track all messages through the conversation
            all_messages = []
            for event in events:
                # Get the payload from the event
                _, payload = next(iter(event.items()))
                if not payload:  # Skip empty payloads
                    continue

                messages = payload.get("messages")
                if not messages:
                    continue
                all_messages.extend(messages)

            # Return the last message's content if available
            return (
                all_messages[-1].content
                if all_messages and hasattr(all_messages[-1], "content")
                else ""
            )
        except:
            return "I ran into an issue, and cannot answer your question."


tru_agent = TruAgent()

In [None]:
tru_agent_app = TruApp(
    tru_agent,
    app_name=APP_NAME,
    app_version="web search",
    connector=trulens_sf_connector,
    main_method=tru_agent.invoke_agent_graph,
)

In [None]:
st_1 = datetime.datetime.fromtimestamp(time.time()).strftime(
    "%Y-%m-%d %H:%M:%S"
)

run_config = RunConfig(
    run_name="Multi-agent demo run",
    description="this is a run with access to web search and charting capabilities",
    dataset_name="Research test dataset",
    source_type="DATAFRAME",
    label="langgraph demo",
    dataset_spec={
        "RECORD_ROOT.INPUT": "query",
    },
)

run: Run = tru_agent_app.add_run(run_config)

## Display the agent's graph

In [None]:
display(Image(tru_agent.graph.get_graph().draw_mermaid_png()))

## Start the run

This runs the agent in batch using the queries in the `input_df`.

In [None]:
import pandas as pd

user_queries = [
    "In 2023, how did the fed funds rate fluctuate? What were the key drivers? Create a line chart that illustrates this data, including a caption with the key drivers.",
    "What is the total holding value reported by BlackRock Fund Advisors in their SEC filings during 2023?? Create a line chart that illustrates this data",
]

user_queries_df = pd.DataFrame(user_queries, columns=["query"])

run.start(input_df=user_queries_df)

## Compute metrics

In [None]:
import time

while run.get_status() == "INVOCATION_IN_PROGRESS":
    time.sleep(3)

run.compute_metrics(["groundedness", "context_relevance", "answer_relevance"])

Web is not as precise as it could be if it had access to private minutes data. Let's supplement web search with a document search.

## Add Cortex Search to the agent

In [None]:
from snowflake.core import Root
from snowflake.snowpark import Session


class CortexSearchArgs(BaseModel):
    query: str


# --- Define a new Cortex Search Tool to perform document search via Cortex ---
class CortexSearchTool(StructuredTool):
    name: str = "CortexSearch"
    description: str = "Searches documents using Cortex Search via Snowflake."
    args_schema: type[BaseModel] = CortexSearchArgs
    session: Session

    def run(self, query: str) -> str:
        """
        Executes a search query using the Cortex Search service in Snowflake.

        Args:
            query (str): The search query string.

        Returns:
            str: A JSON string containing the search results, limited to 10 entries.
        """
        root = Root(self.session)
        search_service = (
            root.databases["CORTEX_SEARCH_TUTORIAL_DB"]
            .schemas["PUBLIC"]
            .cortex_search_services["FOMC_SEARCH_SERVICE"]
        )
        resp = search_service.search(query=query, columns=["chunk"], limit=10)
        return resp.to_json()


def build_graph_with_search():
    def canned_end_node(state: ToolState) -> Command[Literal["__end__"]]:
        print("starting CANNED END", flush=True)

        return Command(goto=END)

    def make_system_prompt(suffix: str) -> str:
        return (
            "You are a helpful AI assistant, collaborating with other assistants."
            " Use the provided tools to progress towards answering the question."
            " If you are unable to fully answer, that's OK, another assistant with different tools "
            " will help where you left off. Execute what you can to make progress."
            " If you or any of the other assistants have the final answer or deliverable,"
            " prefix your response with FINAL ANSWER so the team knows to stop."
            f"\n{suffix}"
        )

    search = SerpAPIWrapper()

    search_tool = Tool(
        name="web_search",
        description="Search the web for current information, such as weather or news",
        func=search.run,
    )

    # # Create document search tool using Cortex Search (uses your Snowflake session)
    cortex_search_tool = CortexSearchTool(session=snowpark_session_trulens)

    # wrap so sync-compatible
    wrapped_cortex_search_tool = Tool(
        name=cortex_search_tool.name,
        description=cortex_search_tool.description,
        func=cortex_search_tool.run,
        return_direct=False,  # set to True only if you want the agent to stop after using it
    )
    # The tool registry now includes both the web and document search tools.
    tool_registry = {
        str(uuid.uuid4()): search_tool,
        str(uuid.uuid4()): wrapped_cortex_search_tool,
    }

    # Index tool descriptions in a vector store for semantic tool retrieval
    tool_documents = [
        Document(
            page_content=search_tool.name,  # your human-readable blurb
            id=tool_id,  # must match the registry key
            metadata={
                "tool_name": tool.name,
                "tool_description": tool.description,
            },
        )
        for tool_id, tool in tool_registry.items()
    ]
    vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings())
    vector_store.add_documents(tool_documents)

    @instrument(
        span_type="SELECT_TOOLS",
        attributes=lambda ret, exc, *args, **kw: {
            # ---- state as JSON-text (OTLP needs a scalar) -----------------
            f"{BASE_SCOPE}.select_tools_input_state": json.dumps(  # ← turns dict → str
                {
                    **{k: v for k, v in args[0].items() if k != "messages"},
                    "messages": [
                        {"type": m.__class__.__name__, "content": m.content}
                        if hasattr(m, "content")  # BaseMessage subclasses
                        else m  # already JSON-friendly
                        for m in args[0].get("messages", [])
                    ],
                }
            ),
            # ---- selected tool IDs as a simple comma-separated string -----
            f"{BASE_SCOPE}.selected_tool_ids": ", ".join(
                ret.get("selected_tools", [])
            )
            if isinstance(ret, dict)
            else "",
        },
    )
    def select_tools(
        state: ToolState,
    ) -> Command[Literal["research_agent", END]]:
        messages = state["messages"]
        last = messages[-1]
        query = last["content"] if isinstance(last, dict) else last.content
        print("selecting tools based on", query)

        # 1. pull top-k with their scores
        results: list[tuple[Document, float]] = (
            vector_store.similarity_search_with_score(
                query,
                k=5,  # look at top-5 candidates
            )
        )
        print("tool search results", results)

        # 2. filter by minimum cosine-similarity
        MIN_SIMILARITY = 0.7
        filtered = [doc for doc, score in results if score >= MIN_SIMILARITY]

        # 3a. no sufficiently similar tool → end
        if not filtered:
            print("no tool selected")
            msg = HumanMessage(
                content="Sorry, I don’t have a tool that’s relevant enough to answer that.",
                name="assistant",
            )
            return Command(
                update={"messages": messages + [msg]},
                goto=END,
            )

        # 3b. otherwise select those tools and move on
        selected_ids = [doc.id for doc in filtered]
        print("tools selected", selected_ids)
        return Command(
            update={
                "selected_tools": selected_ids,
            },
            goto="research_agent",
        )

    # Warning: This executes code locally, which can be unsafe when not sandboxed

    repl = PythonREPL()

    llm = ChatOpenAI(model="gpt-4o")

    @tool
    @instrument(
        span_type="PYTHON_REPL_TOOL",
        attributes={
            f"{BASE_SCOPE}.python_tool_input_code": "code",
        },
    )
    def python_repl_tool(code: str):
        """
        Run arbitrary Python, grab the CURRENT matplotlib figure (if any),
        save it to ./langgraph_saved_images_snowflaketools/v1/chart_<uuid>.png,
        and return a first-line `CHART_PATH=…`.
        """
        import os
        import uuid

        import matplotlib

        matplotlib.use("Agg")
        import matplotlib.pyplot as plt

        # 1) Run the user’s code
        repl.run(code)

        # 2) Check for a figure
        fig = plt.gcf()
        if fig.axes:
            target_dir = "./langgraph_saved_images_snowflaketools/v1"
            os.makedirs(target_dir, exist_ok=True)
            path = os.path.join(target_dir, f"chart_{uuid.uuid4().hex}.png")
            print(path, flush=True)
            fig.savefig(path, format="png")
            plt.close(fig)
        else:
            path = "NONE"

        # 3) Return only the CHART_PATH line
        return f"CHART_PATH={path}\n"

    def get_next_node(last_message: BaseMessage, goto: str):
        if "FINAL ANSWER" in last_message.content:
            # Any agent decided the work is done
            return END
        return goto

    @instrument(
        span_type="RESEARCH_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.research_node_input": args[0]["messages"][
                -1
            ].content,
            f"{BASE_SCOPE}.research_node_response": ret.update["messages"][
                -1
            ].content
            if hasattr(ret, "update")
            else json.dumps(ret, indent=4, sort_keys=True),
            f"{BASE_SCOPE}.tool_messages": [
                dumps(message)
                for message in ret.update["messages"]
                if isinstance(message, ToolMessage)
            ]
            if hasattr(ret, "update")
            else "No tool call",
        },
    )
    @instrument(
        span_type=SpanAttributes.SpanType.RETRIEVAL,
        attributes=lambda ret, exception, *args, **kwargs: {
            SpanAttributes.RETRIEVAL.QUERY_TEXT: args[0]["messages"][
                -1
            ].content,
            SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS: [
                ret.update["messages"][-1].content
            ]
            if hasattr(ret, "update")
            else [json.dumps(ret, indent=4, sort_keys=True)],
        },
    )
    def research_agent_node(
        state: ToolState,
    ) -> Command[Literal["chart_generator"]]:
        sys.__stdout__.write("🔍 [research_agent_node] start\n")
        sys.__stdout__.flush()

        # 1) bind & invoke as before
        selected_tools = [tool_registry[tid] for tid in state["selected_tools"]]
        bound_llm = llm.bind_tools(selected_tools)
        bound_agent = create_react_agent(
            bound_llm,
            tools=selected_tools,
            prompt=make_system_prompt("You can only do research…"),
        )

        sys.__stdout__.write("  ⏳ invoking bound_agent.invoke()\n")
        sys.__stdout__.flush()
        result = bound_agent.invoke(state)

        # 2) debug‐dump
        sys.__stdout__.write("  📬 raw research messages:\n")
        for m in result["messages"]:
            sys.__stdout__.write(
                f"    [{m.__class__.__name__}] {getattr(m, 'content', m)!r}\n"
            )
        sys.__stdout__.flush()

        # 3) strip out any tool_calls on those messages
        clean_messages = []
        for msg in result["messages"]:
            # deep‐copy so we don't mutate the original if you care
            m = deepcopy(msg)
            if hasattr(m, "tool_calls"):
                # either empty the list or delete the attr altogether
                m.tool_calls = []
            clean_messages.append(m)

        # 4) routing
        last = clean_messages[-1]
        goto = get_next_node(last, "chart_generator")
        sys.__stdout__.write(f"  ➡ next goto = {goto}\n\n")
        sys.__stdout__.flush()

        # 5) tag the final message as coming from your research agent
        clean_messages[-1] = HumanMessage(
            content=last.content, name="research_agent"
        )

        return Command(
            update={
                "messages": clean_messages,
                "chart_path": state.get("chart_path", ""),
            },
            goto=goto,
        )

    # Chart generator agent and node
    # NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED
    # 1) Define the chart‐agent: it only returns JSON with a "code" field
    chart_agent = create_react_agent(
        llm,
        [python_repl_tool],
        prompt=make_system_prompt(
            """You can only generate charts by returning a single JSON object, for example:
        {
        "code": "<your python plotting code here>"
        }
        —where <your python plotting code> uses matplotlib to create exactly one figure.
        Do NOT include any prose or tool‐call wrappers."""
        ),
    )

    def extract_chart_path(text: str) -> str | None:
        """
        Scan every line of tool stdout for 'CHART_PATH=' and return
        whatever follows, trimmed.  Returns None if no such line exists.
        """
        for line in text.splitlines():
            if "CHART_PATH=" in line:
                # split on the first '=', strip whitespace
                return line.split("CHART_PATH=", 1)[1].strip()
        return None

    @instrument(
        span_type="CHART_GENERATOR_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.chart_node_input": args[0]["messages"][-1].content,
            f"{BASE_SCOPE}.chart_node_response": (
                ret.update["messages"][-1].content
                if ret and hasattr(ret, "update") and ret.update
                else "No update response"
            ),
        },
    )
    def chart_node(state: ToolState) -> Command[Literal["chart_summarizer"]]:
        print(state["remaining_steps"])
        if state["remaining_steps"] <= 2:
            print("Bailing out", flush=True)
            return Command(
                update={"messages": state["messages"]},
                goto="canned_end",
            )
        # 0) If a path is already in state, skip
        # extract the current human query
        current_query = state["messages"][-1].content

        # if we already generated a chart for _this_ query, skip
        if state.get("last_query") == current_query and state.get("chart_path"):
            print(
                f"⚡️ skipping chart_node, existing path = {state['chart_path']}",
                flush=True,
            )
            return Command(
                update={"messages": state["messages"]}, goto="chart_summarizer"
            )

        # it's a new query (or first run) → clear any old chart_path and remember this query
        state.pop("chart_path", None)
        state["last_query"] = current_query

        # 1) Remember how many messages we had
        len_before = len(state["messages"])

        # 2) Run the agent exactly once
        agent_out = chart_agent.invoke(state)

        print(agent_out, flush=True)
        all_msgs = agent_out["messages"]

        # 3) Look at only the brand-new messages for our chart tool output
        new_segment = all_msgs[len_before:]
        tool_msgs = [
            m
            for m in new_segment
            if isinstance(m, ToolMessage) and "CHART_PATH=" in m.content
        ]

        if not tool_msgs:
            # If none found, trigger your retry logic
            print(
                "⚠️ chart_node: no CHART_PATH in new messages, retrying",
                flush=True,
            )
            print(state["remaining_steps"])
            if state["remaining_steps"] <= 2:
                print("Bailing out", flush=True)
                return Command(
                    update={"messages": state["messages"]},
                    goto="canned_end",
                )
            return Command(
                update={"messages": state["messages"]},
                goto="research_agent",
            )

        # 4) Parse the last one in case there are multiples
        tool_msg = tool_msgs[-1]
        tool_stdout = tool_msg.content
        print(f"chart_node 🖨 tool_stdout:\n{tool_stdout}", flush=True)

        chart_path = extract_chart_path(tool_stdout)
        print(f"chart_node 📂 parsed chart_path = {chart_path!r}", flush=True)
        # 5) Build your new messages list: include only that new ToolMessage
        new_msgs = state["messages"][:] + [tool_msg]

        # 6) Success! stash path into state and append the CHART_PATH marker
        new_msgs.append(
            HumanMessage(
                content=f"CHART_PATH={chart_path}", name="chart_generator"
            )
        )
        return Command(
            update={"messages": new_msgs, "chart_path": chart_path},
            goto="chart_summarizer",
        )

    reflection_prompt_template = PromptTemplate(
        input_variables=["user_query", "chart_summary"],
        template="""\
        You are an AI assistant tasked with reflecting on the quality of a chart summary. The user has asked the following question:
        "{user_query}"

        You are given the following chart summary:
        "{chart_summary}"

        Your task is to evaluate how well the chart summary answers the user's question. Consider the following:
        - Does the summary capture the **key insights** and trends from the chart, even if in a more general form?
        - Does it provide **adequate context** to address the user's query, even if it's not exhaustive?
        - If the summary provides some context but could benefit from more details, consider it sufficient for now unless significant details are missing.

        If the summary **generally** addresses the question, respond with 'Task complete'. If the summary **lacks significant** details or clarity, then respond with specific details on how the answer should be improved and what information is needed. Avoid being overly critical unless the summary completely misses key elements necessary to answer the query.

        Please provide your answer in a **concise and encouraging** manner.
        """,
    )

    # Create the chain using the prompt template and the LLM (ChatOpenAI)
    reflection_chain = reflection_prompt_template | llm

    @instrument(
        span_type="CHART_SUMMARY_REFLECTION",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.chart_summary_reflection_input_user_query": args[0],
            f"{BASE_SCOPE}.chart_summary_reflection_input_chart_summary": args[
                1
            ],
            f"{BASE_SCOPE}.chart_summary_reflection_response": ret,
        },
    )
    def perform_reflection(user_query: str, chart_summary: str) -> str:
        """
        This function uses an LLM to reflect on the quality of a chart summary
        and determine if the task is complete or requires further refinement.
        """
        print("doing reflection...")
        # Call the chain with the user query and chart summary
        reflection_result = reflection_chain.invoke({
            "user_query": user_query,
            "chart_summary": chart_summary,
        })
        print("reflection_result: ", reflection_result.content)
        return reflection_result.content

    @instrument(
        span_type="CHART_SUMMARY_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.summary_node_input": args[0]["messages"][-1].content,
            f"{BASE_SCOPE}.summary_node_output": (
                ret.update["messages"][-1].content
                if hasattr(ret, "update")
                else "NO SUMMARY GENERATED"
            ),
        },
    )
    def chart_summary_node(state: ToolState) -> Command[Literal["__end__"]]:
        print("▶️ entering chart_summary_node", flush=True)
        print(state["remaining_steps"])
        if state["remaining_steps"] <= 2:
            print("Bailing out", flush=True)
            return Command(
                update={"messages": state["messages"]},
                goto="canned_end",
            )

        # 1) find the chart_path in state
        chart_path = state.get("chart_path", "")
        print(f"  using state.chart_path = {chart_path!r}", flush=True)
        if not chart_path:
            return Command(
                update={
                    "messages": state["messages"]
                    + [
                        HumanMessage(
                            "No valid chart was generated. Please try again.",
                            name="chart_summarizer",
                        )
                    ]
                },
                goto="select_tools",
            )

        # 2) strip *everything* except human utterances
        human_history = [
            m for m in state["messages"] if isinstance(m, HumanMessage)
        ]

        # ensure our CHART_PATH marker is last
        if not human_history or not human_history[-1].content.startswith(
            "CHART_PATH="
        ):
            human_history.append(
                HumanMessage(
                    f"CHART_PATH={chart_path}", name="chart_summarizer"
                )
            )

        print(
            "  human_history:", [m.content for m in human_history], flush=True
        )

        # 3) build your ChatCompletion prompt
        system = SystemMessage(
            content=make_system_prompt(
                "You are an AI assistant whose *only* job is to summarise a chart image. "
                "Input is a message CHART_PATH=… pointing at a saved PNG. "
                "Output a concise (≤3 sentences) summary of the key trends."
            )
        )

        messages_for_llm = (
            [system]
            + human_history
            + [
                HumanMessage(
                    "Please summarise the above chart in ≤3 sentences."
                )
            ]
        )

        # 4) call the LLM directly—no tools, no React agent
        print("📝 calling ChatOpenAI directly for summary", flush=True)
        ai_msg: AIMessage = llm(messages_for_llm)
        summary = ai_msg.content
        print(f"📋 chart summary: {summary!r}", flush=True)

        # 5) reflect as before
        user_query = state["messages"][0].content
        print("🔍 reflecting on summary quality", flush=True)
        reflection = perform_reflection(user_query, summary)
        clean_ref = reflection.strip().lower()
        print(f"💡 reflection: {reflection!r}", flush=True)

        # 6) decide where to go
        if "task complete" in clean_ref:
            print("✅ done", flush=True)
            return Command(
                update={
                    "messages": state["messages"]
                    + [HumanMessage(summary, name="chart_summarizer")]
                },
                goto=END,
            )
        else:
            print("🔁 need to retry", flush=True)
            return Command(
                update={
                    "messages": state["messages"]
                    + [
                        HumanMessage(summary, name="chart_summarizer"),
                        HumanMessage(reflection, name="chart_reflection"),
                    ]
                },
                goto="select_tools",
            )

    workflow = StateGraph(ToolState)
    workflow.add_node("select_tools", select_tools)
    workflow.add_node("research_agent", research_agent_node)
    workflow.add_node("chart_generator", chart_node)
    workflow.add_node("chart_summarizer", chart_summary_node)
    workflow.add_node("canned_end", canned_end_node)

    # Update transitions: begin with tool selection then go to research agent.
    workflow.add_edge(START, "select_tools")

    # workflow.add_edge("select_tools", END)
    workflow.add_edge("select_tools", "research_agent")
    workflow.add_edge("research_agent", "chart_generator")
    workflow.add_edge("chart_generator", "chart_summarizer")
    workflow.add_edge("chart_summarizer", END)

    workflow.add_edge("canned_end", END)

    compiled_graph = workflow.compile()

    return compiled_graph

In [None]:
class TruAgent:
    def __init__(self):
        self.graph = build_graph()

    @instrument(
        span_type=SpanAttributes.SpanType.RECORD_ROOT,
        attributes={
            SpanAttributes.RECORD_ROOT.INPUT: "query",
            SpanAttributes.RECORD_ROOT.OUTPUT: "return",
        },
    )
    def invoke_agent_graph(self, query: str) -> str:
        try:
            # rebuild the graph for each query
            self.graph = build_graph()
            # Initialize state with proper message format
            state = {"messages": [HumanMessage(content=query)]}

            # Stream events with recursion limit
            events = self.graph.stream(
                state,
                {"recursion_limit": 15},
            )

            # Track all messages through the conversation
            all_messages = []
            for event in events:
                # Get the payload from the event
                _, payload = next(iter(event.items()))
                if not payload:  # Skip empty payloads
                    continue

                messages = payload.get("messages")
                if not messages:
                    continue
                all_messages.extend(messages)

            # Return the last message's content if available
            return (
                all_messages[-1].content
                if all_messages and hasattr(all_messages[-1], "content")
                else ""
            )
        except:
            return "I ran into an issue, and cannot answer your question."


tru_agent = TruAgent()

tru_agent_app = TruApp(
    tru_agent,
    app_name=APP_NAME,
    app_version="doc and web search",
    connector=trulens_sf_connector,
    main_method=tru_agent.invoke_agent_graph,
)

st_1 = datetime.datetime.fromtimestamp(time.time()).strftime(
    "%Y-%m-%d %H:%M:%S"
)

run_config = RunConfig(
    run_name="Multi-agent demo run - document and web search",
    description="this is a run with access to cortex search and web search",
    dataset_name="Research test dataset",
    source_type="DATAFRAME",
    label="langgraph demo",
    dataset_spec={
        "RECORD_ROOT.INPUT": "query",
    },
)

run: Run = tru_agent_app.add_run(run_config)

In [None]:
display(Image(tru_agent.graph.get_graph().draw_mermaid_png()))

In [None]:
run.start(input_df=user_queries_df)

In [None]:
import time

while run.get_status() == "INVOCATION_IN_PROGRESS":
    time.sleep(3)

run.compute_metrics(["groundedness", "context_relevance", "answer_relevance"])

### Use Cortex Agent to gain access to querying structured SEC data without complicating the graph

In [None]:
from pydantic import BaseModel
import requests
from snowflake.snowpark import Session


class CortexAgentArgs(BaseModel):
    query: str


class CortexAgentTool(StructuredTool):
    name: str = "CortexAgent"
    description: str = "answers questions using the federal reserve meeting minutes and structured data from the SEC"
    args_schema: type[BaseModel] = CortexAgentArgs
    session: Session


def run(self, query: str, **kwargs) -> str:
    print("calling agent")
    payload = {
        "model": "claude-3-5-sonnet",
        "response_instruction": "You are a helpful AI assistant.",
        "experimental": {},
        "tools": [
            {
                "tool_spec": {
                    "type": "cortex_analyst_text_to_sql",
                    "name": "SEC_ANALYST",
                }
            },
            {"tool_spec": {"type": "cortex_search", "name": "FOMC_SEARCH"}},
            {"tool_spec": {"type": "sql_exec", "name": "sql_execution_tool"}},
        ],
        "tool_resources": {
            "SEC_ANALYST": {
                "semantic_model_file": "@agents_db.notebooks.semantic_models/sec_filings.yaml"
            },
            "FOMC_SEARCH": {
                "name": "CORTEX_SEARCH_TUTORIAL_DB.PUBLIC.FOMC_SEARCH_SERVICE"
            },
        },
        "tool_choice": {"type": "auto"},
        "messages": [
            {"role": "user", "content": [{"type": "text", "text": query}]}
        ],
    }

    api_url = "http://SFDEVREL-SFDEVREL_ENTERPRISE.snowflakecomputing.com/api/v2/cortex/agent:run"
    pat = os.getenv("SNOWFLAKE_PAT")
    if not pat:
        raise RuntimeError("Environment variable SNOWFLAKE_PAT is not set")

    headers = {
        "Authorization": f"Bearer {pat}",
        "X-Snowflake-Authorization-Token-Type": "PROGRAMMATIC_ACCESS_TOKEN",
        "Content-Type": "application/json",
    }

    response = requests.post(api_url, json=payload, headers=headers)
    print("agent response", response)

    if response.status_code != 200:
        print(response.status_code)
        print(response.text)
        return f"Failed Cortex Agents API call: {response.status_code} - {response.text}"

    # Extract content from delta
    data = response.json()
    contents = data.get("delta", {}).get("content", [])
    result_parts = [
        chunk.get("text", "")
        for chunk in contents
        if chunk.get("type") == "text"
    ]
    result_text = " ".join(result_parts).strip()

    # Fallback if content is empty
    return result_text or json.dumps(data, indent=2)


def build_graph_with_agent():
    def canned_end_node(state: ToolState) -> Command[Literal["__end__"]]:
        print("starting CANNED END", flush=True)

        return Command(goto=END)

    def make_system_prompt(suffix: str) -> str:
        return (
            "You are a helpful AI assistant, collaborating with other assistants."
            " Use the provided tools to progress towards answering the question."
            " If you are unable to fully answer, that's OK, another assistant with different tools "
            " will help where you left off. Execute what you can to make progress."
            " If you or any of the other assistants have the final answer or deliverable,"
            " prefix your response with FINAL ANSWER so the team knows to stop."
            f"\n{suffix}"
        )

    search = SerpAPIWrapper()

    search_tool = Tool(
        name="web_search",
        description="Search the web for current information, such as weather or news",
        func=search.run,
    )

    # Instantiate CortexAgentTool
    cortex_agent_tool = CortexAgentTool(session=snowpark_session_trulens)

    wrapped_cortex_agent_tool = Tool(
        name=cortex_agent_tool.name,
        description=cortex_agent_tool.description,
        func=cortex_agent_tool.run,
        return_direct=False,  # set to True only if you want the agent to stop after using it
    )

    tool_registry = {
        str(uuid.uuid4()): search_tool,
        str(uuid.uuid4()): wrapped_cortex_agent_tool,  # CortexAgentTool here
    }

    # Update your tool documents indexing accordingly
    tool_documents = [
        Document(
            page_content=tool.name,
            id=tool_id,
            metadata={
                "tool_name": tool.name,
                "tool_description": tool.description,
            },
        )
        for tool_id, tool in tool_registry.items()
    ]
    vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings())
    vector_store.add_documents(tool_documents)

    llm = ChatOpenAI(model="gpt-4o")

    vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings())
    vector_store.add_documents(tool_documents)

    @instrument(
        span_type="SELECT_TOOLS",
        attributes=lambda ret, exc, *args, **kw: {
            # ---- state as JSON-text (OTLP needs a scalar) -----------------
            f"{BASE_SCOPE}.select_tools_input_state": json.dumps(  # ← turns dict → str
                {
                    **{k: v for k, v in args[0].items() if k != "messages"},
                    "messages": [
                        {"type": m.__class__.__name__, "content": m.content}
                        if hasattr(m, "content")  # BaseMessage subclasses
                        else m  # already JSON-friendly
                        for m in args[0].get("messages", [])
                    ],
                }
            ),
            # ---- selected tool IDs as a simple comma-separated string -----
            f"{BASE_SCOPE}.selected_tool_ids": ", ".join(
                ret.get("selected_tools", [])
            )
            if isinstance(ret, dict)
            else "",
        },
    )
    def select_tools(
        state: ToolState,
    ) -> Command[Literal["research_agent", END]]:
        messages = state["messages"]
        last = messages[-1]
        query = last["content"] if isinstance(last, dict) else last.content
        print("selecting tools based on", query)

        # 1. pull top-k with their scores
        results: list[tuple[Document, float]] = (
            vector_store.similarity_search_with_score(
                query,
                k=5,  # look at top-5 candidates
            )
        )
        print("tool search results", results)

        # 2. filter by minimum cosine-similarity
        MIN_SIMILARITY = 0.7
        filtered = [doc for doc, score in results if score >= MIN_SIMILARITY]

        # 3a. no sufficiently similar tool → end
        if not filtered:
            print("no tool selected")
            msg = HumanMessage(
                content="Sorry, I don’t have a tool that’s relevant enough to answer that.",
                name="assistant",
            )
            return Command(
                update={"messages": messages + [msg]},
                goto=END,
            )

        # 3b. otherwise select those tools and move on
        selected_ids = [doc.id for doc in filtered]
        print("tools selected", selected_ids)
        return Command(
            update={
                "selected_tools": selected_ids,
            },
            goto="research_agent",
        )

    # Warning: This executes code locally, which can be unsafe when not sandboxed

    repl = PythonREPL()

    @tool
    @instrument(
        span_type="PYTHON_REPL_TOOL",
        attributes={
            f"{BASE_SCOPE}.python_tool_input_code": "code",
        },
    )
    def python_repl_tool(code: str):
        """
        Run arbitrary Python, grab the CURRENT matplotlib figure (if any),
        save it to ./langgraph_saved_images_snowflaketools/v3/chart_<uuid>.png,
        and return a first-line `CHART_PATH=…`.
        """
        import matplotlib

        matplotlib.use("Agg")  # headless safety
        import os
        import uuid

        import matplotlib.pyplot as plt

        # ------------------ run user code & capture stdout ------------------
        repl.run(code)

        # ------------------ locate a figure (if generated) ------------------
        fig = plt.gcf()
        has_axes = bool(fig.axes)  # True if something was plotted

        # ------------------ always save if we have a figure -----------------
        chart_path = ""
        if has_axes:
            target_dir = "./langgraph_saved_images_snowflaketools/v3"
            os.makedirs(target_dir, exist_ok=True)
            chart_path = os.path.join(
                target_dir, f"chart_{uuid.uuid4().hex}.png"
            )
            fig.savefig(chart_path, format="png")
            plt.close(fig)

        # ------------------ tool result (1st line = CHART_PATH) -------------
        return f"CHART_PATH={chart_path if chart_path else 'NONE'}\n"

    def get_next_node(last_message: BaseMessage, goto: str):
        if "FINAL ANSWER" in last_message.content:
            # Any agent decided the work is done
            return END
        return goto

    @instrument(
        span_type="RESEARCH_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.research_node_input": args[0]["messages"][
                -1
            ].content,
            f"{BASE_SCOPE}.research_node_response": ret.update["messages"][
                -1
            ].content
            if hasattr(ret, "update")
            else json.dumps(ret, indent=4, sort_keys=True),
            f"{BASE_SCOPE}.tool_messages": [
                dumps(message)
                for message in ret.update["messages"]
                if isinstance(message, ToolMessage)
            ]
            if hasattr(ret, "update")
            else "No tool call",
        },
    )
    @instrument(
        span_type=SpanAttributes.SpanType.RETRIEVAL,
        attributes=lambda ret, exception, *args, **kwargs: {
            SpanAttributes.RETRIEVAL.QUERY_TEXT: args[0]["messages"][
                -1
            ].content,
            SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS: [
                ret.update["messages"][-1].content
            ]
            if hasattr(ret, "update")
            else [json.dumps(ret, indent=4, sort_keys=True)],
        },
    )
    def research_agent_node(
        state: ToolState,
    ) -> Command[Literal["chart_generator"]]:
        """
        Always binds the selected tools and invokes the bound agent.
        Stops on FINAL ANSWER or moves to chart_generator.
        """
        # grab (non-empty) list of selected tool IDs
        selected_ids = state["selected_tools"]

        # bind only those tools
        selected_tools = [tool_registry[tid] for tid in selected_ids]
        bound_llm = llm.bind_tools(selected_tools)
        bound_agent = create_react_agent(
            bound_llm,
            tools=selected_tools,  # already bound
            prompt=make_system_prompt(
                "You can only do research. You are working with both a chart generator and a chart summarizer colleagues."
            ),
        )

        # run it
        result = bound_agent.invoke(state)

        # decide if we’re done
        last = result["messages"][-1]
        goto = get_next_node(last, "chart_generator")

        # tag the origin of the final message
        result["messages"][-1] = HumanMessage(
            content=last.content,
            name="research_agent",
        )

        return Command(
            update={"messages": result["messages"]},
            goto=goto,
        )

    # Chart generator agent and node
    # NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED
    # 1) Define the chart‐agent: it only returns JSON with a "code" field
    chart_agent = create_react_agent(
        llm,
        [python_repl_tool],
        prompt=make_system_prompt(
            """You can only generate charts by returning a single JSON object, for example:
        {
        "code": "<your python plotting code here>"
        }
        —where <your python plotting code> uses matplotlib to create exactly one figure.
        Do NOT include any prose or tool‐call wrappers."""
        ),
    )

    def extract_chart_path(text: str) -> str | None:
        """
        Scan every line of tool stdout for 'CHART_PATH=' and return
        whatever follows, trimmed.  Returns None if no such line exists.
        """
        for line in text.splitlines():
            if "CHART_PATH=" in line:
                # split on the first '=', strip whitespace
                return line.split("CHART_PATH=", 1)[1].strip()
        return None

    @instrument(
        span_type="CHART_GENERATOR_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.chart_node_input": args[0]["messages"][-1].content,
            f"{BASE_SCOPE}.chart_node_response": (
                ret.update["messages"][-1].content
                if ret and hasattr(ret, "update") and ret.update
                else "No update response"
            ),
        },
    )
    def chart_node(state: ToolState) -> Command[Literal["chart_summarizer"]]:
        print(state["remaining_steps"])
        if state["remaining_steps"] <= 2:
            print("Bailing out", flush=True)
            return Command(
                update={"messages": state["messages"]},
                goto="canned_end",
            )
        # 0) If a path is already in state, skip
        # extract the current human query
        current_query = state["messages"][-1].content

        # if we already generated a chart for _this_ query, skip
        if state.get("last_query") == current_query and state.get("chart_path"):
            print(
                f"⚡️ skipping chart_node, existing path = {state['chart_path']}",
                flush=True,
            )
            return Command(
                update={"messages": state["messages"]}, goto="chart_summarizer"
            )

        # it's a new query (or first run) → clear any old chart_path and remember this query
        state.pop("chart_path", None)
        state["last_query"] = current_query

        # 1) Remember how many messages we had
        len_before = len(state["messages"])

        # 2) Run the agent exactly once
        agent_out = chart_agent.invoke(state)

        print(agent_out, flush=True)
        all_msgs = agent_out["messages"]

        # 3) Look at only the brand-new messages for our chart tool output
        new_segment = all_msgs[len_before:]
        tool_msgs = [
            m
            for m in new_segment
            if isinstance(m, ToolMessage) and "CHART_PATH=" in m.content
        ]

        if not tool_msgs:
            # If none found, trigger your retry logic
            print(
                "⚠️ chart_node: no CHART_PATH in new messages, retrying",
                flush=True,
            )
            print(state["remaining_steps"])
            if state["remaining_steps"] <= 2:
                print("Bailing out", flush=True)
                return Command(
                    update={"messages": state["messages"]},
                    goto="canned_end",
                )
            return Command(
                update={"messages": state["messages"]},
                goto="research_agent",
            )

        # 4) Parse the last one in case there are multiples
        tool_msg = tool_msgs[-1]
        tool_stdout = tool_msg.content
        print(f"chart_node 🖨 tool_stdout:\n{tool_stdout}", flush=True)

        chart_path = extract_chart_path(tool_stdout)
        print(f"chart_node 📂 parsed chart_path = {chart_path!r}", flush=True)
        # 5) Build your new messages list: include only that new ToolMessage
        new_msgs = state["messages"][:] + [tool_msg]

        # 6) Success! stash path into state and append the CHART_PATH marker
        new_msgs.append(
            HumanMessage(
                content=f"CHART_PATH={chart_path}", name="chart_generator"
            )
        )
        return Command(
            update={"messages": new_msgs, "chart_path": chart_path},
            goto="chart_summarizer",
        )

    reflection_prompt_template = PromptTemplate(
        input_variables=["user_query", "chart_summary"],
        template="""\
        You are an AI assistant tasked with reflecting on the quality of a chart summary. The user has asked the following question:
        "{user_query}"

        You are given the following chart summary:
        "{chart_summary}"

        Your task is to evaluate how well the chart summary answers the user's question. Consider the following:
        - Does the summary capture the **key insights** and trends from the chart, even if in a more general form?
        - Does it provide **adequate context** to address the user's query, even if it's not exhaustive?
        - If the summary provides some context but could benefit from more details, consider it sufficient for now unless significant details are missing.

        If the summary **generally** addresses the question, respond with 'Task complete'. If the summary **lacks significant** details or clarity, then respond with specific details on how the answer should be improved and what information is needed. Avoid being overly critical unless the summary completely misses key elements necessary to answer the query.

        Please provide your answer in a **concise and encouraging** manner.
        """,
    )

    # Create the chain using the prompt template and the LLM (ChatOpenAI)
    reflection_chain = reflection_prompt_template | llm

    @instrument(
        span_type="CHART_SUMMARY_REFLECTION",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.chart_summary_reflection_input_user_query": args[0],
            f"{BASE_SCOPE}.chart_summary_reflection_input_chart_summary": args[
                1
            ],
            f"{BASE_SCOPE}.chart_summary_reflection_response": ret,
        },
    )
    def perform_reflection(user_query: str, chart_summary: str) -> str:
        """
        This function uses an LLM to reflect on the quality of a chart summary
        and determine if the task is complete or requires further refinement.
        """
        print("doing reflection...")
        # Call the chain with the user query and chart summary
        reflection_result = reflection_chain.invoke({
            "user_query": user_query,
            "chart_summary": chart_summary,
        })
        print("reflection_result: ", reflection_result.content)
        return reflection_result.content

    @instrument(
        span_type="CHART_SUMMARY_NODE",
        attributes=lambda ret, exception, *args, **kwargs: {
            f"{BASE_SCOPE}.summary_node_input": args[0]["messages"][-1].content,
            f"{BASE_SCOPE}.summary_node_output": (
                ret.update["messages"][-1].content
                if hasattr(ret, "update")
                else "NO SUMMARY GENERATED"
            ),
        },
    )
    def chart_summary_node(state: ToolState) -> Command[Literal["__end__"]]:
        print("▶️ entering chart_summary_node", flush=True)
        print(state["remaining_steps"])
        if state["remaining_steps"] <= 2:
            print("Bailing out", flush=True)
            return Command(
                update={"messages": state["messages"]},
                goto="canned_end",
            )

        # 1) find the chart_path in state
        chart_path = state.get("chart_path", "")
        print(f"  using state.chart_path = {chart_path!r}", flush=True)
        if not chart_path:
            return Command(
                update={
                    "messages": state["messages"]
                    + [
                        HumanMessage(
                            "No valid chart was generated. Please try again.",
                            name="chart_summarizer",
                        )
                    ]
                },
                goto="select_tools",
            )

        # 2) strip *everything* except human utterances
        human_history = [
            m for m in state["messages"] if isinstance(m, HumanMessage)
        ]

        # ensure our CHART_PATH marker is last
        if not human_history or not human_history[-1].content.startswith(
            "CHART_PATH="
        ):
            human_history.append(
                HumanMessage(
                    f"CHART_PATH={chart_path}", name="chart_summarizer"
                )
            )

        print(
            "  human_history:", [m.content for m in human_history], flush=True
        )

        # 3) build your ChatCompletion prompt
        system = SystemMessage(
            content=make_system_prompt(
                "You are an AI assistant whose *only* job is to summarise a chart image. "
                "Input is a message CHART_PATH=… pointing at a saved PNG. "
                "Output a concise (≤3 sentences) summary of the key trends."
            )
        )

        messages_for_llm = (
            [system]
            + human_history
            + [
                HumanMessage(
                    "Please summarise the above chart in ≤3 sentences."
                )
            ]
        )

        # 4) call the LLM directly—no tools, no React agent
        print("📝 calling ChatOpenAI directly for summary", flush=True)
        ai_msg: AIMessage = llm(messages_for_llm)
        summary = ai_msg.content
        print(f"📋 chart summary: {summary!r}", flush=True)

        # 5) reflect as before
        user_query = state["messages"][0].content
        print("🔍 reflecting on summary quality", flush=True)
        reflection = perform_reflection(user_query, summary)
        clean_ref = reflection.strip().lower()
        print(f"💡 reflection: {reflection!r}", flush=True)

        # 6) decide where to go
        if "task complete" in clean_ref:
            print("✅ done", flush=True)
            return Command(
                update={
                    "messages": state["messages"]
                    + [HumanMessage(summary, name="chart_summarizer")]
                },
                goto=END,
            )
        else:
            print("🔁 need to retry", flush=True)
            return Command(
                update={
                    "messages": state["messages"]
                    + [
                        HumanMessage(summary, name="chart_summarizer"),
                        HumanMessage(reflection, name="chart_reflection"),
                    ]
                },
                goto="select_tools",
            )

    workflow = StateGraph(ToolState)
    workflow.add_node("select_tools", select_tools)
    workflow.add_node("research_agent", research_agent_node)
    workflow.add_node("chart_generator", chart_node)
    workflow.add_node("chart_summarizer", chart_summary_node)
    workflow.add_node("canned_end", canned_end_node)

    # Update transitions: begin with tool selection then go to research agent.
    workflow.add_edge(START, "select_tools")

    # workflow.add_edge("select_tools", END)
    workflow.add_edge("select_tools", "research_agent")
    workflow.add_edge("research_agent", "chart_generator")
    workflow.add_edge("chart_generator", "chart_summarizer")
    workflow.add_edge("chart_summarizer", END)

    workflow.add_edge("canned_end", END)

    compiled_graph = workflow.compile()

    return compiled_graph

In [None]:
class TruAgent:
    def __init__(self):
        self.graph = build_graph()

    @instrument(
        span_type=SpanAttributes.SpanType.RECORD_ROOT,
        attributes={
            SpanAttributes.RECORD_ROOT.INPUT: "query",
            SpanAttributes.RECORD_ROOT.OUTPUT: "return",
        },
    )
    def invoke_agent_graph(self, query: str) -> str:
        try:
            # rebuild the graph for each query
            self.graph = build_graph()
            # Initialize state with proper message format
            state = {"messages": [HumanMessage(content=query)]}

            # Stream events with recursion limit
            events = self.graph.stream(
                state,
                {"recursion_limit": 15},
            )

            # Track all messages through the conversation
            all_messages = []
            for event in events:
                # Get the payload from the event
                _, payload = next(iter(event.items()))
                if not payload:  # Skip empty payloads
                    continue

                messages = payload.get("messages")
                if not messages:
                    continue
                all_messages.extend(messages)

            # Return the last message's content if available
            return (
                all_messages[-1].content
                if all_messages and hasattr(all_messages[-1], "content")
                else ""
            )
        except:
            return "I ran into an issue, and cannot answer your question."


tru_agent = TruAgent()

tru_agent_app = TruApp(
    tru_agent,
    app_name=APP_NAME,
    app_version="doc, sql and web search",
    connector=trulens_sf_connector,
    main_method=tru_agent.invoke_agent_graph,
)

st_1 = datetime.datetime.fromtimestamp(time.time()).strftime(
    "%Y-%m-%d %H:%M:%S"
)

run_config = RunConfig(
    run_name="Multi-agent demo run - document, sql and web search",
    description="this is a run with access to cortex agent (search and analyst) and web search",
    dataset_name="Research test dataset",
    source_type="DATAFRAME",
    label="langgraph demo",
    dataset_spec={
        "RECORD_ROOT.INPUT": "query",
    },
)

run: Run = tru_agent_app.add_run(run_config)

In [None]:
run.start(input_df=user_queries_df)

In [None]:
import time

while run.get_status() == "INVOCATION_IN_PROGRESS":
    time.sleep(3)

run.compute_metrics(["groundedness", "context_relevance", "answer_relevance"])