In [0]:
import requests
import json
import mlflow


# --- define and get notebook widgets for configuration ---
dbutils.widgets.text("serving_endpoint_name", "p300", "Name of the Serving Endpoint")
dbutils.widgets.text("registered_model_name", "P300-Classifier", "Workspace Registered Model Name")

# --- define and get notebook widgets for configuration ---
serving_endpoint_name = dbutils.widgets.get("serving_endpoint_name")
registered_model_name = dbutils.widgets.get("registered_model_name")
challenger_entity_name = "challenger"
champion_entity_name = "champion"

In [0]:
# --- configure API clients ---
databricks_host = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
databricks_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
headers = {"Authorization": f"Bearer {databricks_token}", "Content-Type": "application/json"}

# configure MLflow client
mlflow.set_registry_uri("databricks")
mlflow_client = mlflow.tracking.MlflowClient()

# --- find the latest model version in the 'Staging' stage ---
new_challenger_version_info = None
staging_versions = mlflow_client.get_latest_versions(name=registered_model_name, stages=["Staging"])
if not staging_versions:
    dbutils.notebook.exit(f"no model version found in 'Staging' for '{registered_model_name}'.")
new_challenger_version_info = staging_versions[0]
new_challenger_version_num = new_challenger_version_info.version
display(f"found latest 'Staging' model: Version {new_challenger_version_num}")

In [0]:
    endpoint_creation_payload = {
        "name": serving_endpoint_name,
        "config": {
            "served_entities": [
                {"name": champion_entity_name, "entity_name": registered_model_name, "entity_version": new_challenger_version_num, "workload_size": "Small", "scale_to_zero_enabled": True},
                {"name": challenger_entity_name, "entity_name": registered_model_name, "entity_version": new_challenger_version_num, "workload_size": "Small", "scale_to_zero_enabled": True}
            ],
            "traffic_config": { "routes": [ {"served_entity_name": champion_entity_name, "traffic_percentage": 100}, {"served_entity_name": challenger_entity_name, "traffic_percentage": 0} ] }
        }
    }

In [0]:
# --- check if the serving endpoint already exists ---
endpoint_url = f"{databricks_host}/api/2.0/serving-endpoints/{serving_endpoint_name}"
response = requests.get(endpoint_url, headers=headers)

In [0]:
# --- get the current configuration of the serving endpoint ---
if response.status_code == 404:
    # --- case 1: endpoint does not exist. create it for the first time. ---
    print(f"serving endpoint '{serving_endpoint_name}' not found. creating new endpoint.")
    print(f"seeding endpoint: '{champion_entity_name}' and '{challenger_entity_name}' will both serve version {new_challenger_version_info.version}.")

    creation_url = f"{databricks_host}/api/2.0/serving-endpoints"
    creation_response = requests.post(creation_url, headers=headers, data=json.dumps(endpoint_creation_payload))
    creation_response.raise_for_status()
    print("endpoint creation request accepted. deployment is in progress.")

In [0]:
if response.status_code == 200:
    # --- case 2: endpoint exists. update it. ---
    print(f"serving endpoint '{serving_endpoint_name}' found. proceeding with update.")
    
    endpoint_config = response.json()
    current_config = endpoint_config.get("config", {})
    served_entities = current_config.get("served_entities", [])
    traffic_config = current_config.get("traffic_config", {})
    found_challenger = False

    # modify the challenger entity to point to the new version
    for entity in served_entities:
        if entity.get('name') == challenger_entity_name:
            found_challenger_entity = True
            
            # get details of the currently served version to find its run_id
            current_served_version_num = entity.get('entity_version')
            current_served_version_info = mlflow_client.get_model_version(name=registered_model_name, version=current_served_version_num)
            current_run_id = current_served_version_info.run_id
            print(f"found challenger entity. currently serving version: {current_served_version_num} (from run ID: {current_run_id}).")

            new_challenger_run_id = new_challenger_version_info.run_id            
            if current_run_id == new_challenger_run_id:
                dbutils.notebook.exit(f"challenger is already serving the model from the latest staging run ({new_challenger_run_id}). no update needed.")

            # if run IDs are different, or if we couldn't get details of the old one, proceed with update
            new_challenger_version_num = new_challenger_version_info.version
            print(f"updating challenger to serve version: {new_challenger_version_num}")
            entity['entity_version'] = new_challenger_version_num
            break

    # construct the update payload
    update_payload = {
        "served_entities": served_entities,
        "traffic_config": traffic_config
    }
    update_url = f"{databricks_host}/api/2.0/serving-endpoints/{serving_endpoint_name}/config"
    update_response = requests.put(update_url, headers=headers, data=json.dumps(update_payload))
    
    if update_response.status_code == 200:
        print("endpoint update request accepted successfully. update is in progress.")
    else:
        print(f"error updating endpoint. Status: {update_response.status_code}, Response: {update_response.text}")