In [None]:
# üéØ MODEL REGISTRATION SCRIPT - NEW WORKFLOW (FIXED)
# Purpose: Register best models with duplicate detection
# Compatible with: train.py ‚Üí evaluation.py ‚Üí THIS SCRIPT
# Config-driven from pipeline_config.yml


import mlflow
from mlflow.tracking import MlflowClient
import sys
import yaml
import json
import traceback
import requests
from typing import Dict, Optional, List
from datetime import datetime
from pyspark.sql import SparkSession
import pandas as pd

print("=" * 80)
print("üéØ MODEL REGISTRATION SYSTEM (NEW WORKFLOW)")
print("=" * 80)
 
# ‚úÖ LOAD PIPELINE CONFIGURATION
 
print("\nüìã Step 1: Loading configuration from pipeline_config.yml...")

try:
    with open("pipeline_config.yml", "r") as f:
        pipeline_cfg = yaml.safe_load(f)
    
    print(f"‚úÖ Configuration loaded successfully!")
    
except FileNotFoundError:
    print("‚ùå ERROR: pipeline_config.yml not found!")
    print("üí° Please ensure pipeline_config.yml is in the notebook directory")
    sys.exit(1)
except Exception as e:
    print(f"‚ùå ERROR loading configuration: {e}")
    traceback.print_exc()
    sys.exit(1)
 
# ‚úÖ EXTRACT CONFIGURATION VALUES

class Config:
    """Configuration manager - reads from pipeline_config.yml"""
    
    def __init__(self):
        # Model configuration
        MODEL_TYPE = pipeline_cfg["model"]["type"]
        UC_CATALOG = pipeline_cfg["model"]["catalog"]
        UC_SCHEMA = pipeline_cfg["model"]["schema"]
        BASE_NAME = pipeline_cfg["model"]["base_name"]
        
        # Build full model name
        self.MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{BASE_NAME}_{MODEL_TYPE}"
        self.MODEL_TYPE = MODEL_TYPE
        
        # Experiment tracking
        self.EXPERIMENT_NAME = pipeline_cfg["experiment"]["name"]
        self.ARTIFACT_PATH = pipeline_cfg["experiment"]["artifact_path"]
        
        # Metrics configuration
        self.PRIMARY_METRIC = pipeline_cfg["metrics"]["classification"]["primary_metric"]
        self.METRICS_TO_COMPARE = pipeline_cfg["registry"]["duplicate_detection"]["metrics_to_compare"]
        self.TOLERANCE = pipeline_cfg["registry"]["duplicate_detection"]["tolerance"]
        self.DUPLICATE_CHECK_ENABLED = pipeline_cfg["registry"]["duplicate_detection"]["enabled"]
        
        # Registry settings
        self.REGISTRY_MODE = pipeline_cfg["registry"]["mode"]
        
        # Aliases
        self.STAGING_ALIAS = pipeline_cfg["aliases"]["staging"]
        self.PRODUCTION_ALIAS = pipeline_cfg["aliases"]["production"]
        self.BEST_ALIAS = pipeline_cfg["aliases"]["best"]
        
        # Tables
        self.EVALUATION_LOG_TABLE = pipeline_cfg["tables"]["evaluation_log"]
        
        # Slack notification settings
        self.SLACK_ENABLED = pipeline_cfg["notifications"]["enabled"]
        self.SLACK_WEBHOOK_URL = self._get_slack_webhook()
        
        print(f"\nüìä Configuration Summary:")
        print(f"   Model Type: {self.MODEL_TYPE.upper()}")
        print(f"   Model Name: {self.MODEL_NAME}")
        print(f"   Experiment: {self.EXPERIMENT_NAME}")
        print(f"   Primary Metric: {self.PRIMARY_METRIC}")
        print(f"   Duplicate Detection: {'ENABLED' if self.DUPLICATE_CHECK_ENABLED else 'DISABLED'}")
        print(f"   Tolerance: {self.TOLERANCE}")
        print(f"   Registry Mode: {self.REGISTRY_MODE}")
        print(f"   Slack Notifications: {'ENABLED' if self.SLACK_WEBHOOK_URL else 'DISABLED'}")
    
    def _get_slack_webhook(self) -> Optional[str]:
        """Safely retrieve Slack webhook URL from Databricks secrets"""
        if not self.SLACK_ENABLED:
            return None
        
        # Try to get from Databricks secrets
        try:
            # Try common secret scopes
            scopes = ["shared-scope", "dev-scope", "prod-scope", "ml-scope"]
            for scope in scopes:
                try:
                    webhook = dbutils.secrets.get(scope, "SLACK_WEBHOOK_URL")
                    if webhook and webhook.strip():
                        print(f"   ‚úÖ Slack webhook found in scope '{scope}'")
                        return webhook
                except Exception:
                    continue
            
            print("   ‚ÑπÔ∏è  No Slack webhook found in secrets")
            return None
            
        except Exception as e:
            print(f"   ‚ö†Ô∏è  Could not access secrets: {e}")
            return None

# Initialize config
config = Config()

print("=" * 80)
 
# üì¢ SLACK NOTIFICATION HELPER

class SlackNotifier:
    """Enhanced Slack notification handler"""
    
    def __init__(self, webhook_url: Optional[str]):
        self.webhook_url = webhook_url
        self.enabled = webhook_url is not None and webhook_url.strip() != ""
        
    def send(self, message: str, level: str = "info", extra_fields: Optional[Dict] = None) -> bool:
        """
        Send Slack notification with error handling
        
        Args:
            message: Main message text
            level: Message level (info, success, warning, error)
            extra_fields: Additional fields to include in message
        
        Returns:
            bool: True if sent successfully, False otherwise
        """
        if not self.enabled:
            print(f"üì¢ [SLACK DISABLED] {message}")
            return False
        
        # Emoji mapping
        emoji_map = {
            "info": "‚ÑπÔ∏è",
            "success": "‚úÖ",
            "warning": "‚ö†Ô∏è",
            "error": "‚ùå",
            "trophy": "üèÜ",
            "rocket": "üöÄ"
        }
        
        # Build message
        formatted_message = f"{emoji_map.get(level, '‚ÑπÔ∏è')} *{message}*"
        
        # Add extra fields if provided
        if extra_fields:
            formatted_message += "\n"
            for key, value in extra_fields.items():
                formatted_message += f"\n‚Ä¢ *{key}:* {value}"
        
        payload = {
            "text": formatted_message,
            "username": "ML Pipeline Bot",
            "icon_emoji": ":robot_face:"
        }
        
        try:
            response = requests.post(
                self.webhook_url,
                json=payload,
                timeout=5
            )
            
            if response.status_code == 200:
                print(f"üì¢ Slack notification sent successfully")
                return True
            else:
                print(f"‚ö†Ô∏è  Slack error: {response.status_code} - {response.text}")
                return False
                
        except Exception as e:
            print(f"‚ùå Slack notification failed: {e}")
            return False
    
    def send_registration_success(self, model_name: str, version: int, metrics: Dict) -> bool:
        """Send success notification for model registration"""
        extra = {
            "Model": model_name,
            "Version": f"v{version}",
            "Timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }
        
        # Add key metrics
        for metric_name, metric_value in metrics.items():
            if metric_value is not None:
                extra[metric_name] = f"{metric_value:.4f}"
        
        return self.send(
            "Model Registration Successful",
            level="success",
            extra_fields=extra
        )
    
    def send_registration_skipped(self, model_name: str, reason: str) -> bool:
        """Send notification when registration is skipped"""
        extra = {
            "Model": model_name,
            "Reason": reason,
            "Timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }
        
        return self.send(
            "Model Registration Skipped",
            level="warning",
            extra_fields=extra
        )
    
    def send_error(self, error_message: str, details: Optional[str] = None) -> bool:
        """Send error notification"""
        extra = {
            "Error": error_message,
            "Timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }
        
        if details:
            extra["Details"] = details
        
        return self.send(
            "Registration Pipeline Error",
            level="error",
            extra_fields=extra
        )
 
# ‚úÖ INITIALIZE MLFLOW & SPARK

print("\nüîß Step 2: Initializing MLflow and Spark...")

try:
    spark = SparkSession.builder.appName("ModelRegistration").getOrCreate()
    mlflow.set_tracking_uri("databricks")
    mlflow.set_registry_uri("databricks-uc")
    client = MlflowClient()
    
    # Get experiment
    experiment = mlflow.get_experiment_by_name(config.EXPERIMENT_NAME)
    if experiment is None:
        raise Exception(f"Experiment '{config.EXPERIMENT_NAME}' not found!")
    
    print("‚úÖ MLflow and Spark initialized successfully")
    print(f"   Experiment ID: {experiment.experiment_id}")

except Exception as e:
    print(f"‚ùå Initialization failed: {e}")
    traceback.print_exc()
    sys.exit(1)

# Initialize Slack notifier
slack = SlackNotifier(config.SLACK_WEBHOOK_URL)

# Send startup notification
slack.send(
    "Model Registration Pipeline Started",
    level="info",
    extra_fields={
        "Model Type": config.MODEL_TYPE.upper(),
        "Experiment": config.EXPERIMENT_NAME
    }
)
 
# üìã STEP 1: GET BEST RUN FROM LATEST TRAINING

def get_best_run_from_training() -> Optional[Dict]:
    """
    Get the best run from the latest training session
    Based on primary metric defined in config
    """
    print(f"\n{'='*70}")
    print("üìã STEP 1: Finding Best Run from Training")
    print(f"{'='*70}")
    
    try:
        print(f"üîç Searching for best run by {config.PRIMARY_METRIC}...")
        
        # Search runs ordered by primary metric
        runs = client.search_runs(
            [experiment.experiment_id],
            order_by=[f"metrics.{config.PRIMARY_METRIC} DESC"],
            max_results=1
        )
        
        if not runs:
            print("‚ùå No runs found in experiment!")
            return None
        
        best_run = runs[0]
        
        # Extract all metrics for comparison
        all_metrics = {}
        for metric in config.METRICS_TO_COMPARE:
            all_metrics[metric] = best_run.data.metrics.get(metric)
        
        run_info = {
            "run_id": best_run.info.run_id,
            "run_name": best_run.info.run_name,
            "primary_metric": best_run.data.metrics.get(config.PRIMARY_METRIC),
            "all_metrics": all_metrics,
            "params": best_run.data.params,
            "model_uri": f"runs:/{best_run.info.run_id}/{config.ARTIFACT_PATH}",
            "timestamp": datetime.fromtimestamp(best_run.info.start_time / 1000)
        }
        
        print(f"‚úÖ Best run found:")
        print(f"   Run ID: {run_info['run_id']}")
        print(f"   Run Name: {run_info['run_name']}")
        print(f"   {config.PRIMARY_METRIC}: {run_info['primary_metric']:.4f}")
        print(f"   Model URI: {run_info['model_uri']}")
        
        print(f"\nüìä All Comparison Metrics:")
        for metric_name, metric_value in all_metrics.items():
            if metric_value is not None:
                print(f"   {metric_name}: {metric_value:.4f}")
        
        return run_info
    
    except Exception as e:
        print(f"‚ùå Failed to get best run: {e}")
        traceback.print_exc()
        return None
 
# üîç STEP 2: CHECK FOR DUPLICATE MODELS

def is_duplicate_model(new_model: Dict) -> bool:
    """
    Check if model with same metrics already exists in registry
    Compares multiple metrics defined in config
    """
    print(f"\n{'='*70}")
    print("üìã STEP 2: Duplicate Detection")
    print(f"{'='*70}")
    
    if not config.DUPLICATE_CHECK_ENABLED:
        print("‚ÑπÔ∏è  Duplicate detection disabled in config")
        return False
    
    print(f"üîé Checking for duplicates in registry...")
    print(f"   Model: {config.MODEL_NAME}")
    
    try:
        versions = client.search_model_versions(f"name='{config.MODEL_NAME}'")
        versions_list = list(versions)
    except Exception as e:
        print(f"‚ÑπÔ∏è  Model not found in registry (first registration): {e}")
        return False
    
    if not versions_list:
        print("‚ÑπÔ∏è  No existing versions found (first registration)")
        return False
    
    print(f"üìä Found {len(versions_list)} existing version(s), checking metrics...")
    
    for version in versions_list:
        try:
            # Get run details
            run = client.get_run(version.run_id)
            
            # Compare all metrics defined in config
            all_metrics_match = True
            metric_comparisons = []
            
            for metric_name in config.METRICS_TO_COMPARE:
                existing_value = run.data.metrics.get(metric_name)
                new_value = new_model['all_metrics'].get(metric_name)
                
                if existing_value is None or new_value is None:
                    continue
                
                diff = abs(existing_value - new_value)
                matches = diff < config.TOLERANCE
                
                metric_comparisons.append({
                    'metric': metric_name,
                    'existing': existing_value,
                    'new': new_value,
                    'diff': diff,
                    'matches': matches
                })
                
                if not matches:
                    all_metrics_match = False
            
            if all_metrics_match and metric_comparisons:
                print(f"\n‚ö†Ô∏è  DUPLICATE FOUND!")
                print(f"   Version: v{version.version}")
                print(f"   Run ID: {version.run_id}")
                print(f"\n   Metric Comparison:")
                for comp in metric_comparisons:
                    print(f"      {comp['metric']}: diff={comp['diff']:.6f} (tolerance={config.TOLERANCE})")
                
                # Send Slack notification about duplicate
                slack.send(
                    "Duplicate Model Detected",
                    level="warning",
                    extra_fields={
                        "Model": config.MODEL_NAME,
                        "Existing Version": f"v{version.version}",
                        "Reason": "Same metrics within tolerance",
                        "Tolerance": str(config.TOLERANCE)
                    }
                )
                
                return True
        
        except Exception as e:
            print(f"‚ö†Ô∏è  Error checking version {version.version}: {e}")
            continue
    
    print("‚úÖ No duplicate found - model is unique")
    return False
 
# üöÄ STEP 3: REGISTER MODEL TO UNITY CATALOG

def register_model(run_info: Dict) -> Optional[any]:
    """Register model to Unity Catalog"""
    print(f"\n{'='*70}")
    print("üìã STEP 3: Model Registration")
    print(f"{'='*70}")
    
    # Check for duplicates
    if is_duplicate_model(run_info):
        print("\n‚ö†Ô∏è  Registration SKIPPED: Duplicate model detected")
        print("   Model with same metrics already exists in registry")
        
        # Send Slack notification
        slack.send_registration_skipped(
            config.MODEL_NAME,
            "Duplicate model - same metrics already registered"
        )
        
        return None
    
    try:
        print(f"\n‚è≥ Registering model to Unity Catalog...")
        print(f"   Model Name: {config.MODEL_NAME}")
        print(f"   Source URI: {run_info['model_uri']}")
        
        # Register model
        new_version = mlflow.register_model(
            run_info['model_uri'],
            config.MODEL_NAME
        )
        
        print(f"\n‚úÖ MODEL REGISTERED SUCCESSFULLY!")
        print(f"   Model: {config.MODEL_NAME}")
        print(f"   Version: v{new_version.version}")
        print(f"   Run ID: {run_info['run_id']}")
        
        # Send Slack success notification
        slack.send_registration_success(
            config.MODEL_NAME,
            new_version.version,
            run_info['all_metrics']
        )
        
        return new_version
    
    except Exception as e:
        print(f"‚ùå Registration failed: {e}")
        traceback.print_exc()
        
        # Send Slack error notification
        slack.send_error(
            "Model registration failed",
            details=str(e)
        )
        
        return None
 
# üè∑Ô∏è  STEP 4: ADD METADATA TAGS

def add_metadata_tags(version_number: int, run_info: Dict) -> bool:
    """Add metadata tags to registered model version"""
    print(f"\n{'='*70}")
    print("üìã STEP 4: Adding Metadata Tags")
    print(f"{'='*70}")
    
    try:
        tags = {
            "model_type": config.MODEL_TYPE,
            "registered_from": "new_registration_pipeline",
            "registration_timestamp": datetime.now().isoformat(),
            "source_run_id": run_info['run_id'],
            "source_run_name": run_info['run_name'],
            "primary_metric": config.PRIMARY_METRIC,
            "primary_metric_value": f"{run_info['primary_metric']:.6f}",
            "artifact_path": config.ARTIFACT_PATH,
            "training_timestamp": run_info['timestamp'].isoformat()
        }
        
        # Add all comparison metrics as tags
        for metric_name, metric_value in run_info['all_metrics'].items():
            if metric_value is not None:
                tags[f"metric_{metric_name}"] = f"{metric_value:.6f}"
        
        print(f"   Adding {len(tags)} metadata tags...")
        
        for key, value in tags.items():
            try:
                client.set_model_version_tag(
                    config.MODEL_NAME,
                    version_number,
                    key,
                    str(value)
                )
            except Exception as e:
                print(f"   ‚ö†Ô∏è  Failed to set tag '{key}': {e}")
                continue
        
        print(f"   ‚úÖ Metadata tags added successfully")
        return True
    
    except Exception as e:
        print(f"‚ùå Failed to add tags: {e}")
        traceback.print_exc()
        return False
 
# üìù STEP 5: LOG REGISTRATION DECISION

def log_registration_decision(run_info: Dict, registered: bool, version: Optional[int], reason: str) -> None:
    """Log registration decision to Delta table"""
    print(f"\n{'='*70}")
    print("üìã STEP 5: Logging Decision")
    print(f"{'='*70}")
    
    try:
        log_data = {
            "timestamp": datetime.now(),
            "run_id": run_info["run_id"],
            "run_name": run_info["run_name"],
            "model_name": config.MODEL_NAME,
            "model_type": config.MODEL_TYPE,
            "primary_metric": config.PRIMARY_METRIC,
            "primary_metric_value": run_info["primary_metric"],
            "all_metrics_json": json.dumps(run_info["all_metrics"]),
            "params_json": json.dumps(run_info["params"]),
            "model_uri": run_info["model_uri"],
            "registered": registered,
            "registered_version": version if version else None,
            "reason": reason
        }
        
        df = pd.DataFrame([log_data])
        
        print(f"   Logging to: {config.EVALUATION_LOG_TABLE}")
        spark.createDataFrame(df).write.format("delta").mode("append").saveAsTable(
            config.EVALUATION_LOG_TABLE
        )
        
        print(f"   ‚úÖ Decision logged successfully")
    
    except Exception as e:
        print(f"‚ö†Ô∏è  Failed to log decision: {e}")
        traceback.print_exc()
 
# üìä STEP 6: DISPLAY SUMMARY

def display_summary(run_info: Dict, version_number: Optional[int], registered: bool) -> None:
    """Display registration summary"""
    print(f"\n{'='*80}")
    if registered:
        print("‚úÖ‚úÖ MODEL REGISTRATION COMPLETE ‚úÖ‚úÖ")
    else:
        print("‚ö†Ô∏è‚ö†Ô∏è  MODEL REGISTRATION SKIPPED ‚ö†Ô∏è‚ö†Ô∏è")
    print(f"{'='*80}")
    
    print(f"\nüìä Source Model:")
    print(f"   Model Type: {config.MODEL_TYPE.upper()}")
    print(f"   Run ID: {run_info['run_id']}")
    print(f"   Run Name: {run_info['run_name']}")
    print(f"   {config.PRIMARY_METRIC}: {run_info['primary_metric']:.4f}")
    
    if registered and version_number:
        print(f"\nüèÜ Registered Model:")
        print(f"   Registry: {config.MODEL_NAME}")
        print(f"   Version: v{version_number}")
        
        print(f"\nüìå Next Steps:")
        print(f"   1. Verify model in Unity Catalog")
        print(f"   2. Run UAT/validation tests")
        print(f"   3. Promote to production if tests pass")
    else:
        print(f"\n‚è≠Ô∏è  Registration skipped (duplicate detected)")
        print(f"   Check existing versions in: {config.MODEL_NAME}")
    
    print("=" * 80)
 
# üé¨ MAIN EXECUTION

def main():
    """Main registration pipeline"""
    registered = False
    version_number = None
    reason = ""
    
    try:
        # Step 1: Get best run from training
        run_info = get_best_run_from_training()
        if not run_info:
            print("\n‚ùå No run found to register")
            reason = "No run found in experiment"
            return
        
        # Step 2 & 3: Check duplicates and register
        new_version = register_model(run_info)
        
        if new_version:
            registered = True
            version_number = new_version.version
            reason = "Successfully registered (unique model)"
            
            # Step 4: Add metadata tags
            add_metadata_tags(version_number, run_info)
        else:
            registered = False
            reason = "Duplicate model detected - skipped registration"
        
        # Step 5: Log decision
        log_registration_decision(run_info, registered, version_number, reason)
        
        # Step 6: Display summary
        display_summary(run_info, version_number, registered)
        
        # Send final Slack summary
        if registered and version_number:
            slack.send(
                "Registration Pipeline Completed Successfully",
                level="trophy",
                extra_fields={
                    "Model": config.MODEL_NAME,
                    "Version": f"v{version_number}",
                    "Run Name": run_info['run_name'],
                    f"{config.PRIMARY_METRIC}": f"{run_info['primary_metric']:.4f}",
                    "Status": "Ready for UAT/Production"
                }
            )
        elif not registered:
            slack.send(
                "Registration Pipeline Completed - No Registration",
                level="info",
                extra_fields={
                    "Model": config.MODEL_NAME,
                    "Reason": reason,
                    "Run Name": run_info['run_name']
                }
            )
        
        # Save task values for workflow
        try:
            dbutils.jobs.taskValues.set(key="model_type", value=config.MODEL_TYPE)
            dbutils.jobs.taskValues.set(key="model_name", value=config.MODEL_NAME)
            dbutils.jobs.taskValues.set(key="registered", value=registered)
            if version_number:
                dbutils.jobs.taskValues.set(key="model_version", value=version_number)
            print("\n‚úÖ Task values saved for workflow")
        except:
            print("\n‚ÑπÔ∏è  Not running in Databricks workflow - skipping task values")
        
        print("\nüéâ Registration pipeline completed!")
        
    except Exception as e:
        print(f"\n‚ùå Registration pipeline failed: {e}")
        traceback.print_exc()
        
        # Send critical error notification
        slack.send_error(
            "Registration Pipeline Failed",
            details=f"{type(e).__name__}: {str(e)}"
        )
        
        # Try to log failure
        try:
            if run_info:
                log_registration_decision(
                    run_info,
                    False,
                    None,
                    f"Registration failed: {str(e)}"
                )
        except:
            pass
        
        sys.exit(1)

# Execute
if __name__ == "__main__":
    main()