#Tool-calling Agent

This is an auto-generated notebook created by an AI Playground export.

This notebook uses [Mosaic AI Agent Framework](https://docs.databricks.com/generative-ai/agent-framework/build-genai-apps.html) to recreate your agent from the AI Playground. It  demonstrates how to develop, manually test, evaluate, log, and deploy a tool-calling agent in LangGraph.

The agent code implements [MLflow's ChatAgent](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.ChatAgent) interface, a Databricks-recommended open-source standard that simplifies authoring multi-turn conversational agents, and is fully compatible with Mosaic AI agent framework functionality.

 **_NOTE:_**  This notebook uses LangChain, but AI Agent Framework is compatible with any agent authoring framework, including LlamaIndex or pure Python agents written with the OpenAI SDK.

## Prerequisites

- Address all `TODO`s in this notebook.

In [0]:
%pip install -U -qqqq mlflow langchain langgraph==0.3.4 databricks-langchain pydantic databricks-agents unitycatalog-langchain[databricks] uv langchain-openai
dbutils.library.restartPython()

In [0]:
%load_ext autoreload
%autoreload 2
import warnings

warnings.simplefilter(action="ignore")

In [0]:
dbutils.widgets.text(name="catalog", defaultValue="", label="catalog")
dbutils.widgets.text(name="schema", defaultValue="", label="schema")
dbutils.widgets.text(name="model", defaultValue="serper_agent", label="model")
dbutils.widgets.text(name="DATABRICKS_TOKEN", defaultValue="", label="DATABRICKS_TOKEN")
dbutils.widgets.text(name="secret_scope", defaultValue="", label="secret_scope")
dbutils.widgets.text(name="SERPER_API_KEY", defaultValue="", label="SERPER_API_KEY")
dbutils.widgets.text(name="OPENAI_API_KEY", defaultValue="", label="OPENAI_API_KEY")

In [0]:
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
model = dbutils.widgets.get("model")
# LLM_ENDPOINT = dbutils.widgets.get("foundation_model")
assert (
    len(catalog) > 0 and len(schema) > 0 and len(model) > 0
), "Please provide a valid catalog, schema, and model name"
three_tiered_model_name = f"{catalog}.{schema}.{model}"
print(f"{three_tiered_model_name=}")

Please replace `existing secret scope` with your existing secret scope if you do not wish to create a new one.

In [0]:
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()

DATABRICKS_HOST = w.config.host

secret_scope_name = dbutils.widgets.get("secret_scope")

# if needed create a secret scope
if secret_scope_name != "existing secret scope":
    w.secrets.create_scope(scope=secret_scope_name)
else:
    print(f"Using existing secret scope: {secret_scope_name}")

In [0]:
# openai API key in Databricks Secrets
SECRETS_SCOPE = secret_scope_name
SECRETS_KEY = "DATABRICKS_TOKEN"
OPENAI_SECRETS_KEY = "OpenAIBo"

# if you need to add an OpenAI API key, you can do so with:

if dbutils.widgets.get("DATABRICKS_TOKEN") == "":
    print("no DATABRICKS_TOKEN is provided")
else:
    w.secrets.put_secret(
        scope=SECRETS_SCOPE,
        key=SECRETS_KEY,
        string_value=dbutils.widgets.get("DATABRICKS_TOKEN"),
    )

if dbutils.widgets.get("SERPER_API_KEY") == "":
    print("no SERPER_API_KEY is provided")
else:
    w.secrets.put_secret(
        scope=SECRETS_SCOPE,
        key="SERPER_API_KEY",
        string_value=dbutils.widgets.get("SERPER_API_KEY"),
    )

if dbutils.widgets.get("OPENAI_API_KEY") == "":
    print("no OPENAI_API_KEY is provided")
else:
    w.secrets.put_secret(
        scope=SECRETS_SCOPE,
        key=OPENAI_SECRETS_KEY,
        string_value=dbutils.widgets.get("OPENAI_API_KEY"),
    )

w.secrets.put_secret(
    scope=SECRETS_SCOPE, key="DATABRICKS_HOST", string_value=DATABRICKS_HOST
)

## Define the agent in code
Below we define our agent code in a single cell, enabling us to easily write it to a local Python file for subsequent logging and deployment using the `%%writefile` magic command.

For more examples of tools to add to your agent, see [docs](https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html).

In [0]:
import os

os.environ["OPENAI_API_KEY"] = dbutils.secrets.get(
    scope=SECRETS_SCOPE, key=OPENAI_SECRETS_KEY
)
os.environ["SERPER_API_KEY"] = dbutils.secrets.get(
    scope=SECRETS_SCOPE, key="SERPER_API_KEY"
)

Please replace `catalog`, `schema` with your UC catalog and schema names.

In [0]:
%sql
CREATE OR REPLACE FUNCTION catalog.schema.extract_member_id(
    question STRING COMMENT 'input patient question'
  )
  RETURNS STRING
  COMMENT 'Returns member_id given a patient question'
  RETURN AI_QUERY(
    'databricks-meta-llama-3-1-8b-instruct',
    CONCAT(
      '### Instruction:\nExtract only the member ID from the input text. '
      || 'Return only the member ID with no extra text.\n\n'
      || '### Input:\n',
      question,
      '\n\n### Response:'
    ),
    'STRING'
  )

In [0]:
%sql
CREATE OR REPLACE FUNCTION catalog.schema.extract_deductible(
    member_id STRING COMMENT 'Member id'
  )
  RETURNS TABLE
  COMMENT 'Returns member deductible details about a customer'
  RETURN
    SELECT
      mem_deductible,
      mem_ded_agg
    FROM
      catalog.schema.member_accumulators a
    WHERE
      a.member_id = member_id
    LIMIT 1;

In [0]:
%sql
-- Create a SQL function that calls the vector_search() AI Function
CREATE OR REPLACE FUNCTION catalog.schema.cpt_codes_vector_search(
    question STRING COMMENT 'The question string for searching insurance cpt codes and their descriptions'
  )
  RETURNS TABLE
  COMMENT 'Executes a search on cpt codes to retrieve text descriptions most relevant to the input question.'
  RETURN
    SELECT
      description,
      code
    FROM
      vector_search(
        index => 'catalog.schema.cpt_codes_index',
        query => question,
        num_results => 1
      );

In [0]:
# %sql
# SELECT
#   *
# FROM
#   catalog.schema.cpt_codes_vector_search(
#     'my member id is 1234. what is cost of hip replacement considering my IP address is 172.217.22.14?'
#   ) AS example

In [0]:
%sql
CREATE OR REPLACE FUNCTION catalog.schema.get_procedure_cost(
    procedure_code STRING COMMENT 'Procedure code'
  )
  RETURNS TABLE
  COMMENT 'Returns cost for a procedure code'
  RETURN
    SELECT
      cost as procedure_cost
    FROM
      catalog.schema.procedure_cost a
    WHERE
      a.procedure_code = procedure_code
    LIMIT 1;

Please replace `catalog`, `schema` in the below model as code with your UC catalog and schema names

In [0]:
%%writefile agent.py
from typing import Any, Generator, Optional, Sequence, Union
import mlflow
from databricks_langchain import (
    ChatDatabricks,
    VectorSearchRetrieverTool,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from langchain_community.utilities import GoogleSerperAPIWrapper
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool, Tool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)
from pydantic import BaseModel, Field

# testing
from langchain_openai import ChatOpenAI

mlflow.langchain.autolog()

client = DatabricksFunctionClient()
set_uc_function_client(client)

############################################
# Define your Schema definition for returning structured output
############################################
class ResponseFormatter(BaseModel):
    """Always use this tool to structure your response to the user."""

    answer: str = Field(
        description="The answer to the user's question without doctor's location data"
    )
    first_doctor: str = Field(description="The first recommended doctor's name")
    first_doctor_lattitude: str = Field(
        description="The first recommended doctor's latitude and direction"
    )
    first_doctor_longitude: str = Field(
        description="The first recommended doctor's longitude and direction"
    )
    # second_doctor: str = Field(description="The second recommended doctor's name")
    # second_doctor_lattitude: str = Field(
    #     description="The second recommended doctor's latitude and direction"
    # )
    # second_doctor_longitude: str = Field(
    #     description="The second recommended doctor's longitude and direction"
    # )


############################################
# Define your LLM endpoint and system prompt
############################################
# LLM_ENDPOINT_NAME = "databricks-claude-3-7-sonnet"
# llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME, temperature=0.1)

LLM_ENDPOINT_NAME = "gpt-4.1-mini-2025-04-14"
llm = ChatOpenAI(model_name=LLM_ENDPOINT_NAME, temperature=0.1)

system_prompt = """
    You are an healthcare policy Q&A agent. 
    You are given a task and you must complete it.
    Use the following routine to support the customer.
    # Routine:
    1. Use the extract_member_id tool to extract member id.
    2. Use member id from step 1 as input for the extract_deductible tool to get the member deductible and member deductible aggregate.
    3. Use the cpt_codes_vector_search tool to get the most similar code and description given the original question.
    4. Use code from step 3 as input for the get_procedure_cost tool to get the procedure cost.
    5. If you are provided an IP address convert the IP address into a location to provide location specific recommendations using the Intermediate_Answer tool.
    6. For the recommendations, please provide the latitude and longitude of the location 1 top recommended doctor related to the original question.
    7. Do not mention the IP address in your response.
    8. Following the ResponseFormatter summarize the member id, member deductible, member deductible aggregate, code, procedure cost in the answer
    9. Associate the 1 top recommended doctor with their own location related latitude and longitude along with direction and be as concise as possible in the other fields in the ResponseFormatter. 
    10. Use the output from the ResponseFormatter as the final answer.
    You can use the following tools to complete the task:
    {tools}"""

###############################################################################
## Define tools for your agent, enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html
###############################################################################
search = GoogleSerperAPIWrapper()
tools = [
    ResponseFormatter,
    Tool(
        name="Intermediate_Answer",
        func=search.run,
        description="useful for when you need to ask with search",
    ),
]

# You can use UDFs in Unity Catalog as agent tools
uc_tool_catalog = "catalog"
uc_tool_schema = "schema"
uc_tool_names = [
    f"{uc_tool_catalog}.{uc_tool_schema}.extract_member_id",
    f"{uc_tool_catalog}.{uc_tool_schema}.extract_deductible",
    f"{uc_tool_catalog}.{uc_tool_schema}.get_procedure_cost",
    f"{uc_tool_catalog}.{uc_tool_schema}.cpt_codes_vector_search",
]
uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
tools.extend(uc_toolkit.tools)

# vector_search_index_tools = [
#     VectorSearchRetrieverTool(
#         index_name="catalog.schema.cpt_codes_index",
#         num_results=1,
#         tool_name="cpt_codes_retriever",
#         tool_description="Retrieves information about cpt codes",
#         query_type="ANN",
#     )
# ]
# tools.extend(vector_search_index_tools)

# # (Optional) Use Databricks vector search indexes as tools
# # See https://docs.databricks.com/generative-ai/agent-framework/unstructured-retrieval-tools.html
# # for details
#
# # TODO: Add vector search indexes as tools or delete this block
# vector_search_tools = [
#         VectorSearchRetrieverTool(
#         index_name="",
#         # filters="..."
#     )
# ]
# tools.extend(vector_search_tools)


#####################
## Define agent logic
#####################


def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[Sequence[BaseTool], ToolNode],
    system_prompt: Optional[str] = None,
) -> CompiledGraph:
    # parallel_tool_calls parameter is currently only supported by OpenAI and Anthropic.
    model = model.bind_tools(tools, parallel_tool_calls=True)

    # Define the function that determines which node to go to
    def should_continue(state: ChatAgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If there are function calls, continue. else, end
        if last_message.get("tool_calls"):
            return "continue"
        else:
            return "end"

    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}]
            + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)

        return {"messages": [response]}

    workflow = StateGraph(ChatAgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ChatAgentToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "continue": "tools",
            "end": END,
        },
    )
    workflow.add_edge("tools", "agent")

    return workflow.compile()


class LangGraphChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        request = {"messages": self._convert_messages_to_dict(messages)}

        messages = []
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                messages.extend(
                    ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                )
        return ChatAgentResponse(messages=messages)

    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        request = {"messages": self._convert_messages_to_dict(messages)}
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                yield from (
                    ChatAgentChunk(**{"delta": msg}) for msg in node_data["messages"]
                )


# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
agent = create_tool_calling_agent(llm, tools, system_prompt)
AGENT = LangGraphChatAgent(agent)
mlflow.models.set_model(AGENT)

## Test the agent

Interact with the agent to test its output. Since this notebook called `mlflow.langchain.autolog()` you can view the trace for each step the agent takes.

Replace this placeholder input with an appropriate domain-specific example for your agent.

In [0]:
# dbutils.library.restartPython()

In [0]:
from agent import AGENT

AGENT.predict(
    {
        "messages": [
            {
                "role": "user",
                "content": "my member id is 1234. what is cost of hip replacement considering my IP address is 172.217.22.14?",
            }
        ]
    }
)

In [0]:
# for event in AGENT.predict_stream(
#     {"messages": [{"role": "user", "content": "What tables are related to 'supply chain'?"}]}
# ):
#     print(event, "-----------\n")

In [0]:
from IPython.display import Image, display

graph = AGENT.agent
display(Image(graph.get_graph().draw_mermaid_png()))

### Log the `agent` as an MLflow model
Determine Databricks resources to specify for automatic auth passthrough at deployment time
- **TODO**: If your Unity Catalog tool queries a [vector search index](https://docs.databricks.com/generative-ai/agent-framework/unstructured-retrieval-tools.html) or leverages [external functions](https://docs.databricks.com/generative-ai/agent-framework/external-connection-tools.html), you need to include the dependent vector search index and UC connection objects, respectively, as resources. See [docs](https://docs.databricks.com/generative-ai/agent-framework/log-agent.html#specify-resources-for-automatic-authentication-passthrough) for more details.

Log the agent as code from the `agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code).

In [0]:
# Determine Databricks resources to specify for automatic auth passthrough at deployment time
import mlflow
from agent import tools, LLM_ENDPOINT_NAME
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import (
    DatabricksFunction,
    DatabricksServingEndpoint,
    DatabricksTable,
    DatabricksVectorSearchIndex,
)
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool

# TODO: Manually include underlying resources if needed. See the TODO in the markdown above for more information.
uc_tool_catalog = "catalog"
uc_tool_schema = "schema"
resources = [
    # DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME),
    DatabricksVectorSearchIndex(
        index_name=f"{uc_tool_catalog}.{uc_tool_schema}.cpt_codes_index"
    ),
]
tables = ["cpt_codes", "member_accumulators", "procedure_cost"]
for table in tables:
    resources.append(
        DatabricksTable(table_name=f"{uc_tool_catalog}.{uc_tool_schema}.{table}")
    )
for tool in tools:
    if isinstance(tool, VectorSearchRetrieverTool):
        resources.extend(tool.resources)
    elif isinstance(tool, UnityCatalogTool):
        resources.append(DatabricksFunction(function_name=tool.uc_function_name))

input_example = {
    "messages": [
        {
            "role": "user",
            "content": "my member id is 1234. what is cost of hip replacement considering my IP address is 172.217.22.14?",
        }
    ]
}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        artifact_path="agent",
        python_model="agent.py",
        input_example=input_example,
        resources=resources,
        extra_pip_requirements=[
            "databricks-connect",
        ],
    )

## Evaluate the agent with [Agent Evaluation](https://docs.databricks.com/generative-ai/agent-evaluation/index.html)

You can edit the requests or expected responses in your evaluation dataset and run evaluation as you iterate your agent, leveraging mlflow to track the computed quality metrics.

To evaluate your tool calls, try adding [custom metrics](https://docs.databricks.com/generative-ai/agent-evaluation/custom-metrics.html#evaluating-tool-calls).

In [0]:
import pandas as pd

eval_examples = [
    {
        "request": {
            "messages": [
                {
                    "role": "user",
                    "content": "my member id is 1234. what is cost of hip replacement considering my IP address is 172.217.22.14?",
                }
            ]
        },
        "expected_response": None,
    }
]

eval_dataset = pd.DataFrame(eval_examples)
display(eval_dataset)

In [0]:
import mlflow

with mlflow.start_run(run_id=logged_agent_info.run_id):
    eval_results = mlflow.evaluate(
        f"runs:/{logged_agent_info.run_id}/agent",
        data=eval_dataset,  # Your evaluation dataset
        model_type="databricks-agent",  # Enable Mosaic AI Agent Evaluation
    )

In [0]:
# Review the evaluation results in the MLFLow UI (see console output), or access them in place:
eval_results.tables["eval_results"]

## Perform pre-deployment validation of the agent
Before registering and deploying the agent, we perform pre-deployment checks via the [mlflow.models.predict()](https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.predict) API. See [documentation](https://docs.databricks.com/machine-learning/model-serving/model-serving-debug.html#validate-inputs) for details

In [0]:
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data={
        "messages": [
            {
                "role": "user",
                "content": "my member id is 1234. what is cost of hip replacement considering my IP address is 172.217.22.14?",
            }
        ]
    },
    env_manager="uv",
)

## Register the model to Unity Catalog

Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
UC_MODEL_NAME = f"{catalog}.{schema}.{model}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri,
    name=UC_MODEL_NAME,
)

## Deploy the agent

In [0]:
from databricks import agents

agents.deploy(
    UC_MODEL_NAME,
    uc_registered_model_info.version,
    environment_vars={
        "DATABRICKS_TOKEN": "{{secrets/dbdemos/DATABRICKS_TOKEN}}",
        "DATABRICKS_HOST": "{{secrets/dbdemos/DATABRICKS_HOST}}",
        "SERPER_API_KEY": "{{secrets/dbdemos/SERPER_API_KEY}}",
        "OPENAI_API_KEY": "{{secrets/dbdemos/OpenAIBo}}",
    },
    tags={"endpointSource": "playground"},
)

In [0]:
from databricks import agents

# Note that <user_list> can specify individual users or groups.
agents.set_permissions(
    model_name=UC_MODEL_NAME,
    users=["users"],
    permission_level=agents.PermissionLevel.CAN_QUERY,
)

## Next steps

After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See [docs](https://docs.databricks.com/generative-ai/deploy-agent.html) for details

In [0]:
import time
from databricks.sdk.service.serving import EndpointStateReady, EndpointStateConfigUpdate
from databricks.sdk import WorkspaceClient

endpoint_name: str = f"agents_{catalog}-{schema}-{model}"
print("\nWaiting for endpoint to deploy.  This can take 10 - 20 minutes.", end="")
w = WorkspaceClient()
while (
    w.serving_endpoints.get(endpoint_name).state.ready == EndpointStateReady.NOT_READY
    or w.serving_endpoints.get(endpoint_name).state.config_update
    == EndpointStateConfigUpdate.IN_PROGRESS
):
    print(".", end="")
    time.sleep(30)

In [0]:
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()
endpoint_name: str = f"agents_{catalog}-{schema}-{model}"
response = w.serving_endpoints.query(
    name=endpoint_name,
    dataframe_records=[
        {
            "messages": [
                {
                    "role": "user",
                    "content": "my member id is 1234. what is cost of hip replacement considering my IP address is 172.217.22.14?",
                }
            ]
        }
    ],
    temperature=0.1,
)

In [0]:
[
    message
    for message in response.predictions.get("messages")
    if message.get("tool_calls") is not None
    and message.get("tool_calls")[0].get("function") is not None
    and message.get("tool_calls")[0].get("function").get("name") == "ResponseFormatter"
][0]

In [0]:
response.predictions.get("messages")[-3]

In [0]:
import json

json.loads(
    response.predictions.get("messages")[-3]
    .get("tool_calls")[0]
    .get("function")
    .get("arguments")
)

In [0]:
response.predictions.get('messages')[-1]

## AI Playground integration with agent endpoint