#Tool-calling Agent

This is an auto-generated notebook created by an AI Playground export modified with additional features for evaluation. 

This notebook uses [Mosaic AI Agent Framework](https://docs.databricks.com/generative-ai/agent-framework/build-genai-apps.html) to recreate your agent from the AI Playground. It  demonstrates how to develop, manually test, evaluate, log, and deploy a tool-calling agent in LangGraph.

The agent code implements [MLflow's ChatAgent](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.ChatAgent) interface, a Databricks-recommended open-source standard that simplifies authoring multi-turn conversational agents, and is fully compatible with Mosaic AI agent framework functionality.

 **_NOTE:_**  This notebook uses LangChain, but AI Agent Framework is compatible with any agent authoring framework, including LlamaIndex or pure Python agents written with the OpenAI SDK.

## Prerequisites

- All `TODO`s in this notebook you can skip. We already make them work for this example.
- They are for later expansion, e.g., different tools.

In [0]:
%pip install -qqqq -U -r requirements.txt

dbutils.library.restartPython()

## Define the agent in code
Below we define our agent code in a single cell, enabling us to easily write it to a local Python file for subsequent logging and deployment using the `%%writefile` magic command.

For more examples of tools to add to your agent, see [docs](https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html).

In [0]:
import yaml

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

source_catalog_name = config["source_catalog_name"]
source_schema_name = config["source_schema_name"]
target_catalog_name = config["target_catalog_name"]
target_schema_name = config["target_schema_name"]
service_date = config["service_date"]
claim_id = config["claim_id"]
patient_id = config["patient_id"]
diagnosis_code = config["diagnosis_code"]
ndc_code = config["ndc_code"]
model_uc_name = config["model_uc_name"]
alias = config["alias"]
endpoint_name = config["endpoint_name"]
experiment_path = config['experiment_path']

In [0]:
%%writefile agent.py
from typing import Any, Generator, Optional, Sequence, Union

import mlflow
from databricks_langchain import (
    ChatDatabricks,
    VectorSearchRetrieverTool,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)

mlflow.langchain.autolog()

client = DatabricksFunctionClient()
set_uc_function_client(client)

############################################
# Define your LLM endpoint and system prompt
############################################

LLM_ENDPOINT_NAME = "databricks-claude-3-7-sonnet"
# LLM_ENDPOINT_NAME = "databricks-meta-llama-3-3-70b-instruct"
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

system_prompt = """You are a healthcare assistant specialized in analyzing real-world healthcare data. You will be asked questions about patient enrollment, medical claims, pharmacy claims, diagnoses, and procedures from HealthVerity's healthcare dataset. You should only answer questions relevant to this topic and should politely decline to answer any off topic questions. Be concise and clear - no need to repeat the question.

Use the tools at your disposal to answer the user's question. If you don't know the answer, say so. If the tools fail to execute, say so, and say why if you can. If it isn't clear which tool should be used, ask the user and summarize the tools that you can use.

Available tools include:
- get_patient_enrollment: Get patient demographics and enrollment information
- get_medical_claims: Get medical claims for a patient on a specific service date
- get_patient_diagnoses: Get all diagnosis codes for a patient
- get_pharmacy_claims: Get pharmacy claims and medication history for a patient
- get_patient_procedures: Get procedure codes and details for a patient

If there is a request including a DATE, please always return the full date when possible. 
"""

###############################################################################
## Define tools for your agent, enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html
###############################################################################
tools = []

#You can use UDFs in Unity Catalog as agent tools
# HealthVerity clinical assistant tools
uc_tool_names = [f"{target_catalog_name}.{target_schema_name}.*"]
uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
tools.extend(uc_toolkit.tools)

# # (Optional) Use Databricks vector search indexes as tools
# # See https://docs.databricks.com/generative-ai/agent-framework/unstructured-retrieval-tools.html
# # for details
#
# # TODO: Add vector search indexes as tools or delete this block
# vector_search_tools = [
#         VectorSearchRetrieverTool(
#         index_name="",
#         # filters="..."
#     )
# ]
# tools.extend(vector_search_tools)


#####################
## Define agent logic
#####################

def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[Sequence[BaseTool], ToolNode],
    system_prompt: Optional[str] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)

    # Define the function that determines which node to go to
    def should_continue(state: ChatAgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If there are function calls, continue. else, end
        if last_message.get("tool_calls"):
            return "continue"
        else:
            return "end"

    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}]
            + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    @mlflow.trace(name="agent_call_model")
    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)
        return {"messages": [response]}

    # Create a custom tool node with tracing
    @mlflow.trace(name="agent_tool_execution")
    def call_tools(state: ChatAgentState, config: RunnableConfig):
        tool_node = ChatAgentToolNode(tools)
        return tool_node.invoke(state, config)

    workflow = StateGraph(ChatAgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", RunnableLambda(call_tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "continue": "tools",
            "end": END,
        },
    )
    workflow.add_edge("tools", "agent")

    return workflow.compile()


class LangGraphChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    @mlflow.trace(name="agent_predict")
    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        request = {"messages": self._convert_messages_to_dict(messages)}

        messages = []
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                messages.extend(
                    ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                )
        return ChatAgentResponse(messages=messages)

    @mlflow.trace(name="agent_predict_stream")
    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        request = {"messages": self._convert_messages_to_dict(messages)}
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                yield from (
                    ChatAgentChunk(**{"delta": msg}) for msg in node_data["messages"]
                )


# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
agent = create_tool_calling_agent(llm, tools, system_prompt)
AGENT = LangGraphChatAgent(agent)
mlflow.models.set_model(AGENT)

In [0]:
added_var_content = f"""target_catalog_name = '{target_catalog_name}'\n
target_schema_name = '{target_schema_name}'\n
"""
with open("agent.py", "r") as f:
    existing_content = f.read()

with open("agent.py", "w") as f:
    f.write(added_var_content + existing_content)

## Test the agent

Interact with the agent to test its output. Since this notebook called `mlflow.langchain.autolog()` you can view the trace for each step the agent takes.

Replace this placeholder input with an appropriate domain-specific example for your agent.

In [0]:
dbutils.library.restartPython()

In [0]:
import yaml

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

source_catalog_name = config["source_catalog_name"]
source_schema_name = config["source_schema_name"]
target_catalog_name = config["target_catalog_name"]
target_schema_name = config["target_schema_name"]
service_date = config["service_date"]
claim_id = config["claim_id"]
patient_id = config["patient_id"]
diagnosis_code = config["diagnosis_code"]
ndc_code = config["ndc_code"]
model_uc_name = config["model_uc_name"]
alias = config["alias"]
endpoint_name = config["endpoint_name"]
experiment_path = config['experiment_path']

Note:
**Mlflow experiement needs to be registered into your workspace folder (use absolute path like below). Repo folder wont work.**

In [0]:
import mlflow

#: plan a, setup with default artifact location on managed mlflow-tracking server
experiment_info = mlflow.set_experiment(experiment_path)

#: plan b, setup with your DIY, use mlflow.create_experiment()
# Create with custom configuration
# experiment_path = mlflow.create_experiment(
#     "production-models",
#     artifact_location="s3://my-bucket/experiments/",
#     tags={"team": "data-science", "environment": "prod"},
# )

In [0]:
experiment_info

In [0]:
from agent import AGENT

AGENT.predict({"messages": [{"role": "user", "content": f"What is the healthcare journey for patient {patient_id}?"}]})

In [0]:
for event in AGENT.predict_stream(
    {"messages": [{"role": "user", "content": f"What is the latest medication has patient {patient_id} been prescribed? What is its diagnosis code?"}]}
):
    print(event, "-----------\n")

### Log the `agent` as an MLflow model
Determine Databricks resources to specify for automatic auth passthrough at deployment time
- **TODO**: If your Unity Catalog Function queries a [vector search index](https://docs.databricks.com/generative-ai/agent-framework/unstructured-retrieval-tools.html) or leverages [external functions](https://docs.databricks.com/generative-ai/agent-framework/external-connection-tools.html), you need to include the dependent vector search index and UC connection objects, respectively, as resources. See [docs](https://docs.databricks.com/generative-ai/agent-framework/log-agent.html#specify-resources-for-automatic-authentication-passthrough) for more details.

Log the agent as code from the `agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code).

In [0]:
from agent import tools, LLM_ENDPOINT_NAME

In [0]:
tools

In [0]:
LLM_ENDPOINT_NAME

In [0]:
# Determine Databricks resources to specify for automatic auth passthrough at deployment time
import mlflow
from agent import tools, LLM_ENDPOINT_NAME
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool

resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
for tool in tools:
    if isinstance(tool, VectorSearchRetrieverTool):
        resources.extend(tool.resources)
    elif isinstance(tool, UnityCatalogTool):
        # TODO: If the UC function includes dependencies like external connection or vector search, please include them manually.
        # See the TODO in the markdown above for more information.
        resources.append(DatabricksFunction(function_name=tool.uc_function_name))

input_example = {
    "messages": [
        {
            "role": "user",
            "content": f"What is the healthcare journey for patient {patient_id}?"
        }
    ]
}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        artifact_path="agent",
        python_model="agent.py",
        input_example=input_example,
        resources=resources,
        pip_requirements="requirements.txt"
    )

## Perform pre-deployment validation of the agent
Before registering and deploying the agent, we perform pre-deployment checks via the [mlflow.models.predict()](https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.predict) API. See [documentation](https://docs.databricks.com/machine-learning/model-serving/model-serving-debug.html#validate-inputs) for details

In [0]:
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data={"messages": [{"role": "user", "content": f"What enrollment information do you have for patient {patient_id}?"}]},
    env_manager="uv",
)

## Register the model to Unity Catalog

Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.

In [0]:
# %pip install mlflow[databricks]

In [0]:
mlflow.set_registry_uri("databricks-uc")

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=model_uc_name
)

In [0]:
from mlflow import MlflowClient
client = MlflowClient()

client.set_registered_model_alias(model_uc_name, "Champion", uc_registered_model_info.version)

## Deploy the agent

In [0]:
from databricks import agents
agents.deploy(model_uc_name, uc_registered_model_info.version, tags = {"RemoveAfter": "10-31-2025"})

## Next steps

After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See [docs](https://docs.databricks.com/generative-ai/deploy-agent.html) for details

## We can also work with a Genie room with optimized Text-2-SQL on UC tables augmented with AI generated table metadata /comments!

In [0]:
from IPython.display import HTML, display

# Create HTML link to AI Playground
html_link = f'<a href="https://e2-demo-field-eng.cloud.databricks.com/genie/rooms/01f0574f6d4f11558af894a950b8bf19" target="_blank">Go to Genie Room</a>'
display(HTML(html_link))

## We can also do Key Information Extraction using Agent Bricks on all of our clinical notes. 

In [0]:
# Create HTML link to AI Playground
html_link = f'<a href="https://e2-demo-field-eng.cloud.databricks.com/ml/bricks/kie/use/9a577aad-4adc-4a33-9dbb-88883e6c66b8?o=1444828305810485">Go to Agent Bricks</a>'
display(HTML(html_link))