In [None]:
# üöÄ MULTI-MODEL SERVING ENDPOINT (GIT VARIABLE DRIVEN)

import os
import json
import time
import yaml
import traceback
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
    EndpointCoreConfigInput,
    ServedEntityInput
)
import mlflow
from mlflow.tracking import MlflowClient

print("=" * 80)
print("üöÄ MULTI-MODEL SERVING (GIT VARIABLE CONTROLLED)")
print("=" * 80)

# --------------------------------------------------
# 1Ô∏è‚É£ LOAD PIPELINE CONFIG
# --------------------------------------------------
try:
    with open("pipeline_config.yml", "r") as f:
        pipeline_cfg = yaml.safe_load(f)
    print("‚úÖ pipeline_config.yml loaded")
except Exception as e:
    print("‚ùå Failed to load pipeline_config.yml")
    traceback.print_exc()
    raise e

# --------------------------------------------------
# 2Ô∏è‚É£ LOAD MODEL SERVING CONFIG FROM GIT VARIABLE
# --------------------------------------------------
MODEL_SERVING_CONFIG_RAW = os.getenv("MODEL_SERVING_CONFIG")

if not MODEL_SERVING_CONFIG_RAW:
    raise ValueError("‚ùå MODEL_SERVING_CONFIG Git variable not set")

MODEL_SERVING_CONFIG = json.loads(MODEL_SERVING_CONFIG_RAW)

print("\nüì¶ MODEL_SERVING_CONFIG:")
print(json.dumps(MODEL_SERVING_CONFIG, indent=2))

# --------------------------------------------------
# 3Ô∏è‚É£ INIT CLIENTS
# --------------------------------------------------
w = WorkspaceClient()
mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()

# --------------------------------------------------
# 4Ô∏è‚É£ BUILD SERVED ENTITIES (TRAFFIC SPLIT)
# --------------------------------------------------
served_entities = []

for model_key, cfg in MODEL_SERVING_CONFIG.items():

    if not cfg.get("active", False):
        print(f"‚è≠Ô∏è Skipping inactive model: {model_key}")
        continue

    served_entities.append(
        ServedEntityInput(
            entity_name=cfg["model_name"],
            entity_version=str(cfg["version"]),
            workload_size=pipeline_cfg["serving"]["workload_size"],
            scale_to_zero_enabled=pipeline_cfg["serving"]["scale_to_zero_enabled"],
            traffic_percentage=int(cfg["traffic"])
        )
    )

if not served_entities:
    raise ValueError("‚ùå No active models found in MODEL_SERVING_CONFIG")

# --------------------------------------------------
# 5Ô∏è‚É£ ENDPOINT NAME (SINGLE PROD ENDPOINT)
# --------------------------------------------------
BASE_NAME = pipeline_cfg["models"]["base_name"]
ENDPOINT_NAME = f"{BASE_NAME}-prod"

print(f"\nüöÄ Serving Endpoint: {ENDPOINT_NAME}")

# --------------------------------------------------
# 6Ô∏è‚É£ CREATE OR UPDATE ENDPOINT
# --------------------------------------------------
def endpoint_exists(name):
    try:
        w.serving_endpoints.get(name=name)
        return True
    except:
        return False

try:
    if endpoint_exists(ENDPOINT_NAME):
        print("üîÑ Updating existing endpoint...")
        w.serving_endpoints.update_config(
            name=ENDPOINT_NAME,
            served_entities=served_entities
        )
    else:
        print("‚ûï Creating new endpoint...")
        w.serving_endpoints.create(
            name=ENDPOINT_NAME,
            config=EndpointCoreConfigInput(
                served_entities=served_entities
            )
        )
except Exception as e:
    print("‚ùå Failed to deploy serving endpoint")
    traceback.print_exc()
    raise e

# --------------------------------------------------
# 7Ô∏è‚É£ WAIT UNTIL READY
# --------------------------------------------------
print("\n‚è≥ Waiting for endpoint to be READY...")
start = time.time()
timeout = pipeline_cfg["serving"]["deployment_timeout"]

while time.time() - start < timeout:
    ep = w.serving_endpoints.get(name=ENDPOINT_NAME)
    state = str(ep.state)

    if "READY" in state and "NOT_UPDATING" in state:
        print("‚úÖ Endpoint is READY")
        break

    time.sleep(pipeline_cfg["serving"]["status_check_interval"])
else:
    raise TimeoutError("‚ùå Endpoint did not become READY in time")

# --------------------------------------------------
# 8Ô∏è‚É£ SUCCESS
# --------------------------------------------------
print("\nüéâ SERVING DEPLOYMENT SUCCESSFUL")
print("üìä Traffic Distribution:")
for s in served_entities:
    print(f"   ‚Ä¢ {s.entity_name} v{s.entity_version} ‚Üí {s.traffic_percentage}%")

