In [0]:
pip install snowflake 

In [0]:
pip install mlflow

In [0]:
import mlflow
from mlflow.tracking import MlflowClient
import os
import snowflake.connector

In [0]:
# Set to Databricks MLflow
mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")

In [0]:
# Environment settings
model_name = "workspace.default.CreditCardFraudModel"
METRICS_TO_COMPARE = ['Accuracy', 'Precision', 'Recall', 'F1 Score', 'Matthews Corrcoef']


In [0]:
# Snowflake details

# Load credentials from environment variables

account = dbutils.widgets.get("SNOWFLAKE_ACCOUNT")
user = dbutils.widgets.get("SNOWFLAKE_USER")
password = dbutils.widgets.get("SNOWFLAKE_PASSWORD")
warehouse = dbutils.widgets.get("SNOWFLAKE_WAREHOUSE")
database = dbutils.widgets.get("SNOWFLAKE_DATABASE")
schema = dbutils.widgets.get("SNOWFLAKE_SCHEMA")
email = dbutils.widgets.get("DATABRICKS_EMAIL")


# Optional: set them as environment variables if you need

os.environ["SNOWFLAKE_USER"] = user
os.environ["SNOWFLAKE_PASSWORD"] = password
os.environ["SNOWFLAKE_ACCOUNT"] = account
os.environ["SNOWFLAKE_WAREHOUSE"] = warehouse
os.environ["SNOWFLAKE_DATABASE"] = database
os.environ["DATABRICKS_EMAIL"] = email
os.environ["DATABRICKS_SCHEMA"] = schema

In [0]:
def copy_reference_table():
    print("\n📤 Copying reference dataset in Snowflake...")
    conn = snowflake.connector.connect(
        user=user,
        password=password,
        account=account,
        warehouse=warehouse,
        database='CREDITCARD_REFERENCE',
        schema='PUBLIC'
    )
    cur = conn.cursor()
    cur.execute("""
        CREATE OR REPLACE TABLE CREDITCARD_REFERENCE.PUBLIC.CREDITCARD_REFERENCE AS
        SELECT * FROM CREDITCARD.PUBLIC.CREDITCARD
    """)
    conn.close()
    print("✅ Reference table copied successfully.")


In [0]:
def get_model_versions(client, model_name):
    return client.search_model_versions(f"name='workspace.default.{model_name}'")

In [0]:
def get_model_version_by_tag(client, model_name, tag_key, tag_value):
    versions = client.search_model_versions(f"name='{model_name}'")
    for v in versions:
        mv = client.get_model_version(name=model_name, version=v.version)  # ✅ Fetch full version
        tags = mv.tags  # ✅ tags are available here
        if tags.get(tag_key) == tag_value:
            return mv  # ✅ return full model version with tags
    return None

In [0]:
def get_model_version_metrics(client, model_name, version):
    mv = client.get_model_version(name=model_name, version=version)
    run_id = mv.run_id
    run = client.get_run(run_id)
    return run.data.metrics

In [0]:
def better_than(challenger_metrics, champion_metrics):
    better_count = 0
    for metric in METRICS_TO_COMPARE:
        c_val = challenger_metrics.get(metric)
        champ_val = champion_metrics.get(metric)
        if c_val is None or champ_val is None:
            continue
        if c_val > champ_val:
            better_count += 1
    return better_count > len(METRICS_TO_COMPARE) / 2

In [0]:
def main():
    client = MlflowClient()

    print("🚀 Starting Databricks-compatible Champion Selection...")

    # Get current challenger
    challenger = get_model_version_by_tag(client, model_name, "role", "challenger")
    if not challenger:
        print("❌ No challenger model found.")
        return

    print(f"ℹ️ Challenger found: version {challenger.version}, run {challenger.run_id}")

    # Get current champion
    champion = get_model_version_by_tag(client, model_name, "role", "champion")

    if not champion:
        print("⚠️ No champion found. Promoting challenger directly.")
        client.set_model_version_tag(model_name, challenger.version, "status", "production")
        client.set_model_version_tag(model_name, challenger.version, "role", "champion")
        copy_reference_table()
        print(f"✅ Challenger version {challenger.version} promoted to production.")
        return

    print(f"ℹ️ Champion found: version {champion.version}, run {champion.run_id}")

    challenger_metrics = get_model_version_metrics(client, model_name, challenger.version)
    champion_metrics = get_model_version_metrics(client, model_name, champion.version)

    print("\n📊 Metrics Comparison:")
    print(f"{'Metric':<20} {'Challenger':<15} {'Champion':<15}")
    print("-" * 50)
    for metric in METRICS_TO_COMPARE:
        c_val = challenger_metrics.get(metric, 'N/A')
        champ_val = champion_metrics.get(metric, 'N/A')
        print(f"{metric:<20} {str(c_val):<15} {str(champ_val):<15}")

    if better_than(challenger_metrics, champion_metrics):
        print(f"\n🚀 Challenger outperforms champion. Promoting challenger...")
        # Archive old champion
        client.set_model_version_tag(model_name, champion.version, "status", "archived")
        client.set_model_version_tag(model_name, champion.version, "role", "archived")

        # Promote challenger
        client.set_model_version_tag(model_name, challenger.version, "status", "production")
        client.set_model_version_tag(model_name, challenger.version, "role", "champion")

        copy_reference_table()

        print(f"✅ Challenger v{challenger.version} promoted to Champion (Production).")
    else:
        print(f"\n⚠️ Challenger did NOT outperform champion. No changes made.")

if __name__ == "__main__":
    main()
