# Deploy RAG Agents

This notebook:
1. Logs agents to MLflow
2. Registers models in Unity Catalog
3. Deploys as serving endpoints using `agents.deploy()`

**Prerequisites:**
- Run `01_parse_pdfs.ipynb` and `02_create_indexes.ipynb` first
- Vector Search indexes must be ONLINE

In [None]:
%pip install mlflow>=3.1.3 databricks-agents databricks-langchain langgraph pyyaml
dbutils.library.restartPython()

In [None]:
# Parameters - widgets allow both interactive + job execution
dbutils.widgets.text("catalog", "your_catalog", "Unity Catalog")
dbutils.widgets.text("schema", "rag_agents", "Schema name")

CATALOG = dbutils.widgets.get("catalog")
SCHEMA = dbutils.widgets.get("schema")

# Fail fast in job context if not configured
try:
    is_job = bool(spark.conf.get("spark.databricks.job.id"))
except Exception:
    is_job = False

if is_job and CATALOG == "your_catalog":
    raise ValueError("catalog parameter required for job execution")

# Agent configurations: (config_file, model_name, endpoint_name)
AGENTS = [
    ("config_agent_a.yml", "agent_a_rag_model", "agent-a-rag"),
    ("config_agent_b.yml", "agent_b_rag_model", "agent-b-rag"),
    ("config_agent_c.yml", "agent_c_rag_model", "agent-c-rag"),
]

In [None]:
import yaml
from pathlib import Path

import mlflow
from databricks import agents
from databricks_langchain import ChatDatabricks, VectorSearchRetrieverTool
from langgraph.prebuilt import create_react_agent
from mlflow.models import infer_signature

mlflow.set_registry_uri("databricks-uc")

In [None]:
import yaml
from pathlib import Path

import mlflow
from databricks import agents
from databricks_langchain import ChatDatabricks, VectorSearchRetrieverTool
from langgraph.prebuilt import create_react_agent
from mlflow.models import infer_signature

mlflow.set_registry_uri("databricks-uc")


def load_config(config_name: str, catalog: str, schema: str) -> dict:
    """Load agent config from file and inject catalog/schema."""
    # Adjust path based on notebook location
    config_path = Path("../agents") / config_name
    if not config_path.exists():
        config_path = Path("/Workspace/Repos/") / "agents" / config_name

    with open(config_path) as f:
        config = yaml.safe_load(f)

    # Replace catalog.rag_agents prefix with actual values in vector_search_index
    if "vector_search_index" in config:
        parts = config["vector_search_index"].split(".")
        if len(parts) >= 3:
            config["vector_search_index"] = f"{catalog}.{schema}.{parts[2]}"

    return config


def create_agent_from_config(config: dict):
    """Create RAG agent from configuration."""
    retriever_tool = VectorSearchRetrieverTool(
        index_name=config["vector_search_index"],
        num_results=5,
        columns=["content", "source", "chunk_id"],
        filters={},
        text_column="content",
        tool_name="search_documents",
        tool_description=f"Search the {config['agent_name']} knowledge base.",
    )

    llm = ChatDatabricks(
        endpoint=config.get("llm_endpoint", "databricks-meta-llama-3-3-70b-instruct"),
        temperature=0.1,
        max_tokens=1024,
    )

    agent = create_react_agent(
        model=llm,
        tools=[retriever_tool],
        state_modifier=config.get("system_prompt", "You are a helpful assistant."),
    )

    return agent

## Log and Register Agents

In [None]:
def log_and_register_agent(config_file: str, model_name: str) -> str:
    """Log agent to MLflow and register in Unity Catalog."""
    config = load_config(config_file, CATALOG, SCHEMA)
    agent = create_agent_from_config(config)

    model_full_name = f"{CATALOG}.{SCHEMA}.{model_name}"
    print(f"Logging agent: {config['agent_name']} -> {model_full_name}")

    # Define input/output schema
    input_example = {"messages": [{"role": "user", "content": "Hello, how can you help?"}]}

    with mlflow.start_run(run_name=f"deploy_{model_name}"):
        # Log the agent
        model_info = mlflow.langchain.log_model(
            lc_model=agent,
            artifact_path="agent",
            input_example=input_example,
            registered_model_name=model_full_name,
        )

    print(f"  Registered: {model_full_name}")
    return model_full_name

In [None]:
# Log all agents
registered_models = {}
for config_file, model_name, _ in AGENTS:
    try:
        full_name = log_and_register_agent(config_file, model_name)
        registered_models[model_name] = full_name
    except Exception as e:
        print(f"Error registering {model_name}: {e}")
        registered_models[model_name] = f"Error: {e}"

print("\n=== Registered Models ===")
for name, result in registered_models.items():
    print(f"{name}: {result}")

## Deploy Serving Endpoints

In [None]:
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()


def deploy_agent_endpoint(model_name: str, endpoint_name: str):
    """Deploy agent as serving endpoint using agents.deploy() (idempotent)."""
    model_full_name = f"{CATALOG}.{SCHEMA}.{model_name}"

    print(f"Deploying: {model_full_name} -> {endpoint_name}")

    # Check if endpoint already exists
    try:
        existing = w.serving_endpoints.get(endpoint_name)
        state = existing.state.ready if existing.state else "UNKNOWN"
        print(f"  Endpoint exists. Status: {state}")
        # agents.deploy will update existing endpoint with new model version
    except Exception:
        pass  # Endpoint doesn't exist, will be created

    # Deploy using databricks-agents (handles create or update)
    deployment = agents.deploy(
        model_name=model_full_name,
        model_version=1,  # Use latest; adjust as needed
        endpoint_name=endpoint_name,
        scale_to_zero=True,  # Cost optimization
    )

    print(f"  Endpoint: {endpoint_name}")
    print(f"  Status: {deployment.get('status', 'PENDING')}")
    return deployment

In [None]:
# Deploy all agents
deployments = {}
for config_file, model_name, endpoint_name in AGENTS:
    if model_name in registered_models and not registered_models[model_name].startswith(
        "Error"
    ):
        try:
            deployment = deploy_agent_endpoint(model_name, endpoint_name)
            deployments[endpoint_name] = "deployed"
        except Exception as e:
            print(f"Error deploying {endpoint_name}: {e}")
            deployments[endpoint_name] = f"Error: {e}"
    else:
        deployments[endpoint_name] = "skipped (model not registered)"

print("\n=== Deployments ===")
for endpoint, status in deployments.items():
    print(f"{endpoint}: {status}")

## Verify Endpoints

In [None]:
def check_endpoint_status(endpoint_name: str):
    """Check serving endpoint status."""
    try:
        endpoint = w.serving_endpoints.get(endpoint_name)
        state = endpoint.state.ready if endpoint.state else "UNKNOWN"
        print(f"{endpoint_name}: {state}")
        return state
    except Exception as e:
        print(f"{endpoint_name}: Error - {e}")
        return "ERROR"


print("=== Endpoint Status ===")
for _, _, endpoint_name in AGENTS:
    check_endpoint_status(endpoint_name)

## Test Endpoint

In [None]:
# Test first endpoint
test_endpoint = AGENTS[0][2]
test_message = "What information can you help me with?"

try:
    client = w.serving_endpoints.get_open_ai_client()
    response = client.chat.completions.create(
        model=test_endpoint,
        messages=[{"role": "user", "content": test_message}],
        max_tokens=256,
    )
    print(f"Endpoint: {test_endpoint}")
    print(f"Query: {test_message}")
    print(f"Response: {response.choices[0].message.content}")
except Exception as e:
    print(f"Test failed: {e}")