In [None]:
# üéØ MODEL REGISTRATION SCRIPT - FIXED & CONFIG DRIVEN

import mlflow
from mlflow.tracking import MlflowClient
import sys
import yaml
import json
import requests
from typing import Dict, Optional, List
from datetime import datetime
from pyspark.sql import SparkSession
import pandas as pd

print("=" * 80)
print("üéØ MODEL REGISTRATION SYSTEM - AUTOMATED & DUPLICATE SAFE")
print("=" * 80)

# ---------------------- LOAD PIPELINE CONFIG ----------------------

try:
    with open("pipeline_config.yml", "r") as f:
        pipeline_cfg = yaml.safe_load(f)
    print("‚úÖ pipeline_config.yml loaded")
except Exception as e:
    print(f"‚ùå Failed to load config: {e}")
    sys.exit(1)


class Config:
    def __init__(self):
        MODEL_TYPE = pipeline_cfg["model"]["type"]
        UC_CATALOG = pipeline_cfg["model"]["catalog"]
        UC_SCHEMA = pipeline_cfg["model"]["schema"]
        BASE_NAME = pipeline_cfg["model"]["base_name"]

        self.MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{BASE_NAME}_{MODEL_TYPE}"
        self.EXPERIMENT_NAME = pipeline_cfg["experiment"]["name"]
        self.ARTIFACT_PATH = pipeline_cfg["experiment"]["artifact_path"]

        self.PRIMARY_METRIC = pipeline_cfg["metrics"]["classification"]["primary_metric"]
        
        # Duplicate check settings
        self.TOLERANCE = pipeline_cfg["registry"]["duplicate_detection"]["tolerance"]
        self.METRICS_TO_COMPARE = pipeline_cfg["registry"]["duplicate_detection"]["metrics_to_compare"]
        self.DUPLICATE_CHECK_ENABLED = pipeline_cfg["registry"]["duplicate_detection"]["enabled"]

        self.REGISTRATION_LOG_TABLE = pipeline_cfg["tables"]["registration_log"]

        # Slack config (optional)
        self.SLACK_WEBHOOK = None
        try:
            self.SLACK_WEBHOOK = dbutils.secrets.get("shared-scope", "SLACK_WEBHOOK_URL")
            print("üîê Slack webhook loaded")
        except:
            print("‚ö† Slack webhook NOT configured (Optional)")


config = Config()

print(f"\nüìå Model Registry: {config.MODEL_NAME}")
print(f"üìå Duplicate Logic: {'ENABLED' if config.DUPLICATE_CHECK_ENABLED else 'DISABLED'}")
print(f"üìå Primary Metric: {config.PRIMARY_METRIC}")
print("=" * 80)


# ---------------------- SLACK NOTIFIER (Optional) ----------------------

class SlackNotifier:
    def __init__(self, webhook_url: Optional[str]):
        self.webhook_url = webhook_url

    def send(self, message: str, level: str = "info"):
        if not self.webhook_url:
            return
        
        emoji = {"info":"‚ÑπÔ∏è","success":"‚úÖ","warning":"‚ö†Ô∏è","error":"‚ùå"}.get(level,"‚ÑπÔ∏è")
        payload = {"text": f"{emoji} {message}"}

        try:
            requests.post(self.webhook_url, json=payload, timeout=5)
        except Exception as e:
            print(f"‚ö† Slack send failed: {e}")

slack = SlackNotifier(config.SLACK_WEBHOOK)


# ---------------------- INIT MLFLOW + SPARK ----------------------

spark = SparkSession.builder.appName("ModelRegistration").getOrCreate()
mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()

experiment = mlflow.get_experiment_by_name(config.EXPERIMENT_NAME)
if experiment is None:
    raise Exception(f"Experiment '{config.EXPERIMENT_NAME}' not found!")


# ---------------------- AUTO TABLE CREATOR (NEW) ----------------------

def ensure_table_exists(table_name: str, df_schema, spark):
    """Create delta table automatically if not exists."""
    try:
        spark.sql(f"DESCRIBE TABLE {table_name}")
        print(f"üìå Table exists: {table_name}")
    except Exception:
        print(f"üÜï Creating new Delta table: {table_name}")
        (
            spark.createDataFrame([], df_schema)
            .write.format("delta")
            .option("overwriteSchema", "true")
            .saveAsTable(table_name)
        )
        print(f"‚úÖ Table created: {table_name}")


# ---------------------- FETCH ALL RUNS ----------------------

def get_all_runs() -> List[Dict]:
    print("\nüìç Fetching ALL experiment runs...")

    runs = client.search_runs(
        [experiment.experiment_id],
        order_by=[f"metrics.{config.PRIMARY_METRIC} DESC"],
        max_results=500
    )

    processed_runs = []

    for run in runs:
        all_metrics = {m: run.data.metrics.get(m) for m in config.METRICS_TO_COMPARE if m in run.data.metrics}
        model_uri = f"runs:/{run.info.run_id}/{config.ARTIFACT_PATH}"

        processed_runs.append({
            "run_id": run.info.run_id,
            "run_name": run.info.run_name or "unnamed_run",
            "metrics": all_metrics,
            "params": run.data.params,
            "primary_metric": run.data.metrics.get(config.PRIMARY_METRIC),
            "model_uri": model_uri
        })

    print(f"‚úÖ Found {len(processed_runs)} runs in experiment")
    return processed_runs


# ---------------------- DUPLICATE DETECTION ----------------------

def is_duplicate_model(new_model: Dict) -> bool:
    """Check if model with same metrics and params already exists"""
    
    if not config.DUPLICATE_CHECK_ENABLED:
        return False

    try:
        versions = client.search_model_versions(f"name='{config.MODEL_NAME}'")
    except:
        return False

    for version in versions:
        try:
            run = client.get_run(version.run_id)

            metric_match = True
            for m in config.METRICS_TO_COMPARE:
                old_val = run.data.metrics.get(m)
                new_val = new_model["metrics"].get(m)
                
                if old_val is not None and new_val is not None:
                    if abs(old_val - new_val) > config.TOLERANCE:
                        metric_match = False
                        break

            params_match = (new_model["params"] == run.data.params)

            if metric_match and params_match:
                print(f"   ‚ö† Duplicate detected ‚Üí Matches version v{version.version}")
                return True
        
        except Exception as e:
            print(f"   ‚ö† Error checking version {version.version}: {e}")
            continue
    
    return False


# ---------------------- REGISTER MODEL ----------------------

def register_model(model: Dict):
    """Register model and add all metadata as tags"""
    
    if is_duplicate_model(model):
        slack.send(f"‚ö† Duplicate skipped: {model['run_name']}", "warning")
        return None

    print(f"   üîÑ Registering model...")
    new_version = mlflow.register_model(model["model_uri"], config.MODEL_NAME)
    version = new_version.version

    print(f"   ‚úÖ Registered as version: {version}")
    slack.send(f"‚úÖ Registered: {config.MODEL_NAME} v{version}", "success")

    try:
        tags = {
            "run_name": model["run_name"],
            "run_id": model["run_id"],
            "primary_metric": config.PRIMARY_METRIC,
            "primary_metric_value": str(round(model["primary_metric"], 4)),
            "artifact_path": config.ARTIFACT_PATH,
            "registered_timestamp": datetime.now().isoformat(),
            "registered_by": "Automated Pipeline",
            "params": json.dumps(model["params"]),
            "metrics": json.dumps({k: round(v, 4) for k, v in model["metrics"].items() if v is not None})
        }

        for key, value in tags.items():
            client.set_model_version_tag(config.MODEL_NAME, version, key, value)

        client.update_model_version(
            name=config.MODEL_NAME,
            version=version,
            description=(
                f"üì¶ Automated Model Registration\n\n"
                f"üîπ **Run Name:** {model['run_name']}\n"
                f"üîπ **Primary Metric ({config.PRIMARY_METRIC}):** {model['primary_metric']:.4f}\n"
                f"üîπ **Run ID:** {model['run_id']}\n"
                f"üìÖ **Registered:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
                f"**Metrics:**\n"
                + "\n".join([f"  ‚Ä¢ {k}: {v:.4f}" for k, v in model["metrics"].items() if v is not None])
            )
        )

        print(f"   üè∑ Metadata and tags updated")

    except Exception as e:
        print(f"   ‚ö† Failed to set tags: {e}")

    return version


# ---------------------- LOG DECISION TO TABLE (UPDATED - NO DUPLICATES) ----------------------

from delta.tables import DeltaTable

def log_decision(model, registered, version, reason):
    """Log registration decision to Delta table with MERGE to avoid duplicates"""

    try:
        df = pd.DataFrame([{
            "timestamp": datetime.now(),
            "run_id": model["run_id"],
            "run_name": model["run_name"],
            "model_name": config.MODEL_NAME,
            "primary_metric": config.PRIMARY_METRIC,
            "primary_metric_value": model["primary_metric"],
            "metrics_json": json.dumps(model["metrics"]),
            "params_json": json.dumps(model["params"]),
            "registered": registered,
            "registered_version": version,
            "reason": reason
        }])

        spark_df = spark.createDataFrame(df)

        ensure_table_exists(config.REGISTRATION_LOG_TABLE, spark_df.schema, spark)

        try:
            delta_table = DeltaTable.forName(spark, config.REGISTRATION_LOG_TABLE)

            delta_table.alias("t").merge(
                spark_df.alias("s"),
                """
                t.run_id = s.run_id
                AND t.model_name = s.model_name
                AND t.registered_version = s.registered_version
                """
            ) \
            .whenMatchedUpdateAll() \
            .whenNotMatchedInsertAll() \
            .execute()

            print(f"üîÅ MERGE successful ‚Üí No duplicate row added.")

        except Exception:
            spark_df.write.format("delta").mode("append").saveAsTable(config.REGISTRATION_LOG_TABLE)

        print(f"   üìÑ Logged to: {config.REGISTRATION_LOG_TABLE}")

    except Exception as e:
        print(f"   ‚ö† Failed to log decision: {e}")


# ---------------------- MAIN EXECUTION ----------------------

def main():
    print("\nüöÄ Starting Model Registration Pipeline...\n")
    
    runs = get_all_runs()

    if not runs:
        print("‚ùå No runs found in experiment!")
        return

    registered_count = 0
    skipped_count = 0

    for idx, model in enumerate(runs, start=1):
        print(f"\n{'='*80}")
        print(f"[{idx}/{len(runs)}] Processing: {model['run_name']}")
        print(f"   Primary Metric ({config.PRIMARY_METRIC}): {model['primary_metric']:.4f}")
        print(f"{'='*80}")

        version = register_model(model)

        if version:
            log_decision(model, True, version, "‚úî Registered successfully")
            registered_count += 1
        else:
            log_decision(model, False, None, "‚ö† Duplicate - Skipped")
            skipped_count += 1

    print("\n" + "="*80)
    print("üéâ REGISTRATION PIPELINE COMPLETED")
    print("="*80)
    print(f"‚úÖ Models Registered: {registered_count}")
    print(f"‚ö† Duplicates Skipped: {skipped_count}")
    print(f"üìä Total Processed: {len(runs)}")
    print(f"üì¶ Model Name: {config.MODEL_NAME}")
    print("="*80)

    slack.send(f"üìÅ Registration complete: {registered_count} registered, {skipped_count} skipped", "info")


if __name__ == "__main__":
    main()
