# Insurance Agent with UC Tools

* Framework: [OpenAI Agents SDK](https://openai.github.io/openai-agents-python/)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from databricks.connect import DatabricksSession
from dotenv import load_dotenv
import os
import mlflow
import warnings

#import from shared helper which s 2 level above
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir, os.pardir)))
from shared.helper import *

in_workspace = is_running_in_databricks()
global_config = get_global_config('../../global_config/databricks_config.yaml')

if not in_workspace:
    print("running locally ...")
    spark = DatabricksSession.builder.getOrCreate()
    load_dotenv('../../.env')

    DATABRICKS_HOST = os.getenv('host')
    DATABRICKS_TOKEN = os.getenv('token')
    OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')

    # local mlflow setup
    mlflow.set_tracking_uri("http://localhost:5000")
    mlflow.set_registry_uri("http://localhost:5000")
    mlflow.set_experiment("open_agent_sdk_experiments")
else:
    print("running in workspace ...")
    DATABRICKS_HOST = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
    DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)
    os.environ["OPENAI_API_KEY"] = dbutils.secrets.get(scope="databricks_token_qyu", key="OpenAi")
    mlflow.set_tracking_uri("databricks")
    mlflow.set_registry_uri("databricks-uc")
    mlflow.set_experiment("/Users/q.yu@databricks.com/ML_experiments/insurance_operator_openai_agent_v2")

print(f"host: {DATABRICKS_HOST}")
warnings.filterwarnings("ignore", category=UserWarning)
mlflow.openai.autolog()

INFO:helper:{'catalog': 'fins_genai', 'schema_agents': 'agents', 'schema_insurance_agents': 'insurance_agent', 'llm_endpoint_llama3p3_70B': 'databricks/databricks-meta-llama-3-3-70b-instruct', 'llm_endpoint_llama4_maverick': 'databricks/databricks-llama-4-maverick', 'llm_endpoint_sonnet_3p7': 'databricks/databricks-claude-3-7-sonnet'}
INFO:databricks.sdk:loading DEFAULT profile from ~/.databrickscfg: host, serverless_compute_id, connection_timeout_seconds, token


running locally ...
host: adb-984752964297111.11.azuredatabricks.net


## Define Tools

In [3]:
from pydantic import BaseModel

class UserInfo(BaseModel):
    cust_id: str | None = None
    policy_no: str | None = None
    conversation_id: str | None = None
    user_id: str | None = None

In [4]:
from unitycatalog.ai.core.databricks import (
    DatabricksFunctionClient,
    FunctionExecutionResult,
)
from agents import function_tool, RunContextWrapper

catalog = global_config.get("catalog")
schema = global_config.get("schema_insurance_agents")

@function_tool
def search_claims_details_by_policy_no(wrapper: RunContextWrapper[UserInfo]) -> FunctionExecutionResult:
    print("[DEBUG]: the 'search_claims_details_by_policy_no' tool was called")
    client = DatabricksFunctionClient()
    return client.execute_function(
        function_name=f"{catalog}.{schema}.search_claims_details_by_policy_no",
        parameters={"input_policy_no": wrapper.context.policy_no},
    )


@function_tool
def policy_docs_vector_search(query: str) -> FunctionExecutionResult:
    print("[DEBUG]: the 'policy_docs_vector_search' tool was called")
    client = DatabricksFunctionClient()
    return client.execute_function(
        function_name=f"{catalog}.{schema}.policy_docs_vector_search",
        parameters={"query": query},
    )

## Create Agents

In [5]:
# If you want custom model hosted outside of OpenAI

from openai import AsyncOpenAI
from agents import OpenAIChatCompletionsModel, set_tracing_disabled

# You can replace 'gtp-4o' with the MODEL variable in the Agent definition
MODEL = "databricks-claude-3-7-sonnet"

client = AsyncOpenAI(
    base_url=f"https://{DATABRICKS_HOST}/serving-endpoints",
    api_key=DATABRICKS_TOKEN,
)

The instruction field is the prompt for the agent. We’ve added a `RECOMMENDED_PROMPT_PREFIX` object from OpenAI. This is a optional “hand-off” prompt, and we found it to work fairly well. The tools field is where you designate a list of tools that the agent can use to complete the task specified in the instruction. The model field is where you specify the LLM that powers the agent. It works natively with all OpenAI models, provided that you have access to them

In [6]:
from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX

RECOMMENDED_PROMPT_PREFIX

'# System context\nYou are part of a multi-agent system called the Agents SDK, designed to make agent coordination and execution easy. Agents uses two primary abstraction: **Agents** and **Handoffs**. An agent encompasses instructions and tools and can hand off a conversation to another agent when appropriate. Handoffs are achieved by calling a handoff function, generally named `transfer_to_<agent_name>`. Transfers between agents are handled seamlessly in the background; do not mention or draw attention to these transfers in your conversation with the user.\n'

In [7]:
from agents import Agent, Runner, set_tracing_disabled

# You can turn off trace by setting this to True
set_tracing_disabled(disabled=False)

claims_detail_retrieval_agent = Agent[UserInfo](
    name="Claims Details Retrieval Agent",
    instructions=(
        f"{RECOMMENDED_PROMPT_PREFIX}"
        "You are a claims details retrieval agent. "
        "If you are speaking to a customer, you probably were transferred to you from the triage agent. "
        "Use the following routine to support the customer. \n"
        "# Routine: \n"
        "1. Identify the last question asked by the customer. \n"
        "2. Use the search tools to retrieve data about a claim. Do not rely on your own knowledge. \n"
        "3. If you cannot answer the question, transfer back to the triage agent. \n"
    ),
    tools=[
        search_claims_details_by_policy_no,
    ],
    model="gpt-4o",
    # model=OpenAIChatCompletionsModel(model=MODEL, openai_client=client),
)

policy_qa_agent = Agent[UserInfo](
    name="Policy Q&A Agent",
    instructions=(
        f"{RECOMMENDED_PROMPT_PREFIX}"
        "You are an insurance policy Q&A agent. "
        "If you are speaking to a customer, you probably were transferred to you from the triage agent. "
        "Use the following routine to support the customer.\n"
        "# Routine: \n"
        "1. Identify the last question asked by the customer. \n"
        "2. Use the search tools to answer the question about their policy. Do not rely on your own knowledge. \n"
        "3. If you cannot answer the question, transfer back to the triage agent. \n"
    ),
    tools=[policy_docs_vector_search],
    model="gpt-4o",
    # model=OpenAIChatCompletionsModel(model=MODEL, openai_client=client),
)

triage_agent = Agent[UserInfo](
    name="Triage agent",
    instructions=(
        f"{RECOMMENDED_PROMPT_PREFIX}"
        "You are a helpful triaging agent. "
        "You can use your tools to delegate questions to other appropriate agents. "
        "If the customer does not have anymore questions, wish them a goodbye and a good rest of their day. "
    ),
    # handoffs=[customer_verifier_agent, claims_detail_retrieval_agent, policy_qa_agent],
    handoffs=[claims_detail_retrieval_agent, policy_qa_agent],
    model="gpt-4o",
    # model=OpenAIChatCompletionsModel(model=MODEL, openai_client=client),
)


## Experiment a Chat loop with Mlflow

Sample Conversation:
* "hi, id like to check on my existing claims"
* "here's the policy number: 102070455"
* "sure. id like to ask a different question. does my policy cover towing and labor costs?"


In [9]:
# Input some user data as context
user_info = UserInfo(cust_id="7852", policy_no="102070455", coversation_id="123", user_id="123")
user_input = "[USER]: I'like to check on my existing claims"

In [14]:
# Start a chat span
with mlflow.start_span(name="insurance_agent", span_type="AGENT") as span:
    print("[AGENT] Hello! How may I assist you?")
    result = None
    while True:
        user_input = input("[USER]: ")
        if user_input.lower() == "exit":
            print("[AGENT]: Bye!")
            break
        if not user_input:
            continue
        try:
            if result:
                user_input = result.to_input_list() + [{"role": "user", "content": user_input}]
            result = await Runner.run(
                starting_agent=triage_agent, input=user_input, context=user_info
            )
            print("\n[AGENT]:", result.final_output)
        except Exception as e:
            print(f"\nError occurred: {str(e)}")

[AGENT] Hello! How may I assist you?


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/responses "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/responses "HTTP/1.1 200 OK"



[AGENT]: To check an existing claim, please provide your policy number.


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/responses "HTTP/1.1 200 OK"



[AGENT]: Thank you! Let me retrieve the details for policy number 102070455. One moment, please.


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/responses "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/responses "HTTP/1.1 200 OK"
INFO:databricks.sdk:loading DEFAULT profile from ~/.databrickscfg: host, serverless_compute_id, connection_timeout_seconds, token


[DEBUG]: the 'search_claims_details_by_policy_no' tool was called


INFO:unitycatalog.ai.core.databricks:Using databricks connect to execute functions with serverless compute.
INFO:unitycatalog.ai.core.utils.retry_utils:Successfully re-acquired connection to a serverless instance.
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/responses "HTTP/1.1 200 OK"



[AGENT]: Here are the claim details for policy number 102070455:

1. **Claim Number:** f82bf191-fb74-483c-81c6-8ae9c5cdd778
   - **Claim Date:** 2017-01-09
   - **Incident Date:** 2017-01-06
   - **Incident Type:** Single Vehicle Collision
   - **Collision Type:** Side Collision
   - **Incident Severity:** Minor Damage
   - **Driver Age:** 28
   - **Driver Insured Relationship:** Husband
   - **Claim Amounts:** 
     - Injury: $15,500
     - Property: $7,750
     - Vehicle: $62,000
     - Total: $85,250

2. **Claim Number:** 1c1187c0-d6fb-44ca-9b89-19d6b37ed015
   - **Claim Date:** 2017-01-09
   - **Incident Date:** 2017-01-07
   - **Incident Type:** Single Vehicle Collision
   - **Collision Type:** Front Collision
   - **Incident Severity:** Minor Damage
   - **Driver Age:** 43
   - **Driver Insured Relationship:** Other relative
   - **Claim Amounts:** 
     - Injury: $7,530
     - Property: $15,060
     - Vehicle: $60,240
     - Total: $82,830

If you need further details, feel fre

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/responses "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/responses "HTTP/1.1 200 OK"



[AGENT]: Your policy number is 102070455. If you have any more questions about your policy, feel free to ask!
[AGENT]: Bye!


# Write the agent to a file

In [8]:
%%writefile insurance_chat_agent.py
from typing import Any, List, Optional, Dict, Generator
from mlflow.pyfunc import ChatAgent
from mlflow.entities import SpanType
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)
import mlflow
from uuid import uuid4
import asyncio
from pydantic import BaseModel
from unitycatalog.ai.core.databricks import (
    DatabricksFunctionClient,
    FunctionExecutionResult,
)
from agents import function_tool, RunContextWrapper
from agents import Agent, Runner, set_tracing_disabled
from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("insurance_chat_agent")

# os.environ["OPENAI_API_KEY"] = "{{secrets/databricks_token_qyu/OpenAi}}"
mlflow.openai.autolog()

class UserInfo(BaseModel):
    cust_id: str | None = None
    policy_no: str | None = None
    conversation_id: str | None = None
    user_id: str | None = None


@function_tool
def search_claims_details_by_policy_no(wrapper: RunContextWrapper[UserInfo], policy_no: str) -> FunctionExecutionResult:
    logger.info("The 'search_claims_details_by_policy_no' tool was called")
    wrapper.context.policy_no = policy_no
    client = DatabricksFunctionClient()
    return client.execute_function(
        function_name="ai.insurance_agent.search_claims_details_by_policy_no",
        parameters={"input_policy_no": wrapper.context.policy_no},
    )


@function_tool
def policy_docs_vector_search(query: str) -> FunctionExecutionResult:
    logger.info("The 'policy_docs_vector_search' tool was called")
    client = DatabricksFunctionClient()
    return client.execute_function(
        function_name="ai.insurance_agent.policy_docs_vector_search",
        parameters={"query": query},
    )

set_tracing_disabled(disabled=False)

claims_detail_retrieval_agent = Agent[UserInfo](
    name="Claims Details Retrieval Agent",
    instructions=(
        f"{RECOMMENDED_PROMPT_PREFIX}"
        "You are a claims details retrieval agent. "
        "If you are speaking to a customer, you probably were transferred to you from the triage agent. "
        "Use the following routine to support the customer. \n"
        "# Routine: \n"
        "1. Identify the last question asked by the customer. \n"
        "2. Use the search tools to retrieve data about a claim. Do not rely on your own knowledge. \n"
        "3. If you cannot answer the question, transfer back to the triage agent. \n"
    ),
    tools=[
        search_claims_details_by_policy_no,
    ],
    model="gpt-4o",
)

policy_qa_agent = Agent[UserInfo](
    name="Policy Q&A Agent",
    instructions=(
        f"{RECOMMENDED_PROMPT_PREFIX}"
        "You are an insurance policy Q&A agent. "
        "If you are speaking to a customer, you probably were transferred to you from the triage agent. "
        "Use the following routine to support the customer.\n"
        "# Routine: \n"
        "1. Identify the last question asked by the customer. \n"
        "2. Use the search tools to answer the question about their policy. Do not rely on your own knowledge. \n"
        "3. If you cannot answer the question, transfer back to the triage agent. \n"
    ),
    tools=[policy_docs_vector_search],
    model="gpt-4o",
)

triage_agent = Agent[UserInfo](
    name="Triage agent",
    instructions=(
        f"{RECOMMENDED_PROMPT_PREFIX}"
        "You are a helpful triaging agent. "
        "You can use your tools to delegate questions to other appropriate agents. "
        "If the customer does not have anymore questions, wish them a goodbye and a good rest of their day. "
    ),
    handoffs=[claims_detail_retrieval_agent, policy_qa_agent],
    model="gpt-4o",
)

class InsuranceChatAgent(ChatAgent):
    def __init__(self, starting_agent: Agent):
        self.starting_agent = starting_agent
        self.conversation_state = {}

    def _get_or_create_conversation_state(self, conversation_id: str):
        """Get or create the state for a conversation"""
        if conversation_id not in self.conversation_state:
            self.conversation_state[conversation_id] = {
                "current_agent": self.starting_agent,
                "conversation_history": None
            }
        return self.conversation_state[conversation_id]

    def _get_latest_user_message(selfself, messages: List[ChatAgentMessage]) -> str:
        """Extract the most recent user messages as input text"""
        for message in reversed(messages):
            if message.role == "user":
                return message.content
            return ""

    def _create_user_context(
            self,
            context: Optional[ChatContext] = None,
            custom_inputs: Optional[Dict[str, Any]] = None
        ) -> UserInfo:
        """Convert MLflow inputs to UserInfo object"""
        user_info = UserInfo()

        if context:
            conversation_id = getattr(context, "conversation_id", None)
            if conversation_id:
                user_info.conversation_id = conversation_id

            user_id = getattr(context, "user_id", None)
            if user_id:
                user_info.user_id = user_id

        return user_info

    @mlflow.trace(name="insurance_chat_agent", span_type=SpanType.AGENT)
    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[Dict[str, Any]] = None
    ) -> ChatAgentResponse:
        user_info = self._create_user_context(context, custom_inputs)
        conversation_id = user_info.conversation_id

        # Get the state for this conversation
        state = self._get_or_create_conversation_state(conversation_id)
        current_agent = state["current_agent"]
        conversation_history = state["conversation_history"]


        # Get the latest user message
        latest_message = self._get_latest_user_message(messages)

        # Prepare the input for the agent
        if conversation_history is None:
            # First turn, just use the latest message
            agent_input = latest_message
        else:
            # Add the new user message to the conversation history
            conversation_history.append({
                "role": "user",
                "content": latest_message
            })
            agent_input = conversation_history

        # Run the agent use asyncio
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        try:
            result = loop.run_until_complete(
                Runner.run(
                    starting_agent=self.starting_agent,
                    input=agent_input,
                    context=user_info,
                )
            )
            # Update the state for the next turn
            # Store the updated conversation history from the result
            state["conversation_history"] = result.to_input_list()

            # Update the current agent based on which agent was last used
            if hasattr(result, "last_agent") and result.last_agent:
                state["current_agent"] = result.last_agent

        finally:
            loop.close()

        # Convert the result to ChatAgentResponse format:
        return ChatAgentResponse(
            messages=[
                ChatAgentMessage(
                    role="assistant",
                    content=result.final_output,
                    id=str(uuid4())
                )
            ]
        )

    @mlflow.trace(name="insurance_change_agent_stream", span_type=SpanType.AGENT)
    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[Dict[str, Any]] = None
    ) -> Generator[ChatAgentResponse, None, None]:
        response = self.predict(messages, context, custom_inputs)

        # Yield it as a single chunk
        for message in response.messages:
            yield ChatAgentChunk(delta=message)

AGENT = InsuranceChatAgent(starting_agent=triage_agent)
mlflow.models.set_model(AGENT)

Writing insurance_chat_agent.py


# Load Agent, Run Validation, log and Register the Agent Model

In [10]:
from insurance_chat_agent import AGENT
import os
import nest_asyncio

nest_asyncio.apply()

AGENT.predict(
    {
        "messages": [
            {
                "role": "user",
                "content": "hi, id like to check on my existing claims and my policy number: 102070455",
            }
        ],
        "context": {"conversation_id": "123", "user_id": "123"},
    }
)

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/responses "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/responses "HTTP/1.1 200 OK"
INFO:insurance_chat_agent:The 'search_claims_details_by_policy_no' tool was called
INFO:databricks.sdk:loading DEFAULT profile from ~/.databrickscfg: host, serverless_compute_id, connection_timeout_seconds, token
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/responses "HTTP/1.1 200 OK"


ChatAgentResponse(messages=[ChatAgentMessage(role='assistant', content='I’m unable to retrieve your claim details at the moment due to a system error. Please try again later or contact customer support for immediate assistance.', name=None, id='c6e1724e-b9de-4628-a0d1-8343a0c68b00', tool_calls=None, tool_call_id=None, attachments=None)], finish_reason=None, custom_outputs=None, usage=None)

In [None]:
AGENT.predict({
        "messages": [{"role": "user", "content": "does my policy cover towing and labor costs?"}],
        "context": {"conversation_id": "123", "user_id": "123"}
})

## Log the agent

In [None]:
import mlflow
import os
from mlflow.models.resources import (
    DatabricksFunction,
    DatabricksServingEndpoint,
    DatabricksVectorSearchIndex)
from unitycatalog.ai.openai.toolkit import UCFunctionToolkit
import nest_asyncio

nest_asyncio.apply()
os.environ["OPENAI_API_KEY"] = dbutils.secrets.get(
    scope="my_secret_scope", key="OpenAi"
)

resources = [
    DatabricksVectorSearchIndex(
        index_name="ai.agents.policy_docs_chunked_files_vs_index"
    ),
    DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
    DatabricksFunction(
        function_name="ai.insurance_agent.search_claims_details_by_policy_no"
    ),
    DatabricksFunction(
        function_name="ai.insurance_agent.policy_docs_vector_search"
    )
]

mlflow.set_experiment(f"/Users/{my_databricks_account}/ML_experiments/insurance_chat_agent")
mlflow.openai.autolog()

# Log the model using the "models from code" approach
with mlflow.start_run():
    logged_model_info = mlflow.pyfunc.log_model(
        artifact_path="insurance_chat_agent",
        python_model=os.path.join(os.getcwd(), "insurance_chat_agent.py"),
        input_example={
            "messages": [
                {
                    "role": "user",
                    "content": "hi, id like to check on my existing claims?",
                }
            ],
            "context": {"conversation_id": "123", "user_id": "123"},
        },
        pip_requirements=[
            "mlflow",
            "openai-agents",
            "unitycatalog-openai[databricks]==0.2.0",
            "pydantic",
        ],
        resources=resources
    )