In [0]:
import mlflow
from mlflow.deployments import get_deploy_client
from mlflow.tracking import MlflowClient

import os
import json


# Import libraries
%run reference
print(ENV_VARS, MODELS_NAME)

In [0]:
MODEL_VERSION_1 = dbutils.jobs.taskValues.get(taskKey="register_model_task_1", key="registered_model_version_1")
MODEL_VERSION_2 = dbutils.jobs.taskValues.get(taskKey="register_model_task_2", key="registered_model_version_2")

databricks_host = spark.conf.get('spark.databricks.workspaceUrl')

print(f"Received registered_model_name_1: {MODELS_NAME['MODEL1']}")
print(f"Received registered_model_version_1: {MODEL_VERSION_1}")

print(f"Received registered_model_name_2: {MODELS_NAME['MODEL2']}")
print(f"Received registered_model_version_2: {MODEL_VERSION_2}")

In [0]:
def create_or_update_endpoint(model_name_1, model_version_1, model_name_2, model_version_2, endpoint_name):
    try:
        print(f"Creating serving endpoint for {model_name_1}, {model_name_2}...")
        client = get_deploy_client("databricks")

        served_entities = [
            {
                "entity_name": model_name_1,
                "entity_version": model_version_1,
                "workload_size": ENV_VARS['WORKLOAD_SIZE'],
                "scale_to_zero_enabled": True if ENV_VARS['SCALE_TO_ZERO_ENABLED'] == "True" else False
            },
            {
                "entity_name": f"{ENV_VARS['CATALOG']}.{ENV_VARS['SCHEMA']}.{model_name_2}",
                "entity_version": model_version_2,  
                "workload_size": ENV_VARS['WORKLOAD_SIZE'],
                "scale_to_zero_enabled": True if ENV_VARS['SCALE_TO_ZERO_ENABLED'] == "True" else False
            }
        ]

        auto_capture_config = {
            "catalog_name": ENV_VARS['CATALOG'],
            "schema_name": "click_api",
            "table_name_prefix": endpoint_name,
            "enabled": True                         
        }

        traffic_config = {
            "routes": [
                {
                    "served_model_name": f"{model_name_1}-{model_version_1}",
                    "traffic_percentage": ENV_VARS['MODEL1_TRAFFIC_PERCENTAGE']
                },
                {
                    "served_model_name": f"{model_name_2}-{model_version_2}",
                    "traffic_percentage": ENV_VARS['MODEL2_TRAFFIC_PERCENTAGE']
                }
            ]
        }

        # First try updating the existing endpoint.
        endpoint = client.update_endpoint(
            endpoint=endpoint_name,
            config={
                "served_entities": served_entities,
                "auto_capture_config": auto_capture_config,
                "traffic_config": traffic_config
            }
        )

        # assert endpoint["name"] == model_name_1
        # assert endpoint["state"]["config_update"] == "IN_PROGRESS"
        print (f"Serving endpoint {endpoint_name} updated successfully using version {model_version_1}, {model_version_2} of ML model [{model_name_1}, {model_name_2}]")

    except Exception as e:
        
        if hasattr(e, "response") and hasattr(e.response, "json"):
            error = e.response.json()
            error_code = error.get("error_code")
            
            if error_code:
                # If it doesn't already exist, create a new endpoint.
                # if error_code == 'RESOURCE_DOES_NOT_EXIST':
                catalog_name = ENV_VARS['CATALOG']
                database_name = "click_api"

                create_database_sql = f"CREATE SCHEMA IF NOT EXISTS {catalog_name}.{database_name}"
                spark.sql(create_database_sql)
                
                endpoint = client.create_endpoint(
                    name=endpoint_name,
                    config={
                        "served_entities": served_entities,
                        "auto_capture_config": auto_capture_config,
                        "traffic_config": traffic_config
                    }
                )
                
                # Ensure that the endpoint is being created using the registered model.
                assert endpoint["pending_config"]["served_entities"][0]["name"] == f"{model_name_1}-{model_version_1}"
                assert endpoint["pending_config"]["served_entities"][1]["name"] == f"{model_name_2}-{model_version_2}"
                print (f"Serving endpoint {endpoint_name} made successfuly for combined ML Models {model_name_1}-{model_version_1}, {model_name_2}-{model_version_2}")
                
                # elif error_code == "RESOURCE_CONFLICT":
                #     print(f"Endpoint {endpoint_inside} is currently being updated. Please try again later")
        else:
            print("ERROR: " + e)
            raise e


create_or_update_endpoint(MODELS_NAME['MODEL1'], MODEL_VERSION_1, MODELS_NAME['MODEL2'], MODEL_VERSION_2, ENV_VARS['ENDPOINT_NAME'])

# Construct the model URL
model_url = f"https://{databricks_host}/serving-endpoints/{ENV_VARS['ENDPOINT_NAME']}/invocations"
print(model_url)
