In [0]:
import requests, base64, json, re
from databricks.sdk import WorkspaceClient 

dbutils.widgets.text("catalog_name", "", "Catalog (required)")
dbutils.widgets.text("schema_name", "", "Schema")
CATALOG_NAME = dbutils.widgets.get("catalog_name").strip()
SCHEMA_NAME = dbutils.widgets.get("schema_name").strip()

w = WorkspaceClient()
CP_URL_RAW = w.secrets.get_secret(scope="shscreds", key="cpurl").value
CP_URL = base64.b64decode(CP_URL_RAW).decode('utf-8').strip().rstrip('/')
DP_URL_RAW = w.secrets.get_secret(scope="shscreds", key="dpurl").value
DP_URL = base64.b64decode(CP_URL_RAW).decode('utf-8').strip().rstrip('/')

In [0]:
%pip install -U -qqqq backoff databricks-openai openai-agents uv databricks-agents mlflow-skinny[databricks]
dbutils.library.restartPython()

In [0]:
import json

# 1. Get Variables
CATALOG = dbutils.widgets.get("catalog_name").strip()
SCHEMA = dbutils.widgets.get("schema_name").strip()
FUNC_NAME = f"{CATALOG}.{SCHEMA}.py_listshssqlraw"
TOOL_NAME_CLEAN = FUNC_NAME.replace('.', '__')

# 2. Generate the Agent Code
agent_code_template = """
import json
import os
import warnings
from typing import Any, Callable, Generator, Optional, List
from uuid import uuid4

import mlflow
from databricks.sdk import WorkspaceClient
from mlflow.entities import SpanType
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
    output_to_responses_items_stream,
    to_chat_completions_input,
)
from openai import OpenAI
from pydantic import BaseModel
from unitycatalog.ai.core.base import get_uc_function_client
from unitycatalog.ai.core.databricks import DatabricksFunctionClient

# Configuration
LLM_ENDPOINT_NAME = "databricks-claude-opus-4-6"
MAX_TOKENS = 125000
SYSTEM_PROMPT = \"\"\"You are a senior spark performance engineer. Use the get_raw_sql_metrics_full(<cluster_id> ) to retrieve the cluster's Spark SQL execution plan JSON blob for a Spark application/job. Please analyze the execution plan and provide a deep, detailed analysis for the job run, including any bottlenecks found and performance recommendations/optimizations.\"\"\"

# Tool Helper Classes
class ToolInfo(BaseModel):
    name: str
    spec: dict
    exec_fn: Callable

class ToolCallingAgent(ResponsesAgent):
    def __init__(self, llm_endpoint: str):
        self.llm_endpoint = llm_endpoint
        
        # Initialize standard client for LLM
        self.workspace_client = WorkspaceClient()
        self.model_serving_client: OpenAI = (
            self.workspace_client.serving_endpoints.get_open_ai_client()
        )
        self._tools_dict = self._load_tools()

    def _load_tools(self) -> dict[str, ToolInfo]:
        print("Initializing Tools via Manual Spec...")
        
        # 1. DEFINE THE TOOL MANUALLY
        tool_name = "REPLACE_TOOL_NAME_CLEAN"
        
        tool_spec = {
            "type": "function",
            "function": {
                "name": tool_name,
                "description": "Retrieves Spark SQL execution plan metrics for a given cluster.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "cluster_id": {
                            "type": "string",
                            "description": "The Cluster ID to analyze"
                        }
                        # HIDDEN: pat_token AND cp_url are omitted here
                    },
                    "required": ["cluster_id"]
                }
            }
        }

        # 2. DEFINE EXECUTION WITH EXPLICIT CLIENT
        def exec_fn(**kwargs):
            udf_name = "REPLACE_FUNC_NAME"
            print(f"Executing {udf_name}...")
            
            # A. Retrieve Secrets/Env Vars
            pat_token = os.environ.get("USER_PAT_OVERRIDE")
            cp_url = os.environ.get("USER_CP_URL_OVERRIDE")

            if not pat_token:
                return json.dumps({"error": "USER_PAT_OVERRIDE not found."})
            if not cp_url:
                return json.dumps({"error": "USER_CP_URL_OVERRIDE not found."})

            # B. Inject tokens into arguments
            # The UDF now expects these 3 args: cluster_id, pat_token, cp_url
            kwargs["pat_token"] = pat_token
            kwargs["cp_url"] = cp_url

            # C. Create Client
            # We explicitly pass the host (cp_url) and token to ensure the client connects correctly
            client = DatabricksFunctionClient(
                client=WorkspaceClient(host=cp_url, token=pat_token)
            )
            
            try:
                # D. Execute
                res = client.execute_function(udf_name, kwargs)
                if res.error:
                    return f"Error executing function: {res.error}"
                
                # --- UX FIX: Compact & Prefix Output ---
                result_val = res.value
                if not isinstance(result_val, str):
                    result_val = json.dumps(result_val)
                    
                compact_json = result_val.replace('\\n', '').replace('\\r', '')
                final_output = f"SPARK_METRICS_DATA: {compact_json}"
                
                return final_output

            except Exception as e:
                return f"Exception during execution: {str(e)}"
        
        return {
            tool_name: ToolInfo(name=tool_name, spec=tool_spec, exec_fn=exec_fn)
        }

    def get_tool_specs(self) -> list[dict]:
        return [tool_info.spec for tool_info in self._tools_dict.values()]

    @mlflow.trace(span_type=SpanType.TOOL)
    def execute_tool(self, tool_name: str, args: dict) -> Any:
        return self._tools_dict[tool_name].exec_fn(**args)

    def call_llm(self, messages: list[dict[str, Any]]) -> Generator[dict[str, Any], None, None]:
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message="PydanticSerializationUnexpectedValue")
            for chunk in self.model_serving_client.chat.completions.create(
                model=self.llm_endpoint,
                messages=to_chat_completions_input(messages),
                tools=self.get_tool_specs(),
                stream=True,
                max_tokens=MAX_TOKENS 
            ):
                chunk_dict = chunk.to_dict()
                if len(chunk_dict.get("choices", [])) > 0:
                    yield chunk_dict

    def handle_tool_call(self, tool_call: dict[str, Any], messages: list[dict[str, Any]]) -> ResponsesAgentStreamEvent:
        try:
            args = json.loads(tool_call.get("arguments"))
        except Exception:
            args = {}
        result = str(self.execute_tool(tool_name=tool_call["name"], args=args))
        tool_call_output = self.create_function_call_output_item(tool_call["call_id"], result)
        messages.append(tool_call_output)
        return ResponsesAgentStreamEvent(type="response.output_item.done", item=tool_call_output)

    def call_and_run_tools(self, messages: list[dict[str, Any]], max_iter: int = 10) -> Generator[ResponsesAgentStreamEvent, None, None]:
        for _ in range(max_iter):
            last_msg = messages[-1]
            if last_msg.get("role", None) == "assistant":
                return
            elif last_msg.get("type", None) == "function_call":
                yield self.handle_tool_call(last_msg, messages)
            else:
                yield from output_to_responses_items_stream(
                    chunks=self.call_llm(messages), aggregator=messages
                )
        yield ResponsesAgentStreamEvent(
            type="response.output_item.done",
            item=self.create_text_output_item("Max iterations reached. Stopping.", str(uuid4())),
        )

    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        session_id = request.custom_inputs.get("session_id") if request.custom_inputs else None
        if not session_id and request.context:
            session_id = request.context.conversation_id
        if session_id:
            mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id})

        outputs = [event.item for event in self.predict_stream(request) if event.type == "response.output_item.done"]
        return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs)

    def predict_stream(self, request: ResponsesAgentRequest) -> Generator[ResponsesAgentStreamEvent, None, None]:
        session_id = request.custom_inputs.get("session_id") if request.custom_inputs else None
        if not session_id and request.context:
            session_id = request.context.conversation_id
        if session_id:
            mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id})

        messages = to_chat_completions_input([i.model_dump() for i in request.input])
        if SYSTEM_PROMPT:
            messages.insert(0, {"role": "system", "content": SYSTEM_PROMPT})
        yield from self.call_and_run_tools(messages=messages)

# Log the model using MLflow
AGENT = ToolCallingAgent(llm_endpoint=LLM_ENDPOINT_NAME)
mlflow.models.set_model(AGENT)
"""

# 3. Perform the substitution
final_agent_code = agent_code_template.replace("REPLACE_FUNC_NAME", FUNC_NAME)
final_agent_code = final_agent_code.replace("REPLACE_TOOL_NAME_CLEAN", TOOL_NAME_CLEAN)

# 4. Write the file
with open("agent.py", "w") as f:
    f.write(final_agent_code)

print(f"✅ Generated agent.py with CP_URL injection for: {FUNC_NAME}")

In [0]:
import mlflow
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint

# 1. Define Resources Manually
# We list the endpoint and the function so Model Serving knows they are dependencies
resources = [
    DatabricksServingEndpoint(endpoint_name="databricks-claude-opus-4-6"),
    # We point to the inner function we are using
    DatabricksFunction(function_name=f"{CATALOG}.{SCHEMA}.py_listshssqlraw")
]

# 2. FAST Input Example
# We use a simple greeting so the Agent DOES NOT call the heavy tool during logging.
# This generates the same schema (String -> String) but runs instantly.
input_example = {
    "input": [{"role": "user", "content": "Hello! Are you ready to analyze Spark logs?"}],
    "custom_inputs": {"session_id": "fast-logging-test"}
}

print("Logging model (fast mode)...")

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="agent",
        python_model="agent.py",
        input_example=input_example,
        pip_requirements=[
            "databricks-openai",
            "databricks-sdk",
            "pydantic",
            "mlflow",
            "backoff",
            "unitycatalog-ai" 
        ],
        resources=resources,
    )

print(f"✅ Model logged successfully at: {logged_agent_info.model_uri}")

In [0]:
# Load the model back into the current session to test it
# This bypasses the auth issues of the isolated 'uv' environment
print(f"Testing model from run: {logged_agent_info.run_id}")
loaded_model = mlflow.pyfunc.load_model(f"runs:/{logged_agent_info.run_id}/agent")

# Run a live test
response = loaded_model.predict({
    "input": [{"role": "user", "content": "Hello! Are you ready?"}], 
    "custom_inputs": {"session_id": "final-validation"}
})

print("\nResponse from Agent:")
print(response)

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
CATALOG = dbutils.widgets.get("catalog_name").strip()
SCHEMA = dbutils.widgets.get("schema_name").strip()


model_name = "spark_observability_agent_dbx"
UC_MODEL_NAME = f"{CATALOG}.{SCHEMA}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

print(f"✅ Model registered: {UC_MODEL_NAME} (Version {uc_registered_model_info.version})")

In [0]:
import mlflow
import os

# 1. SETUP CONFIG
CATALOG = dbutils.widgets.get("catalog_name").strip()
SCHEMA = dbutils.widgets.get("schema_name").strip()
model_name = "spark_observability_agent_dbx"
UC_MODEL_FULL_NAME = f"{CATALOG}.{SCHEMA}.{model_name}"

# 2. CRITICAL STEP: INJECT THE SECRET LOCALLY
# The Agent code now looks for "USER_PAT_OVERRIDE" in os.environ.
# We fetch the real PAT from your secret scope and set it here for this session.
try:
    # Fetch the secret (returns the actual string value in the notebook)
    dev_pat = dbutils.secrets.get(scope="shscreds", key="pat")
    cpurl = dbutils.secrets.get(scope="shscreds", key="cpurl")
    
    # Set the environment variable so agent.py can read it
    os.environ["USER_PAT_OVERRIDE"] = dev_pat
    os.environ["USER_CP_URL_OVERRIDE"] = cpurl
    print("✅ Successfully injected USER_PAT_OVERRIDE into local environment.")
except Exception as e:
    print(f"⚠️ Warning: Could not fetch secret. Agent execution might fail if UDF requires auth.\nError: {e}")

# 3. LOAD MODEL
# Replace '3' with the actual version you want to test if needed
model_uri = f"models:/{UC_MODEL_FULL_NAME}/1"
print(f"Loading model from UC: {model_uri}")

uc_agent = mlflow.pyfunc.load_model(model_uri)

# 4. PREDICT
# Now when the agent runs 'exec_fn', it will find the token in os.environ
print("Running prediction...")
response = uc_agent.predict({
    "input": [{"role": "user", "content": "Analyze performance for cluster 0205-013230-2hrzofqr"}],
    "custom_inputs": {"session_id": "uc-test-local-01"}
})

print("Success!")
# print(response[0]) # Uncomment to see the output

In [0]:
from IPython.display import display, Markdown

# 1. Run the prediction (assuming you haven't already)
# response = local_agent.predict(request) 

# 2. Parse the complex Agent response
final_markdown_text = ""

# The 'response' object is a dictionary containing a list of events/messages
if isinstance(response, dict) and "output" in response:
    for item in response["output"]:
        # We only care about messages from the assistant (the AI)
        if item.get("type") == "message" and item.get("role") == "assistant":
            # Extract the actual text content from the message
            for content_item in item.get("content", []):
                if content_item.get("type") == "output_text":
                    final_markdown_text += content_item.get("text", "")

# 3. Display formatted Markdown
print("="*50)
print("FORMATTED AGENT RESPONSE:")
print("="*50)
display(Markdown(final_markdown_text))

In [0]:
from databricks import agents

# This reference tells Model Serving to fetch the secret securely 
# and expose it as an environment variable inside the container.
SECRET_PATH = "{{secrets/shscreds/pat}}"
SECRET_CP_URL = "{{secrets/shscreds/cpurl}}"

print(f"Deploying model version {uc_registered_model_info.version}")

agents.deploy(
    UC_MODEL_NAME, 
    uc_registered_model_info.version, 
    tags={"endpointSource": "playground"},
    environment_vars={
        "USER_PAT_OVERRIDE": SECRET_PATH,
        "USER_CP_URL_OVERRIDE": SECRET_CP_URL
    }
)