#Tool-calling Agent

This is an auto-generated notebook created by an AI Playground export.

This notebook uses [Mosaic AI Agent Framework](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/build-genai-apps) to recreate your agent from the AI Playground. It  demonstrates how to develop, manually test, evaluate, log, and deploy a tool-calling agent in LangGraph.

The agent code implements [MLflow's ChatAgent](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.ChatAgent) interface, a Databricks-recommended open-source standard that simplifies authoring multi-turn conversational agents, and is fully compatible with Mosaic AI agent framework functionality.

 **_NOTE:_**  This notebook uses LangChain, but AI Agent Framework is compatible with any agent authoring framework, including LlamaIndex or pure Python agents written with the OpenAI SDK.

## Prerequisites

- Address all `TODO`s in this notebook.

In [0]:
%pip install -U -qqqq mlflow-skinny[databricks]==3.1.3 langgraph==0.3.4 langgraph-checkpoint-postgres databricks-langchain databricks-agents uv
dbutils.library.restartPython()

## Define the agent in code
Below we define our agent code in a single cell, enabling us to easily write it to a local Python file for subsequent logging and deployment using the `%%writefile` magic command.

For more examples of tools to add to your agent, see [docs](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/agent-tool).

In [0]:
%%writefile agent.py
from typing import Any, Generator, Optional, Sequence, Union, Dict

import mlflow
from databricks_langchain import (
    ChatDatabricks,
    VectorSearchRetrieverTool,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)
import os
import logging

mlflow.langchain.autolog()

################################################################
#### Lakebase Part
###############################################################

import uuid
import logging
from langgraph.checkpoint.postgres import PostgresSaver
from databricks.sdk import WorkspaceClient
import requests

def get_db_uri():

    w = WorkspaceClient(
        host=os.getenv("HOST_URL"),
        azure_tenant_id=os.getenv("AZURE_TENANT_ID"),
        azure_client_id=os.getenv("AZURE_CLIENT_ID"),
        azure_client_secret=os.getenv("AZURE_CLIENT_SECRET"),
        auth_type="azure-client-secret",
        )

    instance_name = "stateful-agent-backend"

    cred = w.database.generate_database_credential(
        request_id=str(uuid.uuid4()), 
        instance_names=[instance_name],
    )
    
    instance = w.database.get_database_instance(name=instance_name)

    DB_URI = (
        f"postgresql://{os.getenv("AZURE_CLIENT_ID")}:{cred.token}"
        f"@instance-75eabdf6-13f6-43a9-a9b8-d844c306d095.database.azuredatabricks.net:5432/"
        f"databricks_postgres?sslmode=require"
    )

    return DB_URI

checkpointer = PostgresSaver.from_conn_string(get_db_uri())

# some of this will need to be on the fly - only a one hour token??

################################################################

client = DatabricksFunctionClient()
set_uc_function_client(client)

############################################
# Define your LLM endpoint and system prompt
############################################
LLM_ENDPOINT_NAME = "databricks-claude-3-7-sonnet"
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

system_prompt = """"""

###############################################################################
## Define tools for your agent, enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/agent-tool
###############################################################################
tools = []

# You can use UDFs in Unity Catalog as agent tools
uc_tool_names = []
uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
tools.extend(uc_toolkit.tools)

# # (Optional) Use Databricks vector search indexes as tools
# # See https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/unstructured-retrieval-tools
# # for details
#
# # TODO: Add vector search indexes as tools or delete this block
# vector_search_tools = [
#         VectorSearchRetrieverTool(
#         index_name="",
#         # filters="..."
#     )
# ]
# tools.extend(vector_search_tools)

#####################
## Define agent logic
#####################

def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[Sequence[BaseTool], ToolNode],
    system_prompt: Optional[str] = None,
    checkpointer: Optional[PostgresSaver] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)

    # Define the function that determines which node to go to
    def should_continue(state: ChatAgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If there are function calls, continue. else, end
        if last_message.get("tool_calls"):
            return "continue"
        else:
            return "end"

    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}]
            + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)

        return {"messages": [response]}

    workflow = StateGraph(ChatAgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ChatAgentToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "continue": "tools",
            "end": END,
        },
    )
    workflow.add_edge("tools", "agent")
    if checkpointer:
        return workflow.compile(checkpointer=checkpointer)
    else:
        return workflow.compile()


class LangGraphChatAgent(ChatAgent):
    def __init__(self): #, agent: CompiledStateGraph):
        self.agent = None #agent

    def _make_config(
        self, 
        thread_id, 
        configurable_kwargs: Optional[dict] = None, 
        metadata_kwargs: Optional[dict] = None):
        if not configurable_kwargs:
            configurable_kwargs = {}
        if not metadata_kwargs:
            metadata_kwargs = {}
        configurable_dict = configurable_kwargs
        metadata_dict = metadata_kwargs
        configurable_dict.update({'thread_id': thread_id})
        config = {
            "configurable": configurable_dict,
            "metadata": metadata_dict
        }
        return config
    
    def handle_custom_inputs_as_config(self, custom_inputs: Optional[Dict[str,Any]] = None):
        if not custom_inputs:
            logging.warning('no custom inputs provided - will start a new thread_id')
            custom_inputs = dict()

        if "thread_id" not in custom_inputs:
            logging.warning('no thread_id provided, creating one')
            custom_inputs['thread_id'] = str(uuid.uuid4())
        else:
            logging.info(f'using thread_id {custom_inputs["thread_id"]}')

        thread_id = custom_inputs.pop("thread_id")
        metadata_dict = custom_inputs
        config = self._make_config(thread_id=thread_id, metadata_kwargs=metadata_dict)
        return config

    def get_history_of_thread(self, thread_id):
        db_uri = get_db_uri()
        with PostgresSaver.from_conn_string(db_uri) as checkpointer:
            agent = create_tool_calling_agent(llm, tools, system_prompt, checkpointer)

            history = list(
                agent.get_state_history(
                self.handle_custom_inputs_as_config(
                    custom_inputs= {
                    'thread_id' : thread_id,
                    })))

            if history:
                previous_messages = history[0].values['messages']
            else:
                previous_messages = []

        return previous_messages

    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[Dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        # self.get_agent(llm, tools, system_prompt)
        db_uri = get_db_uri()
        with PostgresSaver.from_conn_string(db_uri) as checkpointer:

            agent = create_tool_calling_agent(llm, tools, system_prompt, checkpointer)
            request = {"messages": self._convert_messages_to_dict(messages)}
            print("custom inputs = ", custom_inputs)
            config = self.handle_custom_inputs_as_config(custom_inputs=custom_inputs)

            messages = []
            for event in agent.stream(request, config, stream_mode="updates"):
                for node_data in event.values():
                    messages.extend(
                        ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                    )
            return ChatAgentResponse(messages=messages)

    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[Dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        db_uri = get_db_uri()
        with PostgresSaver.from_conn_string(db_uri) as checkpointer:
            agent = create_tool_calling_agent(llm, tools, system_prompt, checkpointer)

            request = {"messages": self._convert_messages_to_dict(messages)}
            print("custom inputs = ", custom_inputs)
            config = self.handle_custom_inputs_as_config(custom_inputs=custom_inputs)


            for event in agent.stream(request, config, stream_mode="updates"):
                for node_data in event.values():
                    yield from (
                        ChatAgentChunk(**{"delta": msg}) for msg in node_data["messages"]
                    )

# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
AGENT = LangGraphChatAgent() #agent)
mlflow.models.set_model(AGENT)

## Test the agent

Interact with the agent to test its output. Since this notebook called `mlflow.langchain.autolog()` you can view the trace for each step the agent takes.

Replace this placeholder input with an appropriate domain-specific example for your agent.

In [0]:
dbutils.library.restartPython()

### MODIFY THESE

In [0]:
import os
from dbruntime.databricks_repl_context import get_context

# TODO: set WORKSPACE_URL manually if it cannot be inferred from the current notebook
WORKSPACE_URL = None
if WORKSPACE_URL is None:
  workspace_url_hostname = get_context().workspaceUrl
  assert workspace_url_hostname is not None, "Unable to look up current workspace URL. This can happen if running against serverless compute. Manually set WORKSPACE_URL yourself above, or run this notebook against classic compute"
  WORKSPACE_URL = f"https://{workspace_url_hostname}"

# TODO: set secret_scope_name and secret_key_name to access your PAT
secret_scope_name = "ryuta"
secret_key_azure_tenant_id = "azure_tenant_id"
secret_key_azure_client_id = "azure_client_id"
secret_key_azure_client_secret = "azure_client_secret"

os.environ["HOST_URL"] = WORKSPACE_URL
os.environ["AZURE_TENANT_ID"] = dbutils.secrets.get(scope=secret_scope_name, key=secret_key_azure_tenant_id)
os.environ["AZURE_CLIENT_ID"] = dbutils.secrets.get(scope=secret_scope_name, key=secret_key_azure_client_id)
os.environ["AZURE_CLIENT_SECRET"] = dbutils.secrets.get(scope=secret_scope_name, key=secret_key_azure_client_secret)

In [0]:
import uuid
from agent import AGENT

thread_id = str(uuid.uuid4())

AGENT.predict(
  {
    "messages": [{"role": "user", "content": "Hello!"}],
    "custom_inputs" : {
      'user_id': 'ryuta.yoshimatsu@databricks.com',
      'thread_id': thread_id,
    }
  }
)

In [0]:
AGENT.predict(
  {
    "messages": [{"role": "user", "content": "What are some best places to visit in Japan in September?"}],
    "custom_inputs" : {
      'user_id': 'ryuta.yoshimatsu@databricks.com',
      'thread_id': thread_id,
    }
  }
)

In [0]:
AGENT.predict(
  {
    "messages": [{"role": "user", "content": "What else do you recommend?"}],
    "custom_inputs" : {
      'user_id': 'ryuta.yoshimatsu@databricks.com',
      'thread_id': thread_id,
    }
  }
)

In [0]:
AGENT.predict(
  {
    "messages": [{"role": "user", "content": "Give me more recommendations."}],
    "custom_inputs" : {
      'user_id': 'ryuta.yoshimatsu@databricks.com',
    }
  }
)

### Log the `agent` as an MLflow model
Determine Databricks resources to specify for automatic auth passthrough at deployment time
- **TODO**: If your Unity Catalog Function queries a [vector search index](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/unstructured-retrieval-tools) or leverages [external functions](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/external-connection-tools), you need to include the dependent vector search index and UC connection objects, respectively, as resources. See [docs](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/log-agent#specify-resources-for-automatic-authentication-passthrough) for more details.

Log the agent as code from the `agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code).

In [0]:
# Determine Databricks resources to specify for automatic auth passthrough at deployment time
import mlflow
from agent import LLM_ENDPOINT_NAME, tools
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint
from pkg_resources import get_distribution
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool

resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]

for tool in tools:
    if isinstance(tool, VectorSearchRetrieverTool):
        resources.extend(tool.resources)
    elif isinstance(tool, UnityCatalogTool):
        # TODO: If the UC function includes dependencies like external connection or vector search, please include them manually.
        # See the TODO in the markdown above for more information.
        resources.append(DatabricksFunction(function_name=tool.uc_function_name))

input_example = {
    "messages": [{"role": "user", "content": "Hello!"}],
    "custom_inputs" : {
      'user_id': 'ryuta.yoshimatsu@databricks.com',
    }
}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="agent",
        python_model="agent.py",
        input_example=input_example,
        resources=resources,
        pip_requirements=[
            f"databricks-connect=={get_distribution('databricks-connect').version}",
            f"mlflow=={get_distribution('mlflow').version}",
            f"databricks-langchain=={get_distribution('databricks-langchain').version}",
            f"langgraph=={get_distribution('langgraph').version}",
            f"pydantic=={get_distribution('pydantic').version}",
            f"langgraph-checkpoint-postgres=={get_distribution('langgraph-checkpoint-postgres').version}",
            f"psycopg[binary]=={get_distribution('psycopg').version}", # serving endpoints don't have this and so must add it
        ],
    )

## Evaluate the agent with [Agent Evaluation](https://learn.microsoft.com/azure/databricks/mlflow3/genai/eval-monitor/)

You can edit the requests or expected responses in your evaluation dataset and run evaluation as you iterate your agent, leveraging mlflow to track the computed quality metrics.

Evaluate your agent with one of our [predefined LLM scorers](https://learn.microsoft.com/azure/databricks/mlflow3/genai/eval-monitor/predefined-judge-scorers), or try adding [custom metrics](https://learn.microsoft.com/azure/databricks/mlflow3/genai/eval-monitor/custom-scorers).

In [0]:
# import mlflow
# from mlflow.genai.scorers import RelevanceToQuery, Safety, RetrievalRelevance, RetrievalGroundedness

# eval_dataset = [
#     {
#         "inputs": {
#             "messages": [{"role": "user", "content": "hello!"}],
#             "custom_inputs" : {
#             'user_id': 'peter.hawkins@databricks.com',
#             }
#         },
#         "expected_response": None
#     }
# ]

# eval_results = mlflow.genai.evaluate(
#     data=eval_dataset,
#     predict_fn=lambda messages: AGENT.predict({"messages": messages}),
#     scorers=[RelevanceToQuery(), Safety()], # add more scorers here if they're applicable
# )

# Review the evaluation results in the MLfLow UI (see console output)

## Register the model to Unity Catalog

Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
catalog = "ryuta"
schema = "agents"
model_name = "stateful_agent"

UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

## Deploy the agent

In [0]:
from databricks import agents
agents.deploy(
  UC_MODEL_NAME, 
  uc_registered_model_info.version, 
  tags = {"endpointSource": "playground"},
  environment_vars={
        "HOST_URL": f"{WORKSPACE_URL}",
        "AZURE_TENANT_ID": f"{{{{secrets/{secret_scope_name}/{secret_key_azure_tenant_id}}}}}",
        "AZURE_CLIENT_ID": f"{{{{secrets/{secret_scope_name}/{secret_key_azure_client_id}}}}}",
        "AZURE_CLIENT_SECRET": f"{{{{secrets/{secret_scope_name}/{secret_key_azure_client_secret}}}}}",    
    },
)

## Next steps

After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See [docs](https://learn.microsoft.com/azure/databricks/generative-ai/deploy-agent) for details