In [None]:
# üéØ MODEL REGISTRATION SCRIPT 

import mlflow
from mlflow.tracking import MlflowClient
import sys
import yaml
import json
import requests
from typing import Dict, Optional, List
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, TimestampType, BooleanType
import pandas as pd

print("=" * 80)
print("üéØ MODEL REGISTRATION SYSTEM - AUTOMATED & DUPLICATE SAFE")
print("=" * 80)

# ---------------------- LOAD PIPELINE CONFIG ----------------------
try:
    with open("pipeline_config.yml", "r") as f:
        pipeline_cfg = yaml.safe_load(f)
    print("‚úÖ pipeline_config.yml loaded")
except Exception as e:
    print(f"‚ùå Failed to load config: {e}")
    sys.exit(1)


class Config:
    def __init__(self):
        MODEL_TYPE = pipeline_cfg["model"]["type"]
        UC_CATALOG = pipeline_cfg["model"]["catalog"]
        UC_SCHEMA = pipeline_cfg["model"]["schema"]
        BASE_NAME = pipeline_cfg["model"]["base_name"]

        self.MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{BASE_NAME}_{MODEL_TYPE}"
        self.EXPERIMENT_NAME = pipeline_cfg["experiment"]["name"]
        self.ARTIFACT_PATH = pipeline_cfg["experiment"]["artifact_path"]
        self.PRIMARY_METRIC = pipeline_cfg["metrics"]["classification"]["primary_metric"]

        self.TOLERANCE = pipeline_cfg["registry"]["duplicate_detection"]["tolerance"]
        self.METRICS_TO_COMPARE = pipeline_cfg["registry"]["duplicate_detection"]["metrics_to_compare"]
        self.DUPLICATE_CHECK_ENABLED = pipeline_cfg["registry"]["duplicate_detection"]["enabled"]

        self.REGISTRATION_LOG_TABLE = pipeline_cfg["tables"]["registration_log"]

        self.SLACK_WEBHOOK = None
        try:
            self.SLACK_WEBHOOK = dbutils.secrets.get("shared-scope", "SLACK_WEBHOOK_URL")
            print("üîê Slack webhook loaded")
        except:
            print("‚ö† Slack webhook NOT configured (Optional)")


config = Config()

print(f"\nüìå Model Registry: {config.MODEL_NAME}")
print(f"üìå Duplicate Logic: {'ENABLED' if config.DUPLICATE_CHECK_ENABLED else 'DISABLED'}")
print(f"üìå Primary Metric: {config.PRIMARY_METRIC}")
print("=" * 80)


# ---------------------- SLACK NOTIFIER ----------------------
class SlackNotifier:
    def __init__(self, webhook_url: Optional[str]):
        self.webhook_url = webhook_url

    def send(self, message: str, level: str = "info"):
        if not self.webhook_url:
            return
        emoji = {"info": "‚ÑπÔ∏è", "success": "‚úÖ", "warning": "‚ö†Ô∏è", "error": "‚ùå"}.get(level, "‚ÑπÔ∏è")
        payload = {"text": f"{emoji} {message}"}

        try:
            requests.post(self.webhook_url, json=payload, timeout=5)
        except:
            pass


slack = SlackNotifier(config.SLACK_WEBHOOK)


# ---------------------- INIT SPARK + MLFLOW ----------------------
spark = SparkSession.builder.appName("ModelRegistration").getOrCreate()
mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()

experiment = mlflow.get_experiment_by_name(config.EXPERIMENT_NAME)
if experiment is None:
    raise Exception(f"Experiment '{config.EXPERIMENT_NAME}' not found!")


# ---------------------- TABLE SCHEMA (FIXED) ----------------------
def get_table_schema():
    """Define fixed schema for registration log table"""
    return StructType([
        StructField("timestamp", TimestampType(), True),
        StructField("run_id", StringType(), True),
        StructField("run_name", StringType(), True),
        StructField("model_name", StringType(), True),
        StructField("primary_metric", StringType(), True),
        StructField("primary_metric_value", DoubleType(), True),
        StructField("metrics_json", StringType(), True),
        StructField("params_json", StringType(), True),
        StructField("registered", BooleanType(), True),
        StructField("registered_version", StringType(), True),
        StructField("reason", StringType(), True)
    ])


# ---------------------- TABLE CREATION ----------------------
def ensure_table_exists():
    try:
        spark.sql(f"DESCRIBE TABLE {config.REGISTRATION_LOG_TABLE}")
        print(f"üìå Table exists: {config.REGISTRATION_LOG_TABLE}")
    except:
        print(f"üÜï Creating Delta table: {config.REGISTRATION_LOG_TABLE}")
        schema = get_table_schema()
        empty_df = spark.createDataFrame([], schema)
        empty_df.write.format("delta").option("overwriteSchema", "true").saveAsTable(config.REGISTRATION_LOG_TABLE)
        print(f"‚úÖ Table created: {config.REGISTRATION_LOG_TABLE}")


# ---------------------- FETCH RUNS ----------------------
def get_all_runs() -> List[Dict]:
    print("\nüìç Fetching experiment runs...")
    runs = client.search_runs(
        [experiment.experiment_id],
        order_by=[f"metrics.{config.PRIMARY_METRIC} DESC"],
        max_results=500
    )

    return [{
        "run_id": run.info.run_id,
        "run_name": run.info.run_name or "unnamed_run",
        "metrics": {m: run.data.metrics.get(m) for m in config.METRICS_TO_COMPARE if m in run.data.metrics},
        "params": run.data.params,
        "primary_metric": run.data.metrics.get(config.PRIMARY_METRIC),
        "model_uri": f"runs:/{run.info.run_id}/{config.ARTIFACT_PATH}"
    } for run in runs]


# ---------------------- DUPLICATE CHECK ----------------------
def is_duplicate_model(new_model: Dict) -> bool:
    if not config.DUPLICATE_CHECK_ENABLED:
        return False

    try:
        versions = client.search_model_versions(f"name='{config.MODEL_NAME}'")
    except:
        return False

    for version in versions:
        try:
            run = client.get_run(version.run_id)

            metrics_match = all(
                abs((run.data.metrics.get(m) or 0) - (new_model["metrics"].get(m) or 0)) <= config.TOLERANCE
                for m in config.METRICS_TO_COMPARE
            )

            if metrics_match and run.data.params == new_model["params"]:
                print(f"   ‚ö† Duplicate detected ‚Üí Matches v{version.version}")
                return True
        except:
            continue

    return False


# ---------------------- CHECK IF ALREADY LOGGED ----------------------
def is_already_logged(run_id: str) -> bool:
    try:
        existing = spark.sql(f"""
            SELECT run_id 
            FROM {config.REGISTRATION_LOG_TABLE} 
            WHERE run_id = '{run_id}'
            LIMIT 1
        """).count()

        return existing > 0
    except:
        return False


# ---------------------- REGISTER MODEL ----------------------
def register_model(model: Dict):
    if is_duplicate_model(model):
        return None

    print(f"   üîÑ Registering model...")
    reg = mlflow.register_model(model["model_uri"], config.MODEL_NAME)
    version = reg.version

    try:
        client.set_model_version_tag(config.MODEL_NAME, version, "run_id", model["run_id"])
        client.set_model_version_tag(config.MODEL_NAME, version, "run_name", model["run_name"])
        client.set_model_version_tag(config.MODEL_NAME, version, "primary_metric", config.PRIMARY_METRIC)
        client.set_model_version_tag(config.MODEL_NAME, version, "primary_metric_value", str(round(model["primary_metric"], 4)))
        client.set_model_version_tag(config.MODEL_NAME, version, "registered_timestamp", datetime.now().isoformat())
    except Exception as e:
        print(f"   ‚ö† Failed to set tags: {e}")

    print(f"   ‚úÖ Registered as version: {version}")
    return version


# ---------------------- LOG DECISION ----------------------
def log_decision(model, registered, version, reason):

    if is_already_logged(model["run_id"]):
        print(f"   ‚ÑπÔ∏è Already logged in table, skipping duplicate entry")
        return

    version_str = str(version) if version is not None else "N/A"

    record = {
        "timestamp": datetime.now(),
        "run_id": model["run_id"],
        "run_name": model["run_name"],
        "model_name": config.MODEL_NAME,
        "primary_metric": config.PRIMARY_METRIC,
        "primary_metric_value": float(model["primary_metric"]) if model["primary_metric"] else 0.0,
        "metrics_json": json.dumps(model["metrics"]),
        "params_json": json.dumps(model["params"]),
        "registered": registered,
        "registered_version": version_str,
        "reason": reason
    }

    df = pd.DataFrame([record])
    spark_df = spark.createDataFrame(df, schema=get_table_schema())

    try:
        spark_df.write.format("delta").mode("append").saveAsTable(config.REGISTRATION_LOG_TABLE)
        print(f"   üìÑ Logged to: {config.REGISTRATION_LOG_TABLE}")
    except Exception as e:
        print(f"   ‚ö† Failed to log: {e}")


# ---------------------- MAIN ----------------------
def main():
    print("\nüöÄ Starting Model Registration Pipeline...\n")

    ensure_table_exists()

    runs = get_all_runs()
    if not runs:
        print("‚ùå No runs found!")
        return

    registered_count = 0
    skipped_count = 0
    processed_run_ids = set()

    for idx, model in enumerate(runs, start=1):

        if model['run_id'] in processed_run_ids:
            continue

        processed_run_ids.add(model['run_id'])

        print(f"\n{'='*80}")
        print(f"[{idx}/{len(runs)}] Processing: {model['run_name']}")
        print(f"   Primary Metric ({config.PRIMARY_METRIC}): {model['primary_metric']:.4f}")
        print(f"{'='*80}")

        # -------- FIXED BLOCK --------
        if is_already_logged(model['run_id']):
            print(f"   ‚è≠ Skipped ‚Äî Already processed earlier (Run ID: {model['run_id']})")
            continue

        version = register_model(model)
        # -------- FIXED BLOCK END ----

        if version:
            log_decision(model, True, version, "‚úî Registered successfully")
            slack.send(f"‚úÖ Registered: {config.MODEL_NAME} v{version}", "success")
            registered_count += 1
        else:
            log_decision(model, False, None, "‚ö† Duplicate - Skipped")
            slack.send(f"‚ö† Duplicate skipped: {model['run_name']}", "warning")
            skipped_count += 1

    print("\n" + "="*80)
    print("üéâ REGISTRATION PIPELINE COMPLETED")
    print("="*80)
    print(f"‚úÖ Models Registered: {registered_count}")
    print(f"‚ö† Duplicates Skipped: {skipped_count}")
    print(f"üìä Total Processed: {len(runs)}")
    print(f"üì¶ Model Name: {config.MODEL_NAME}")
    print("="*80)

    slack.send(f"üìÅ Registration complete: {registered_count} registered, {skipped_count} skipped", "info")


if __name__ == "__main__":
    main()
