In [None]:
import os
import json
import time
import yaml
import traceback
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
    EndpointCoreConfigInput,
    ServedEntityInput,
    TrafficConfig,
    Route
)
import mlflow
from mlflow.tracking import MlflowClient

print("=" * 80)
print("üöÄ MULTI-MODEL SERVING (GENERAL VARIABLE CONTROLLED)")
print("=" * 80)

# --------------------------------------------------
# 1Ô∏è‚É£ LOAD PIPELINE CONFIG
# --------------------------------------------------
try:
    with open("pipeline_config.yml", "r") as f:
        pipeline_cfg = yaml.safe_load(f)
    print("‚úÖ pipeline_config.yml loaded")
except Exception as e:
    print("‚ùå Failed to load pipeline_config.yml")
    traceback.print_exc()
    raise e

try:
    with open("experiments_config.yml", "r") as f:
        experiments_cfg = yaml.safe_load(f)
    print("‚úÖ experiments_config.yml loaded")
except Exception as e:
    print("‚ùå Failed to load experiments_config.yml")
    traceback.print_exc()
    raise e


# --------------------------------------------------
# ‚úÖ 2Ô∏è‚É£ GENERAL VARIABLE BASED SERVING CONFIG (NO GIT/ENV/WIDGET)
# --------------------------------------------------
MODEL_SERVING_CONFIG = {
    "random_forest": {
        "version": "2",
        "traffic": 100,
        "active": True
    }
    # Example: Add more models if needed
    # "xgboost": {
    #     "version": "5",
    #     "traffic": 40,
    #     "active": True
    # }
}

print("\n‚úÖ MODEL_SERVING_CONFIG loaded from general variable:")
print(json.dumps(MODEL_SERVING_CONFIG, indent=2))


# --------------------------------------------------
# 3Ô∏è‚É£ VALIDATE AND AUTO-COMPLETE MODEL NAMES
# --------------------------------------------------
def get_full_model_name(model_type: str) -> str:
    """
    Generate full Unity Catalog model name from model_type
    """
    UC_CATALOG = pipeline_cfg["models"]["catalog"]
    UC_SCHEMA = pipeline_cfg["models"]["schema"]
    BASE_NAME = pipeline_cfg["models"]["base_name"]
    NAMING_FMT = pipeline_cfg["models"]["naming"]["format"]

    return NAMING_FMT.format(
        catalog=UC_CATALOG,
        schema=UC_SCHEMA,
        base_name=BASE_NAME,
        model_type=model_type
    )

available_models = list(experiments_cfg.get("models", {}).keys())

for model_type, cfg in MODEL_SERVING_CONFIG.items():

    # Validate model_type exists in experiments_config.yml
    if model_type not in available_models:
        raise ValueError(
            f"‚ùå Invalid model_type '{model_type}' in MODEL_SERVING_CONFIG\n"
            f"   Available models: {available_models}"
        )

    # Auto-generate model_name if missing
    if "model_name" not in cfg or not cfg["model_name"]:
        cfg["model_name"] = get_full_model_name(model_type)
        print(f"   ‚ÑπÔ∏è  Auto-generated model_name for '{model_type}': {cfg['model_name']}")

    # Ensure UC format
    if cfg["model_name"].count(".") < 2:
        cfg["model_name"] = get_full_model_name(model_type)
        print(f"   ‚ÑπÔ∏è  Converted to full UC path: {cfg['model_name']}")

    # Validate required fields
    required_fields = ["version", "traffic", "active"]
    for field in required_fields:
        if field not in cfg:
            raise ValueError(f"‚ùå Missing '{field}' for model '{model_type}'")

print("\n‚úÖ Configuration validated and completed")


# --------------------------------------------------
# 4Ô∏è‚É£ INIT CLIENTS
# --------------------------------------------------
w = WorkspaceClient()
mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()


# --------------------------------------------------
# 5Ô∏è‚É£ VERIFY MODELS EXIST IN REGISTRY
# --------------------------------------------------
print("\nüîç Verifying models in Unity Catalog...")

for model_type, cfg in MODEL_SERVING_CONFIG.items():

    if not cfg.get("active", False):
        print(f"   ‚è≠Ô∏è  Skipping inactive: {model_type}")
        continue

    model_name = cfg["model_name"]
    version = str(cfg["version"])

    try:
        model_versions = client.search_model_versions(f"name='{model_name}'")
        version_exists = any(mv.version == version for mv in model_versions)

        if not version_exists:
            available_versions = [mv.version for mv in model_versions]
            raise ValueError(
                f"‚ùå Version {version} not found for {model_name}\n"
                f"   Available versions: {available_versions}"
            )

        print(f"   ‚úÖ Found: {model_name} v{version}")

    except Exception as e:
        raise ValueError(
            f"‚ùå Model not found in registry: {model_name}\n"
            f"   Error: {e}\n"
            f"   Make sure the model is registered first!"
        )


# --------------------------------------------------
# 6Ô∏è‚É£ BUILD SERVED ENTITIES + TRAFFIC ROUTES
# --------------------------------------------------
served_entities = []
traffic_config_routes = []

print("\nüìä Building traffic distribution...")

for model_type, cfg in MODEL_SERVING_CONFIG.items():

    if not cfg.get("active", False):
        continue

    served_entity_name = f"{model_type}-v{cfg['version']}"

    served_entity = ServedEntityInput(
        name=served_entity_name,               # Must be unique
        entity_name=cfg["model_name"],         # UC model name
        entity_version=str(cfg["version"]),
        workload_size=pipeline_cfg["serving"]["workload_size"],
        scale_to_zero_enabled=pipeline_cfg["serving"]["scale_to_zero_enabled"]
    )
    served_entities.append(served_entity)

    route = Route(
        served_model_name=served_entity_name,
        traffic_percentage=int(cfg["traffic"])
    )
    traffic_config_routes.append(route)

    print(f"   ‚Ä¢ {cfg['model_name']} v{cfg['version']} ‚Üí {cfg['traffic']}%")

if not served_entities:
    raise ValueError("‚ùå No active models found in MODEL_SERVING_CONFIG")

# Validate traffic adds up to 100%
total_traffic = sum(
    int(cfg["traffic"])
    for cfg in MODEL_SERVING_CONFIG.values()
    if cfg.get("active", False)
)

if total_traffic != 100:
    raise ValueError(
        f"‚ùå Traffic percentages must add up to 100%\n"
        f"   Current total: {total_traffic}%"
    )

traffic_config = TrafficConfig(routes=traffic_config_routes)


# --------------------------------------------------
# 7Ô∏è‚É£ ENDPOINT NAME (SINGLE PROD ENDPOINT)
# --------------------------------------------------
BASE_NAME = pipeline_cfg["models"]["base_name"]
ENDPOINT_NAME = f"{BASE_NAME}-prod"

print(f"\nüöÄ Serving Endpoint: {ENDPOINT_NAME}")


# --------------------------------------------------
# 8Ô∏è‚É£ CREATE OR UPDATE ENDPOINT
# --------------------------------------------------
def endpoint_exists(name: str) -> bool:
    try:
        w.serving_endpoints.get(name=name)
        return True
    except:
        return False

try:
    if endpoint_exists(ENDPOINT_NAME):
        print("üîÑ Updating existing endpoint...")
        w.serving_endpoints.update_config(
            name=ENDPOINT_NAME,
            served_entities=served_entities,
            traffic_config=traffic_config
        )
    else:
        print("‚ûï Creating new endpoint...")
        w.serving_endpoints.create(
            name=ENDPOINT_NAME,
            config=EndpointCoreConfigInput(
                served_entities=served_entities,
                traffic_config=traffic_config
            )
        )
except Exception as e:
    print("‚ùå Failed to deploy serving endpoint")
    traceback.print_exc()
    raise e


# --------------------------------------------------
# 9Ô∏è‚É£ WAIT UNTIL READY
# --------------------------------------------------
print("\n‚è≥ Waiting for endpoint to be READY...")
start = time.time()
timeout = pipeline_cfg["serving"]["deployment_timeout"]

while time.time() - start < timeout:
    ep = w.serving_endpoints.get(name=ENDPOINT_NAME)
    state = str(ep.state)

    if "READY" in state and "NOT_UPDATING" in state:
        print("‚úÖ Endpoint is READY")
        break

    time.sleep(pipeline_cfg["serving"]["status_check_interval"])
else:
    raise TimeoutError("‚ùå Endpoint did not become READY in time")


# --------------------------------------------------
# üîü SUCCESS SUMMARY
# --------------------------------------------------
print("\n" + "=" * 80)
print("üéâ SERVING DEPLOYMENT SUCCESSFUL")
print("=" * 80)
print(f"üìç Endpoint Name: {ENDPOINT_NAME}")
print(f"\nüìä Traffic Distribution:")
for route in traffic_config_routes:
    print(f"   ‚Ä¢ {route.served_model_name} ‚Üí {route.traffic_percentage}%")
print("=" * 80)
