In [None]:
# üéØ CREDIT RISK MODEL EVALUATION + AUTO REGISTRATION (NO DUPLICATES)

import mlflow
from mlflow.tracking import MlflowClient
import pandas as pd
import json
from datetime import datetime
from pyspark.sql import SparkSession

# CONFIG 

EXPERIMENT_NAME = "/Shared/CreditRisk_ML_Experiments"
MODEL_NAME = "workspace.ml_credit_risk.credit_risk_model_random_forest"

PRIMARY_METRIC = "test_f1"
DUPLICATE_TOLERANCE = 0.001  # 0.1% tolerance

EVALUATION_LOG = "workspace.ml_credit_risk.model_registration_log"

print("\nüìå CONFIGURATION LOADED")
print(f"Experiment: {EXPERIMENT_NAME}")
print(f"Duplicate Detection Metric Tolerance: {DUPLICATE_TOLERANCE}")
print("="*70)

# INIT

spark = SparkSession.builder.appName("CreditRiskModelRegistration").getOrCreate()
mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()


experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
if experiment is None:
    raise Exception(f"‚ùå Experiment '{EXPERIMENT_NAME}' not found!")

# STEP 1: Get Best Latest Trained Run

def get_latest_best_run():
    runs = client.search_runs(
        [experiment.experiment_id],
        order_by=[f"metrics.{PRIMARY_METRIC} DESC"],
        max_results=1
    )
    
    best_run = runs[0]
    
    return {
        "run_id": best_run.info.run_id,
        "run_name": best_run.info.run_name,
        "metric": best_run.data.metrics.get(PRIMARY_METRIC),
        "params": best_run.data.params,
        "model_uri": f"runs:/{best_run.info.run_id}/ml_model"
    }

# STEP 2: Check Duplicate in Registry

def is_duplicate_model(new_model):
    versions = client.search_model_versions(f"name='{MODEL_NAME}'")
    
    for v in versions:
        run = client.get_run(v.run_id)
        existing_metric = run.data.metrics.get(PRIMARY_METRIC)
        existing_params = run.data.params
        
        metric_match = abs(existing_metric - new_model["metric"]) < DUPLICATE_TOLERANCE
        params_match = (existing_params == new_model["params"])
        
        if metric_match and params_match:
            return True
    
    return False

# STEP 3: Log Decision

def log_registration(model, registered, reason):
    df = pd.DataFrame([{
        "timestamp": datetime.now(),
        "run_id": model["run_id"],
        "run_name": model["run_name"],
        "metric": model["metric"],
        "params_json": json.dumps(model["params"]),
        "model_uri": model["model_uri"],
        "registered": registered,
        "reason": reason
    }])
    
    spark.createDataFrame(df).write.format("delta").mode("append").saveAsTable(EVALUATION_LOG)

# MAIN EXECUTION

print("\nüöÄ Checking latest model run...")

latest_model = get_latest_best_run()

# Duplicate handling logic
if is_duplicate_model(latest_model):
    status = False
    reason = "‚ö† Duplicate model detected ‚Äî same params & same metric. NOT registered."
else:
    # Register new unique model
    client.create_model_version(
        name=MODEL_NAME,
        source=latest_model["model_uri"],
        run_id=latest_model["run_id"]
    )
    
    status = True
    reason = "‚úî Model successfully registered (unique combination)."


# Log result
log_registration(latest_model, status, reason)


print("\nüìç PROCESS COMPLETE")
print(f"‚û° REGISTERED: {status}")
print(f"‚û° REASON: {reason}")
print("="*70)
