In [None]:
# üéØ MODEL REGISTRATION SCRIPT - MULTI MODEL (FIXED)

import mlflow
from mlflow.tracking import MlflowClient
import sys
import yaml
import os
import json
import requests
from typing import Dict, Optional, List
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, TimestampType, BooleanType
import pandas as pd

print("=" * 80)
print("üéØ MODEL REGISTRATION SYSTEM - MULTI MODEL + AUTOMATED")
print("=" * 80)

# ---------------------- LOAD CONFIG FILES ----------------------
try:
    with open("pipeline_config.yml", "r") as f:
        pipeline_cfg = yaml.safe_load(f)
    
    with open("experiments_config.yml", "r") as f:
        experiments_cfg = yaml.safe_load(f)
    
    print("‚úÖ Configuration files loaded\n")
except Exception as e:
    print(f"‚ùå Failed to load config: {e}")
    sys.exit(1)


# ‚úÖ ADD WIDGET LOGIC
try:
    dbutils.widgets.text("MODELS_TO_TRAIN", "", "Models to Register")
    dbutils.widgets.text("environment", "development", "Environment")
except:
    pass

# ‚úÖ MODEL SHORT NAME MAPPING (same as training script)
MODEL_SHORT_NAMES = {
    "random_forest": "RF",
    "xgboost": "XGB",
    "logistic_regression": "LR",
    "gradient_boosting": "GB",
    "decision_tree": "DT",
    "svm": "SVM",
    "naive_bayes": "NB",
    "knn": "KNN"
}

def get_model_short_name(model_type):
    """
    ‚úÖ Dynamically generate short name if not in mapping
    Same logic as training script
    """
    if model_type in MODEL_SHORT_NAMES:
        return MODEL_SHORT_NAMES[model_type]
    else:
        words = model_type.split("_")
        return "".join([w[0].upper() for w in words if w])


def get_models_to_register():
    """
    Get models to register from widget or environment variable
    Matches the logic from training script
    """
    available_models = list(experiments_cfg.get("models", {}).keys())
    
    if not available_models:
        raise ValueError("‚ùå No models defined in experiments_config.yml")
    
    value = None
    try:
        value = dbutils.widgets.get("MODELS_TO_TRAIN")
        print(f"üìå MODELS_TO_TRAIN from Widget: '{value}'")
    except:
        value = os.getenv("MODELS_TO_TRAIN", "")
        print(f"üìå MODELS_TO_TRAIN from ENV: '{value}'")
    
    if value:
        value = value.strip()
    
    if not value or value == "" or value.lower() in ["none", "null", "undefined"]:
        raise ValueError(
            f"‚ùå MODELS_TO_TRAIN is not set!\n"
            f"   Available models: {available_models}\n"
            f"   Current value: '{value}'"
        )
    
    if value.lower() == "all":
        print(f"‚úÖ Registering ALL models: {available_models}")
        return available_models
    
    models = [m.strip() for m in value.split(",") if m.strip()]
    
    if not models:
        raise ValueError(f"‚ùå No valid models found in MODELS_TO_TRAIN='{value}'")
    
    invalid_models = [m for m in models if m not in available_models]
    
    if invalid_models:
        raise ValueError(
            f"‚ùå Invalid model names: {invalid_models}\n"
            f"   Available: {available_models}"
        )
    
    print(f"‚úÖ Models to register: {models}")
    return models


# ---------------------- CONFIGURATION CLASS ----------------------
class Config:
    def __init__(self, model_type: str):
        """
        Initialize config for a specific model type
        Args:
            model_type: e.g., 'random_forest', 'xgboost'
        """
        UC_CATALOG = pipeline_cfg["models"]["catalog"]
        UC_SCHEMA = pipeline_cfg["models"]["schema"]
        BASE_NAME = pipeline_cfg["models"]["base_name"]
        NAMING_FMT = pipeline_cfg["models"]["naming"]["format"]

        self.MODEL_NAME = NAMING_FMT.format(
            catalog=UC_CATALOG,
            schema=UC_SCHEMA,
            base_name=BASE_NAME,
            model_type=model_type
        )
        
        self.MODEL_TYPE = model_type
        
        # ‚úÖ CRITICAL FIX: Use model-specific experiment name
        BASE_EXPERIMENT_NAME = pipeline_cfg["experiment"]["name"]
        model_short = get_model_short_name(model_type)
        self.EXPERIMENT_NAME = f"{BASE_EXPERIMENT_NAME}_{model_short}"
        
        self.ARTIFACT_PATH = pipeline_cfg["experiment"]["artifact_path"]
        self.PRIMARY_METRIC = pipeline_cfg["metrics"]["classification"]["primary_metric"]

        self.TOLERANCE = pipeline_cfg["registry"]["duplicate_detection"]["tolerance"]
        self.METRICS_TO_COMPARE = pipeline_cfg["registry"]["duplicate_detection"]["metrics_to_compare"]
        self.DUPLICATE_CHECK_ENABLED = pipeline_cfg["registry"]["duplicate_detection"]["enabled"]

        self.REGISTRATION_LOG_TABLE = pipeline_cfg["tables"]["registration_log"]

        self.SLACK_WEBHOOK = None
        try:
            self.SLACK_WEBHOOK = dbutils.secrets.get("shared-scope", "SLACK_WEBHOOK_URL")
            print(f"   üîê Slack webhook loaded for {model_type}")
        except:
            pass


# ---------------------- SLACK NOTIFIER ----------------------
class SlackNotifier:
    def __init__(self, webhook_url: Optional[str]):
        self.webhook_url = webhook_url

    def send(self, message: str, level: str = "info"):
        if not self.webhook_url:
            return
        emoji = {"info": "‚ÑπÔ∏è", "success": "‚úÖ", "warning": "‚ö†Ô∏è", "error": "‚ùå"}.get(level, "‚ÑπÔ∏è")
        payload = {"text": f"{emoji} {message}"}

        try:
            requests.post(self.webhook_url, json=payload, timeout=5)
        except:
            pass


# ---------------------- INIT SPARK + MLFLOW ----------------------
spark = SparkSession.builder.appName("ModelRegistrationMultiModel").getOrCreate()
mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()


# ---------------------- TABLE SCHEMA ----------------------
def get_table_schema():
    """Define fixed schema for registration log table"""
    return StructType([
        StructField("timestamp", TimestampType(), True),
        StructField("run_id", StringType(), True),
        StructField("run_name", StringType(), True),
        StructField("model_type", StringType(), True),
        StructField("model_name", StringType(), True),
        StructField("primary_metric", StringType(), True),
        StructField("primary_metric_value", DoubleType(), True),
        StructField("metrics_json", StringType(), True),
        StructField("params_json", StringType(), True),
        StructField("registered", BooleanType(), True),
        StructField("registered_version", StringType(), True),
        StructField("reason", StringType(), True)
    ])


# ---------------------- TABLE CREATION ----------------------
def ensure_table_exists(table_name: str):
    try:
        spark.sql(f"DESCRIBE TABLE {table_name}")
        print(f"   ‚úÖ Table exists: {table_name}")
    except:
        print(f"   üÜï Creating Delta table: {table_name}")
        schema = get_table_schema()
        empty_df = spark.createDataFrame([], schema)
        empty_df.write.format("delta").option("overwriteSchema", "true").saveAsTable(table_name)
        print(f"   ‚úÖ Table created: {table_name}")


# ---------------------- FETCH RUNS FOR MODEL TYPE ----------------------
def get_runs_for_model(config: Config) -> List[Dict]:
    """
    ‚úÖ FIXED: Fetch runs from model-specific experiment
    """
    print(f"   üìç Searching experiment: {config.EXPERIMENT_NAME}")
    
    # ‚úÖ Get the correct experiment for this model type
    try:
        experiment = mlflow.get_experiment_by_name(config.EXPERIMENT_NAME)
        if experiment is None:
            print(f"   ‚ö†Ô∏è  Experiment not found: {config.EXPERIMENT_NAME}")
            return []
    except Exception as e:
        print(f"   ‚ùå Error getting experiment: {e}")
        return []
    
    print(f"   üî¨ Experiment ID: {experiment.experiment_id}")
    
    runs = client.search_runs(
        [experiment.experiment_id],
        order_by=[f"metrics.{config.PRIMARY_METRIC} DESC"],
        max_results=500
    )

    # ‚úÖ Filter by model type (extra safety check)
    filtered_runs = [
        run for run in runs
        if config.MODEL_TYPE in (run.info.run_name or "")
    ]
    
    print(f"   üîç Found {len(filtered_runs)} runs for {config.MODEL_TYPE}")

    return [{
        "run_id": run.info.run_id,
        "run_name": run.info.run_name or "unnamed_run",
        "metrics": {m: run.data.metrics.get(m) for m in config.METRICS_TO_COMPARE if m in run.data.metrics},
        "params": run.data.params,
        "primary_metric": run.data.metrics.get(config.PRIMARY_METRIC),
        "model_uri": f"runs:/{run.info.run_id}/{config.ARTIFACT_PATH}"
    } for run in filtered_runs]


# ---------------------- DUPLICATE CHECK ----------------------
def is_duplicate_model(new_model: Dict, config: Config) -> bool:
    """
    ‚úÖ ENHANCED: Check if model with similar metrics and params already exists
    
    Compares:
    1. ALL metrics from config.METRICS_TO_COMPARE
    2. ALL parameters
    3. Uses tolerance threshold from config
    
    Returns True if duplicate found (skip registration)
    """
    if not config.DUPLICATE_CHECK_ENABLED:
        print(f"      ‚ÑπÔ∏è  Duplicate detection: DISABLED")
        return False

    try:
        versions = client.search_model_versions(f"name='{config.MODEL_NAME}'")
        
        if not versions:
            print(f"      ‚ÑπÔ∏è  No existing versions found - will register")
            return False
            
    except Exception as e:
        print(f"      ‚ö†Ô∏è  Could not fetch versions: {e}")
        return False

    print(f"      üîç Checking against {len(versions)} existing version(s)...")

    for version in versions:
        try:
            # Get the run details for this registered version
            run = client.get_run(version.run_id)
            
            # ‚úÖ Check ALL metrics from config
            metrics_differences = {}
            all_metrics_match = True
            
            for metric_name in config.METRICS_TO_COMPARE:
                existing_value = run.data.metrics.get(metric_name, 0)
                new_value = new_model["metrics"].get(metric_name, 0)
                difference = abs(existing_value - new_value)
                
                metrics_differences[metric_name] = {
                    "existing": existing_value,
                    "new": new_value,
                    "diff": difference,
                    "within_tolerance": difference <= config.TOLERANCE
                }
                
                if difference > config.TOLERANCE:
                    all_metrics_match = False
            
            # ‚úÖ Check ALL parameters
            params_match = run.data.params == new_model["params"]
            
            # ‚úÖ If BOTH metrics and params match ‚Üí Duplicate!
            if all_metrics_match and params_match:
                print(f"      ‚ö†Ô∏è  DUPLICATE DETECTED ‚Üí Matches existing v{version.version}")
                print(f"         Metrics comparison (tolerance: {config.TOLERANCE}):")
                for metric_name, diff_info in metrics_differences.items():
                    status = "‚úì" if diff_info["within_tolerance"] else "‚úó"
                    print(f"           {status} {metric_name}: {diff_info['existing']:.4f} vs {diff_info['new']:.4f} (diff: {diff_info['diff']:.4f})")
                print(f"         ‚úì Parameters: Exact match")
                return True
            
            # Show why it's NOT a duplicate (for debugging)
            elif not all_metrics_match and not params_match:
                print(f"      ‚úì NOT duplicate of v{version.version} (metrics AND params differ)")
            elif not all_metrics_match:
                print(f"      ‚úì NOT duplicate of v{version.version} (metrics differ beyond tolerance)")
            elif not params_match:
                print(f"      ‚úì NOT duplicate of v{version.version} (parameters differ)")
                
        except Exception as e:
            print(f"      ‚ö†Ô∏è  Error checking v{version.version}: {e}")
            continue

    print(f"      ‚úÖ No duplicates found - will register as new version")
    return False


# ---------------------- CHECK IF ALREADY LOGGED ----------------------
def is_already_logged(run_id: str, table_name: str) -> bool:
    """Check if this run_id is already in the registration log"""
    try:
        existing = spark.sql(f"""
            SELECT run_id 
            FROM {table_name} 
            WHERE run_id = '{run_id}'
            LIMIT 1
        """).count()

        return existing > 0
    except:
        return False


# ---------------------- REGISTER MODEL ----------------------
def register_model(model: Dict, config: Config):
    """
    Register model to Unity Catalog Model Registry
    Note: Duplicate check is done BEFORE calling this function
    """
    print(f"      üîÑ Registering to: {config.MODEL_NAME}")
    
    try:
        reg = mlflow.register_model(model["model_uri"], config.MODEL_NAME)
        version = reg.version
        print(f"      ‚úÖ Registered as version: {version}")
    except Exception as e:
        print(f"      ‚ùå Registration failed: {e}")
        return None

    try:
        client.set_model_version_tag(config.MODEL_NAME, version, "run_id", model["run_id"])
        client.set_model_version_tag(config.MODEL_NAME, version, "run_name", model["run_name"])
        client.set_model_version_tag(config.MODEL_NAME, version, "model_type", config.MODEL_TYPE)
        client.set_model_version_tag(config.MODEL_NAME, version, "primary_metric", config.PRIMARY_METRIC)
        client.set_model_version_tag(config.MODEL_NAME, version, "primary_metric_value", str(round(model["primary_metric"], 4)))
        client.set_model_version_tag(config.MODEL_NAME, version, "registered_timestamp", datetime.now().isoformat())
        
        # ‚úÖ Add all compared metrics as tags
        for metric_name in config.METRICS_TO_COMPARE:
            if metric_name in model["metrics"]:
                metric_value = model["metrics"][metric_name]
                client.set_model_version_tag(
                    config.MODEL_NAME, 
                    version, 
                    f"metric_{metric_name}", 
                    str(round(metric_value, 4))
                )
        
        print(f"      ‚úÖ Tags added successfully")
    except Exception as e:
        print(f"      ‚ö†Ô∏è  Failed to set tags: {e}")

    return version


# ---------------------- LOG DECISION ----------------------
def log_decision(model: Dict, config: Config, registered: bool, version: Optional[int], reason: str):
    """Log registration decision to Delta table"""

    if is_already_logged(model["run_id"], config.REGISTRATION_LOG_TABLE):
        print(f"      ‚ÑπÔ∏è  Already logged in table, skipping duplicate entry")
        return

    version_str = str(version) if version is not None else "N/A"

    record = {
        "timestamp": datetime.now(),
        "run_id": model["run_id"],
        "run_name": model["run_name"],
        "model_type": config.MODEL_TYPE,
        "model_name": config.MODEL_NAME,
        "primary_metric": config.PRIMARY_METRIC,
        "primary_metric_value": float(model["primary_metric"]) if model["primary_metric"] else 0.0,
        "metrics_json": json.dumps(model["metrics"]),
        "params_json": json.dumps(model["params"]),
        "registered": registered,
        "registered_version": version_str,
        "reason": reason
    }

    df = pd.DataFrame([record])
    spark_df = spark.createDataFrame(df, schema=get_table_schema())

    try:
        # ‚úÖ CRITICAL FIX: Add mergeSchema option
        spark_df.write \
            .format("delta") \
            .mode("append") \
            .option("mergeSchema", "true") \
            .saveAsTable(config.REGISTRATION_LOG_TABLE)
        print(f"      üìÑ Logged to: {config.REGISTRATION_LOG_TABLE}")
    except Exception as e:
        print(f"      ‚ö†Ô∏è  Failed to log: {e}")


# ---------------------- PROCESS SINGLE MODEL TYPE ----------------------
def process_model_type(model_type: str, slack: SlackNotifier) -> Dict:
    """
    Process registration for a single model type
    Returns: dict with counts
    """
    print(f"\n{'='*80}")
    print(f"üöÄ PROCESSING MODEL TYPE: {model_type.upper()}")
    print(f"{'='*80}\n")
    
    config = Config(model_type)
    
    print(f"üì¶ Model Registry Name: {config.MODEL_NAME}")
    print(f"üî¨ Experiment Name: {config.EXPERIMENT_NAME}")
    print(f"üîç Primary Metric: {config.PRIMARY_METRIC}")
    print(f"üõ°Ô∏è  Duplicate Check: {'ENABLED' if config.DUPLICATE_CHECK_ENABLED else 'DISABLED'}\n")
    
    ensure_table_exists(config.REGISTRATION_LOG_TABLE)
    
    runs = get_runs_for_model(config)
    
    if not runs:
        print(f"   ‚ö†Ô∏è  No runs found for {model_type}")
        return {"registered": 0, "skipped": 0, "total": 0}
    
    registered_count = 0
    skipped_count = 0
    processed_run_ids = set()
    
    for idx, model in enumerate(runs, start=1):
        
        if model['run_id'] in processed_run_ids:
            continue
        
        processed_run_ids.add(model['run_id'])
        
        print(f"\n   [{idx}/{len(runs)}] Processing: {model['run_name']}")
        print(f"      Run ID: {model['run_id']}")
        print(f"      Primary Metric ({config.PRIMARY_METRIC}): {model['primary_metric']:.4f}")
        
        # ‚úÖ Show all metrics being compared
        print(f"      Metrics to compare:")
        for metric_name in config.METRICS_TO_COMPARE:
            metric_value = model["metrics"].get(metric_name, 0)
            print(f"         ‚Ä¢ {metric_name}: {metric_value:.4f}")
        
        if is_already_logged(model['run_id'], config.REGISTRATION_LOG_TABLE):
            print(f"      ‚è≠Ô∏è  Skipped ‚Äî Already logged in registration table")
            skipped_count += 1
            continue
        
        # ‚úÖ CRITICAL: This checks for duplicates in Registry
        if is_duplicate_model(model, config):
            log_decision(model, config, False, None, "‚ö† Duplicate metrics+params - Skipped")
            slack.send(f"‚ö†Ô∏è Duplicate skipped: {model['run_name']}", "warning")
            skipped_count += 1
            continue
        
        # ‚úÖ Not a duplicate, proceed with registration
        version = register_model(model, config)
        
        if version:
            log_decision(model, config, True, version, "‚úî Registered successfully")
            slack.send(f"‚úÖ Registered: {config.MODEL_NAME} v{version} ({model['run_name']})", "success")
            registered_count += 1
        else:
            # This should not happen now, but keep as fallback
            log_decision(model, config, False, None, "‚ö† Registration failed")
            slack.send(f"‚ö†Ô∏è Registration failed: {model['run_name']}", "warning")
            skipped_count += 1
    
    print(f"\n{'='*80}")
    print(f"‚úÖ {model_type.upper()} REGISTRATION COMPLETE")
    print(f"{'='*80}")
    print(f"   ‚úÖ Registered: {registered_count}")
    print(f"   ‚ö†Ô∏è  Skipped: {skipped_count}")
    print(f"   üìä Total: {len(runs)}")
    
    return {
        "registered": registered_count,
        "skipped": skipped_count,
        "total": len(runs)
    }


# ---------------------- MAIN ----------------------
def main():
    """
    Main registration pipeline:
    1. Parse model types from widget/environment
    2. For each model type, process all runs from correct experiment
    3. Register unique models to Unity Catalog
    """
    print("\nüöÄ Starting Multi-Model Registration Pipeline...\n")
    
    try:
        MODEL_TYPES = get_models_to_register()
        print(f"\nüìã Models to register: {MODEL_TYPES}\n")
    except ValueError as e:
        print(str(e))
        dbutils.notebook.exit("FAILED: Invalid MODELS_TO_TRAIN configuration")
        return
    
    slack = SlackNotifier(None)
    
    total_stats = {
        "registered": 0,
        "skipped": 0,
        "total": 0
    }
    
    results_by_model = {}
    
    for model_type in MODEL_TYPES:
        try:
            stats = process_model_type(model_type, slack)
            results_by_model[model_type] = stats
            
            total_stats["registered"] += stats["registered"]
            total_stats["skipped"] += stats["skipped"]
            total_stats["total"] += stats["total"]
            
        except Exception as e:
            print(f"\n‚ùå Error processing {model_type}: {e}")
            results_by_model[model_type] = {"error": str(e)}
            continue
    
    # ---------------------- FINAL SUMMARY ----------------------
    print("\n" + "="*80)
    print("üéâ ALL MODELS REGISTRATION COMPLETED")
    print("="*80)
    
    print("\nüìä Summary by Model Type:")
    print("-" * 80)
    for model_type, stats in results_by_model.items():
        if "error" in stats:
            print(f"   ‚ùå {model_type}: {stats['error']}")
        else:
            print(f"   {model_type}:")
            print(f"      ‚úÖ Registered: {stats['registered']}")
            print(f"      ‚ö†Ô∏è  Skipped: {stats['skipped']}")
            print(f"      üìä Total: {stats['total']}")
    
    print("\n" + "="*80)
    print("üìà Overall Statistics:")
    print("="*80)
    print(f"   ‚úÖ Total Registered: {total_stats['registered']}")
    print(f"   ‚ö†Ô∏è  Total Skipped: {total_stats['skipped']}")
    print(f"   üìä Total Processed: {total_stats['total']}")
    print(f"   üì¶ Registration Log: {pipeline_cfg['tables']['registration_log']}")
    print("="*80)

if __name__ == "__main__":
    main()