In [0]:
import requests
import json
import mlflow
import time


# --- define and get notebook widgets for configuration ---
dbutils.widgets.text("serving_endpoint_name", "p300", "Name of the Serving Endpoint")
dbutils.widgets.text("registered_model_name", "P300-Classifier", "Workspace Registered Model Name")

# get widget values
serving_endpoint_name = dbutils.widgets.get("serving_endpoint_name")
registered_model_name = dbutils.widgets.get("registered_model_name")
challenger_entity_name = "challenger"
champion_entity_name = "champion"

In [0]:
# --- configure API clients ---
databricks_host = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
databricks_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
headers = {"Authorization": f"Bearer {databricks_token}", "Content-Type": "application/json"}

mlflow.set_registry_uri("databricks")
mlflow_client = mlflow.tracking.MlflowClient()

In [0]:
# --- step 1: identify challenger and current champion models ---
challenger_to_promote = None
current_champion = None

# get latest version in staging (our challenger)
staging_versions = mlflow_client.get_latest_versions(name=registered_model_name, stages=["Staging"])
if not staging_versions:
    dbutils.notebook.exit(f"no model version found in 'Staging' for '{registered_model_name}'. nothing to promote.")
challenger_to_promote = staging_versions[0]
print(f"found challenger to promote: Version {challenger_to_promote.version} (from Run ID: {challenger_to_promote.run_id})")

# get latest version in production (our current champion)
production_versions = mlflow_client.get_latest_versions(name=registered_model_name, stages=["Production"])
if production_versions:
    current_champion = production_versions[0]
    print(f"found current champion: Version {current_champion.version} (from Run ID: {current_champion.run_id})")
else:
    print("no model is currently in 'Production' stage. proceeding with first promotion.")
    

# --- step 2: check if promotion is redundant ---
if current_champion and challenger_to_promote.run_id == current_champion.run_id:
    dbutils.notebook.exit(f"promotion skipped. the model from run ID '{challenger_to_promote.run_id}' is already the champion.")


# --- step 3: transition the model stage in the MLflow registry ---
print(f"promoting version {challenger_to_promote.version} to the 'Production' stage...")
mlflow_client.transition_model_version_stage(
    name=registered_model_name,
    version=challenger_to_promote.version,
    stage="Production",
    archive_existing_versions=True
)
print("model stage successfully transitioned in the registry.")

In [0]:
# --- step 4: update the serving endpoint to serve the new champion ---
print(f"updating serving endpoint '{serving_endpoint_name}' to serve the new champion...")

# get the current endpoint configuration
endpoint_get_url = f"{databricks_host}/api/2.0/serving-endpoints/{serving_endpoint_name}"
response = requests.get(endpoint_get_url, headers=headers)
response.raise_for_status()

endpoint_config = response.json()
current_config = endpoint_config.get("config", {})
served_entities = current_config.get("served_entities", [])
traffic_config = current_config.get("traffic_config", {})


# find the champion entity and update its version
found_champion_entity = False
for entity in served_entities:
    if entity.get('name') == champion_entity_name:
        entity['entity_version'] = challenger_to_promote.version
        found_champion_entity = True
        print(f"champion entity will be updated to serve version {challenger_to_promote.version}.")
        break

if not found_champion_entity:
    dbutils.notebook.exit(f"could not find a served entity named '{champion_entity_name}' to update.")

# construct the update payload and send the request
update_payload = {"served_entities": served_entities, "traffic_config": traffic_config}
endpoint_config_url = f"{databricks_host}/api/2.0/serving-endpoints/{serving_endpoint_name}/config"

update_response = requests.put(endpoint_config_url, headers=headers, data=json.dumps(update_payload))
update_response.raise_for_status()

print("endpoint update request accepted. promotion process complete.")