In [None]:
import argparse
import sys
from loguru import logger
from pyspark.sql import SparkSession
from pyspark.dbutils import DBUtils

from honeywell.serving.model_serving import ModelServing

In [None]:

# ---------------------------------------------------------
# Args
# ---------------------------------------------------------
parser = argparse.ArgumentParser(description="Deploy or update Databricks model serving endpoint")

parser.add_argument(
    "--model_name",
    type=str,
    required=True,
    help="Fully-qualified UC model name (e.g. catalog.schema.model_name)",
)

parser.add_argument(
    "--env",
    type=str,
    required=True,
    choices=["dev", "staging", "prod"],
    help="Deployment environment",
)

parser.add_argument(
    "--upstream_task_key",
    type=str,
    default="check_model_version",
    help="Databricks Jobs taskKey that produced model_version",
)

parser.add_argument(
    "--wait_for_ready",
    action="store_true",
    help="Wait until serving endpoint becomes READY",
)

parser.add_argument(
    "--timeout_sec",
    type=int,
    default=900,
    help="Timeout (seconds) to wait for endpoint readiness",
)

args = parser.parse_args()


In [None]:

# ---------------------------------------------------------
# Spark + DBUtils
# ---------------------------------------------------------
spark = SparkSession.builder.getOrCreate()
dbutils = DBUtils(spark)

In [None]:

# ---------------------------------------------------------
# Load model_version from upstream task
# ---------------------------------------------------------
logger.info("Fetching model version from upstream task...")

model_version = dbutils.jobs.taskValues.get(
    taskKey=args.upstream_task_key,
    key="model_version",
)

if not model_version:
    logger.error("‚ùå No model_version found from upstream task: {}", args.upstream_task_key)
    sys.exit(1)

logger.info("‚úÖ Model version to deploy: {}", model_version)

In [None]:

# ---------------------------------------------------------
# Clean serving endpoint name (Databricks-safe)
# ---------------------------------------------------------
def make_endpoint_name(model_name: str, env: str) -> str:
    """
    Generate a Databricks-compliant serving endpoint name.
    """
    clean = model_name.lower()
    clean = clean.replace(".", "-").replace("_", "-")

    endpoint = f"{clean}-serving-{env}"

    # Databricks max length = 63 chars
    if len(endpoint) > 63:
        logger.warning(
            "Endpoint name too long ({} chars). Truncating to 63.",
            len(endpoint),
        )
        endpoint = endpoint[:63]

    return endpoint


endpoint_name = make_endpoint_name(args.model_name, args.env)
logger.info("Using endpoint name: {}", endpoint_name)

In [None]:

# ---------------------------------------------------------
# Clean serving endpoint name (Databricks-safe)
# ---------------------------------------------------------
def make_endpoint_name(model_name: str, env: str) -> str:
    """
    Generate a Databricks-compliant serving endpoint name.
    """
    clean = model_name.lower()
    clean = clean.replace(".", "-").replace("_", "-")

    endpoint = f"{clean}-serving-{env}"

    # Databricks max length = 63 chars
    if len(endpoint) > 63:
        logger.warning(
            "Endpoint name too long ({} chars). Truncating to 63.",
            len(endpoint),
        )
        endpoint = endpoint[:63]

    return endpoint


endpoint_name = make_endpoint_name(args.model_name, args.env)
logger.info("Using endpoint name: {}", endpoint_name)

In [None]:

# ---------------------------------------------------------
# Guardrails for production safety
# ---------------------------------------------------------
if args.env == "prod" and not model_version:
    logger.error("‚ùå Refusing to deploy to PROD without a model_version.")
    sys.exit(1)

logger.info("Environment: {}", args.env)
logger.info("Model name: {}", args.model_name)

In [None]:

# ---------------------------------------------------------
# Initialize Serving Manager
# ---------------------------------------------------------
model_serving = ModelServing(
    model_name=args.model_name,     # UC model name
    endpoint_name=endpoint_name,    # Clean serving name
)

In [None]:

# ---------------------------------------------------------
# Deploy or update endpoint
# ---------------------------------------------------------
logger.info("Starting deployment/update of serving endpoint...")

model_serving.deploy_or_update_serving_endpoint(version=model_version)

logger.info("‚úÖ Deployment/update API call completed.")


In [None]:


# ---------------------------------------------------------
# Optional: wait until endpoint is READY (CI-safe)
# ---------------------------------------------------------
if args.wait_for_ready:
    logger.info("Waiting for serving endpoint to become READY...")

    model_serving.workspace.serving_endpoints.wait_get_serving_endpoint(
        name=endpoint_name,
        timeout=args.timeout_sec,
    )

    logger.info("üöÄ Serving endpoint is READY.")



In [None]:

# ---------------------------------------------------------
# Print serving endpoint URL
# ---------------------------------------------------------
workspace_url = spark.conf.get("spark.databricks.workspaceUrl")

serving_url = f"https://{workspace_url}/serving-endpoints/{endpoint_name}/invocations"

logger.info("üî• Model successfully deployed!")
logger.info("üöÄ Serving Endpoint URL:\n{}", serving_url)