In [None]:
# üéØ MODEL REGISTRATION SCRIPT - MULTI MODEL

import mlflow
from mlflow.tracking import MlflowClient
import sys
import yaml
import json
import requests
from typing import Dict, Optional, List
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, TimestampType, BooleanType
import pandas as pd

print("=" * 80)
print("üéØ MODEL REGISTRATION SYSTEM - MULTI MODEL + AUTOMATED")
print("=" * 80)

# ---------------------- LOAD CONFIG FILES ----------------------
try:
    # Load pipeline config
    with open("pipeline_config.yml", "r") as f:
        pipeline_cfg = yaml.safe_load(f)
    
    # Load experiments config (for reference if needed)
    with open("experiments_config.yml", "r") as f:
        experiments_cfg = yaml.safe_load(f)
    
    print("‚úÖ Configuration files loaded\n")
except Exception as e:
    print(f"‚ùå Failed to load config: {e}")
    sys.exit(1)


# ---------------------- CONFIGURATION CLASS ----------------------
class Config:
    def __init__(self, model_type: str):
        """
        Initialize config for a specific model type
        Args:
            model_type: e.g., 'random_forest', 'xgboost'
        """
        # ‚úÖ CHANGED: Use models section instead of model section
        UC_CATALOG = pipeline_cfg["models"]["catalog"]
        UC_SCHEMA = pipeline_cfg["models"]["schema"]
        BASE_NAME = pipeline_cfg["models"]["base_name"]
        NAMING_FMT = pipeline_cfg["models"]["naming"]["format"]

        # Generate model name dynamically
        self.MODEL_NAME = NAMING_FMT.format(
            catalog=UC_CATALOG,
            schema=UC_SCHEMA,
            base_name=BASE_NAME,
            model_type=model_type
        )
        
        self.MODEL_TYPE = model_type
        self.EXPERIMENT_NAME = pipeline_cfg["experiment"]["name"]
        self.ARTIFACT_PATH = pipeline_cfg["experiment"]["artifact_path"]
        self.PRIMARY_METRIC = pipeline_cfg["metrics"]["classification"]["primary_metric"]

        self.TOLERANCE = pipeline_cfg["registry"]["duplicate_detection"]["tolerance"]
        self.METRICS_TO_COMPARE = pipeline_cfg["registry"]["duplicate_detection"]["metrics_to_compare"]
        self.DUPLICATE_CHECK_ENABLED = pipeline_cfg["registry"]["duplicate_detection"]["enabled"]

        self.REGISTRATION_LOG_TABLE = pipeline_cfg["tables"]["registration_log"]

        # Slack webhook (optional)
        self.SLACK_WEBHOOK = None
        try:
            self.SLACK_WEBHOOK = dbutils.secrets.get("shared-scope", "SLACK_WEBHOOK_URL")
            print(f"   üîê Slack webhook loaded for {model_type}")
        except:
            pass  # Silent fail - Slack is optional


# ---------------------- SLACK NOTIFIER ----------------------
class SlackNotifier:
    def __init__(self, webhook_url: Optional[str]):
        self.webhook_url = webhook_url

    def send(self, message: str, level: str = "info"):
        if not self.webhook_url:
            return
        emoji = {"info": "‚ÑπÔ∏è", "success": "‚úÖ", "warning": "‚ö†Ô∏è", "error": "‚ùå"}.get(level, "‚ÑπÔ∏è")
        payload = {"text": f"{emoji} {message}"}

        try:
            requests.post(self.webhook_url, json=payload, timeout=5)
        except:
            pass


# ---------------------- INIT SPARK + MLFLOW ----------------------
spark = SparkSession.builder.appName("ModelRegistrationMultiModel").getOrCreate()
mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()

experiment = mlflow.get_experiment_by_name(pipeline_cfg["experiment"]["name"])
if experiment is None:
    raise Exception(f"Experiment '{pipeline_cfg['experiment']['name']}' not found!")


# ---------------------- TABLE SCHEMA ----------------------
def get_table_schema():
    """Define fixed schema for registration log table"""
    return StructType([
        StructField("timestamp", TimestampType(), True),
        StructField("run_id", StringType(), True),
        StructField("run_name", StringType(), True),
        StructField("model_type", StringType(), True),  # ‚úÖ NEW: Track model type
        StructField("model_name", StringType(), True),
        StructField("primary_metric", StringType(), True),
        StructField("primary_metric_value", DoubleType(), True),
        StructField("metrics_json", StringType(), True),
        StructField("params_json", StringType(), True),
        StructField("registered", BooleanType(), True),
        StructField("registered_version", StringType(), True),
        StructField("reason", StringType(), True)
    ])


# ---------------------- TABLE CREATION ----------------------
def ensure_table_exists(table_name: str):
    try:
        spark.sql(f"DESCRIBE TABLE {table_name}")
        print(f"   ‚úÖ Table exists: {table_name}")
    except:
        print(f"   üÜï Creating Delta table: {table_name}")
        schema = get_table_schema()
        empty_df = spark.createDataFrame([], schema)
        empty_df.write.format("delta").option("overwriteSchema", "true").saveAsTable(table_name)
        print(f"   ‚úÖ Table created: {table_name}")


# ---------------------- FETCH RUNS FOR MODEL TYPE ----------------------
def get_runs_for_model(config: Config) -> List[Dict]:
    """
    Fetch runs for a specific model type from MLflow experiment
    """
    print(f"   üìç Fetching runs for {config.MODEL_TYPE}...")
    
    runs = client.search_runs(
        [experiment.experiment_id],
        order_by=[f"metrics.{config.PRIMARY_METRIC} DESC"],
        max_results=500
    )

    # ‚úÖ CHANGED: Filter runs by model type (check if model_type is in run_name)
    filtered_runs = [
        run for run in runs
        if config.MODEL_TYPE in (run.info.run_name or "")
    ]
    
    print(f"   üîç Found {len(filtered_runs)} runs for {config.MODEL_TYPE}")

    return [{
        "run_id": run.info.run_id,
        "run_name": run.info.run_name or "unnamed_run",
        "metrics": {m: run.data.metrics.get(m) for m in config.METRICS_TO_COMPARE if m in run.data.metrics},
        "params": run.data.params,
        "primary_metric": run.data.metrics.get(config.PRIMARY_METRIC),
        "model_uri": f"runs:/{run.info.run_id}/{config.ARTIFACT_PATH}"
    } for run in filtered_runs]


# ---------------------- DUPLICATE CHECK ----------------------
def is_duplicate_model(new_model: Dict, config: Config) -> bool:
    """Check if model with similar metrics and params already exists"""
    if not config.DUPLICATE_CHECK_ENABLED:
        return False

    try:
        versions = client.search_model_versions(f"name='{config.MODEL_NAME}'")
    except:
        return False

    for version in versions:
        try:
            run = client.get_run(version.run_id)

            metrics_match = all(
                abs((run.data.metrics.get(m) or 0) - (new_model["metrics"].get(m) or 0)) <= config.TOLERANCE
                for m in config.METRICS_TO_COMPARE
            )

            if metrics_match and run.data.params == new_model["params"]:
                print(f"      ‚ö†Ô∏è  Duplicate detected ‚Üí Matches v{version.version}")
                return True
        except:
            continue

    return False


# ---------------------- CHECK IF ALREADY LOGGED ----------------------
def is_already_logged(run_id: str, table_name: str) -> bool:
    """Check if this run_id is already in the registration log"""
    try:
        existing = spark.sql(f"""
            SELECT run_id 
            FROM {table_name} 
            WHERE run_id = '{run_id}'
            LIMIT 1
        """).count()

        return existing > 0
    except:
        return False


# ---------------------- REGISTER MODEL ----------------------
def register_model(model: Dict, config: Config):
    """Register model to Unity Catalog Model Registry"""
    if is_duplicate_model(model, config):
        return None

    print(f"      üîÑ Registering to: {config.MODEL_NAME}")
    reg = mlflow.register_model(model["model_uri"], config.MODEL_NAME)
    version = reg.version

    # Set version tags
    try:
        client.set_model_version_tag(config.MODEL_NAME, version, "run_id", model["run_id"])
        client.set_model_version_tag(config.MODEL_NAME, version, "run_name", model["run_name"])
        client.set_model_version_tag(config.MODEL_NAME, version, "model_type", config.MODEL_TYPE)
        client.set_model_version_tag(config.MODEL_NAME, version, "primary_metric", config.PRIMARY_METRIC)
        client.set_model_version_tag(config.MODEL_NAME, version, "primary_metric_value", str(round(model["primary_metric"], 4)))
        client.set_model_version_tag(config.MODEL_NAME, version, "registered_timestamp", datetime.now().isoformat())
    except Exception as e:
        print(f"      ‚ö†Ô∏è  Failed to set tags: {e}")

    print(f"      ‚úÖ Registered as version: {version}")
    return version


# ---------------------- LOG DECISION ----------------------
def log_decision(model: Dict, config: Config, registered: bool, version: Optional[int], reason: str):
    """Log registration decision to Delta table"""

    if is_already_logged(model["run_id"], config.REGISTRATION_LOG_TABLE):
        print(f"      ‚ÑπÔ∏è  Already logged in table, skipping duplicate entry")
        return

    version_str = str(version) if version is not None else "N/A"

    record = {
        "timestamp": datetime.now(),
        "run_id": model["run_id"],
        "run_name": model["run_name"],
        "model_type": config.MODEL_TYPE,  # ‚úÖ NEW: Add model type
        "model_name": config.MODEL_NAME,
        "primary_metric": config.PRIMARY_METRIC,
        "primary_metric_value": float(model["primary_metric"]) if model["primary_metric"] else 0.0,
        "metrics_json": json.dumps(model["metrics"]),
        "params_json": json.dumps(model["params"]),
        "registered": registered,
        "registered_version": version_str,
        "reason": reason
    }

    df = pd.DataFrame([record])
    spark_df = spark.createDataFrame(df, schema=get_table_schema())

    try:
        spark_df.write.format("delta").mode("append").saveAsTable(config.REGISTRATION_LOG_TABLE)
        print(f"      üìÑ Logged to: {config.REGISTRATION_LOG_TABLE}")
    except Exception as e:
        print(f"      ‚ö†Ô∏è  Failed to log: {e}")


# ---------------------- PROCESS SINGLE MODEL TYPE ----------------------
def process_model_type(model_type: str, slack: SlackNotifier) -> Dict:
    """
    Process registration for a single model type
    Returns: dict with counts
    """
    print(f"\n{'='*80}")
    print(f"üöÄ PROCESSING MODEL TYPE: {model_type.upper()}")
    print(f"{'='*80}\n")
    
    # Create config for this model type
    config = Config(model_type)
    
    print(f"üì¶ Model Registry: {config.MODEL_NAME}")
    print(f"üîç Primary Metric: {config.PRIMARY_METRIC}")
    print(f"üõ°Ô∏è  Duplicate Check: {'ENABLED' if config.DUPLICATE_CHECK_ENABLED else 'DISABLED'}\n")
    
    # Ensure table exists
    ensure_table_exists(config.REGISTRATION_LOG_TABLE)
    
    # Get runs for this model type
    runs = get_runs_for_model(config)
    
    if not runs:
        print(f"   ‚ö†Ô∏è  No runs found for {model_type}")
        return {"registered": 0, "skipped": 0, "total": 0}
    
    registered_count = 0
    skipped_count = 0
    processed_run_ids = set()
    
    # Process each run
    for idx, model in enumerate(runs, start=1):
        
        if model['run_id'] in processed_run_ids:
            continue
        
        processed_run_ids.add(model['run_id'])
        
        print(f"\n   [{idx}/{len(runs)}] Processing: {model['run_name']}")
        print(f"      Primary Metric ({config.PRIMARY_METRIC}): {model['primary_metric']:.4f}")
        
        # Check if already logged
        if is_already_logged(model['run_id'], config.REGISTRATION_LOG_TABLE):
            print(f"      ‚è≠Ô∏è  Skipped ‚Äî Already processed earlier")
            continue
        
        # Try to register
        version = register_model(model, config)
        
        if version:
            log_decision(model, config, True, version, "‚úî Registered successfully")
            slack.send(f"‚úÖ Registered: {config.MODEL_NAME} v{version} ({model['run_name']})", "success")
            registered_count += 1
        else:
            log_decision(model, config, False, None, "‚ö† Duplicate - Skipped")
            slack.send(f"‚ö†Ô∏è Duplicate skipped: {model['run_name']}", "warning")
            skipped_count += 1
    
    print(f"\n{'='*80}")
    print(f"‚úÖ {model_type.upper()} REGISTRATION COMPLETE")
    print(f"{'='*80}")
    print(f"   ‚úÖ Registered: {registered_count}")
    print(f"   ‚ö†Ô∏è  Skipped: {skipped_count}")
    print(f"   üìä Total: {len(runs)}")
    
    return {
        "registered": registered_count,
        "skipped": skipped_count,
        "total": len(runs)
    }


# ---------------------- MAIN ----------------------
def main():
    """
    Main registration pipeline:
    1. Parse model types from Git variable
    2. For each model type, process all runs
    3. Register unique models to Unity Catalog
    """
    print("\nüöÄ Starting Multi-Model Registration Pipeline...\n")
    
    # ‚úÖ CHANGED: Parse model types from Git variable
    MODEL_TYPES_RAW = pipeline_cfg["models"]["enabled"]
    MODEL_TYPES = [m.strip() for m in MODEL_TYPES_RAW.split(",")]
    
    print(f"üìã Models to register: {MODEL_TYPES}\n")
    
    # Initialize Slack notifier
    slack = SlackNotifier(None)  # Will be initialized per model type
    
    # Track overall stats
    total_stats = {
        "registered": 0,
        "skipped": 0,
        "total": 0
    }
    
    results_by_model = {}
    
    # Process each model type
    for model_type in MODEL_TYPES:
        try:
            stats = process_model_type(model_type, slack)
            results_by_model[model_type] = stats
            
            # Update totals
            total_stats["registered"] += stats["registered"]
            total_stats["skipped"] += stats["skipped"]
            total_stats["total"] += stats["total"]
            
        except Exception as e:
            print(f"\n‚ùå Error processing {model_type}: {e}")
            results_by_model[model_type] = {"error": str(e)}
            continue
    
    # ---------------------- FINAL SUMMARY ----------------------
    print("\n" + "="*80)
    print("üéâ ALL MODELS REGISTRATION COMPLETED")
    print("="*80)
    
    print("\nüìä Summary by Model Type:")
    print("-" * 80)
    for model_type, stats in results_by_model.items():
        if "error" in stats:
            print(f"   ‚ùå {model_type}: {stats['error']}")
        else:
            print(f"   {model_type}:")
            print(f"      ‚úÖ Registered: {stats['registered']}")
            print(f"      ‚ö†Ô∏è  Skipped: {stats['skipped']}")
            print(f"      üìä Total: {stats['total']}")
    
    print("\n" + "="*80)
    print("üìà Overall Statistics:")
    print("="*80)
    print(f"   ‚úÖ Total Registered: {total_stats['registered']}")
    print(f"   ‚ö†Ô∏è  Total Skipped: {total_stats['skipped']}")
    print(f"   üìä Total Processed: {total_stats['total']}")
    print(f"   üì¶ Registration Log: {pipeline_cfg['tables']['registration_log']}")
    print("="*80)


if __name__ == "__main__":
    main()