In [None]:
# =============================================================================
# üß™ UAT MODEL INFERENCE - NEW WORKFLOW (CONFIG-DRIVEN)
# =============================================================================
# Purpose: Validate staging model performance on UAT data
# Compatible with: train.py ‚Üí model_registration.py ‚Üí uat_staging.py ‚Üí THIS SCRIPT
# Prerequisites: Run uat_staging.py first to promote model to @Staging
# =============================================================================

import mlflow
from mlflow.tracking import MlflowClient
import pandas as pd
import numpy as np
import math
from sklearn.metrics import (
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score,
    roc_auc_score,
    confusion_matrix
)
from pyspark.sql import SparkSession
from pyspark.ml.linalg import VectorUDT, Vectors
from datetime import datetime
import warnings
import sys
import traceback
import yaml
import json
import requests
from typing import Dict, Optional, Tuple

warnings.filterwarnings("ignore")

print("=" * 80)
print("üß™ UAT MODEL INFERENCE (NEW WORKFLOW)")
print("=" * 80)
 
# ‚úÖ LOAD PIPELINE CONFIGURATION

print("\nüìã Step 1: Loading configuration from pipeline_config.yml...")

try:
    with open("pipeline_config.yml", "r") as f:
        pipeline_cfg = yaml.safe_load(f)
    
    print(f"‚úÖ Configuration loaded successfully!")
    
except FileNotFoundError:
    print("‚ùå ERROR: pipeline_config.yml not found!")
    sys.exit(1)
except Exception as e:
    print(f"‚ùå ERROR loading configuration: {e}")
    traceback.print_exc()
    sys.exit(1)
 
# ‚úÖ CONFIGURATION CLASS

class Config:
    """Configuration manager - reads from pipeline_config.yml"""
    
    def __init__(self):
        # Model configuration
        MODEL_TYPE = pipeline_cfg["model"]["type"]
        UC_CATALOG = pipeline_cfg["model"]["catalog"]
        UC_SCHEMA = pipeline_cfg["model"]["schema"]
        BASE_NAME = pipeline_cfg["model"]["base_name"]
        
        self.MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{BASE_NAME}_{MODEL_TYPE}"
        self.MODEL_TYPE = MODEL_TYPE
        
        # Aliases
        self.STAGING_ALIAS = pipeline_cfg["aliases"]["staging"]
        self.PRODUCTION_ALIAS = pipeline_cfg["aliases"]["production"]
        
        # Data configuration - Use PREPROCESSED data (same as train.py)
        self.UAT_INPUT_TABLE = pipeline_cfg["data"]["preprocessed_table"]
        self.LABEL_COL = "label"  # Preprocessed data has 'label' column
        
        # Metrics configuration
        self.PRIMARY_METRIC = pipeline_cfg["metrics"]["classification"]["primary_metric"]
        self.DIRECTION = pipeline_cfg["metrics"]["classification"]["direction"]
        self.TRACKED_METRICS = pipeline_cfg["metrics"]["classification"]["tracked_metrics"]
        
        # UAT thresholds from classification config
        self.UAT_THRESHOLDS = pipeline_cfg["uat"]["classification_thresholds"]
        
        # Output table
        self.UAT_RESULTS_TABLE = pipeline_cfg["tables"]["uat_results"]
        
        # Slack notifications
        self.SLACK_ENABLED = pipeline_cfg["notifications"]["enabled"]
        self.SLACK_WEBHOOK_URL = self._get_slack_webhook()
        
        print(f"\nüìä Configuration Summary:")
        print(f"   Model Type: {self.MODEL_TYPE.upper()}")
        print(f"   Model Name: {self.MODEL_NAME}")
        print(f"   Staging Alias: @{self.STAGING_ALIAS}")
        print(f"   UAT Input: {self.UAT_INPUT_TABLE}")
        print(f"   Primary Metric: {self.PRIMARY_METRIC}")
        print(f"   Slack: {'ENABLED' if self.SLACK_WEBHOOK_URL else 'DISABLED'}")
    
    def _get_slack_webhook(self) -> Optional[str]:
        """Safely retrieve Slack webhook URL from Databricks secrets"""
        if not self.SLACK_ENABLED:
            return None
        
        try:
            scopes = ["shared-scope", "dev-scope", "prod-scope", "ml-scope"]
            for scope in scopes:
                try:
                    webhook = dbutils.secrets.get(scope, "SLACK_WEBHOOK_URL")
                    if webhook and webhook.strip():
                        print(f"   ‚úÖ Slack webhook found in scope '{scope}'")
                        return webhook
                except Exception:
                    continue
            
            print("   ‚ÑπÔ∏è  No Slack webhook found in secrets")
            return None
            
        except Exception as e:
            print(f"   ‚ö†Ô∏è  Could not access secrets: {e}")
            return None

# Initialize config
config = Config()

print("=" * 80)
 
# ‚úÖ SLACK NOTIFICATION HELPER

class SlackNotifier:
    """Slack notification handler"""
    
    def __init__(self, webhook_url: Optional[str]):
        self.webhook_url = webhook_url
        self.enabled = webhook_url is not None and webhook_url.strip() != ""
        
    def send(self, message: str, level: str = "info", extra_fields: Optional[Dict] = None) -> bool:
        """Send Slack notification"""
        if not self.enabled:
            print(f"üì¢ [SLACK DISABLED] {message}")
            return False
        
        emoji_map = {
            "info": "‚ÑπÔ∏è",
            "success": "‚úÖ",
            "warning": "‚ö†Ô∏è",
            "error": "‚ùå",
            "test": "üß™"
        }
        
        formatted_message = f"{emoji_map.get(level, '‚ÑπÔ∏è')} *{message}*"
        
        if extra_fields:
            formatted_message += "\n"
            for key, value in extra_fields.items():
                formatted_message += f"\n‚Ä¢ *{key}:* {value}"
        
        payload = {
            "text": formatted_message,
            "username": "ML Pipeline Bot",
            "icon_emoji": ":test_tube:"
        }
        
        try:
            response = requests.post(
                self.webhook_url,
                json=payload,
                timeout=5
            )
            
            if response.status_code == 200:
                print(f"üì¢ Slack notification sent successfully")
                return True
            else:
                print(f"‚ö†Ô∏è  Slack error: {response.status_code}")
                return False
                
        except Exception as e:
            print(f"‚ùå Slack notification failed: {e}")
            return False

# Initialize Slack notifier
slack = SlackNotifier(config.SLACK_WEBHOOK_URL)
 
# ‚úÖ INITIALIZE MLFLOW & SPARK

print("\nüîß Step 2: Initializing MLflow and Spark...")

try:
    spark = SparkSession.builder.appName("UAT_Inference").getOrCreate()
    mlflow.set_tracking_uri("databricks")
    mlflow.set_registry_uri("databricks-uc")
    client = MlflowClient()
    
    print("‚úÖ MLflow and Spark initialized successfully")

except Exception as e:
    print(f"‚ùå Failed to initialize: {e}")
    sys.exit(1)

# Send startup notification
slack.send(
    "UAT Inference Pipeline Started",
    level="test",
    extra_fields={
        "Model": config.MODEL_NAME,
        "Model Type": config.MODEL_TYPE.upper(),
        "Alias": f"@{config.STAGING_ALIAS}"
    }
)

# =============================================================================
# üîß HELPER FUNCTION: CONVERT SPARK VECTOR TO NUMPY
# =============================================================================
def vector_to_array(v):
    """Convert PySpark ML Vector to numpy array"""
    if v is None:
        return None
    return v.toArray() if hasattr(v, 'toArray') else np.array(v)
 
# üìã STEP 1: LOAD MODEL FROM STAGING

def load_staging_model() -> Tuple[any, int, str]:
    """Load model from Unity Catalog using Staging alias"""
    print(f"\n{'='*70}")
    print(f"üìã STEP 1: Loading Model from @{config.STAGING_ALIAS}")
    print(f"{'='*70}")
    
    try:
        print(f"üîç Looking for model: {config.MODEL_NAME}@{config.STAGING_ALIAS}")
        
        # Get model version by alias
        try:
            model_version = client.get_model_version_by_alias(
                config.MODEL_NAME, 
                config.STAGING_ALIAS
            )
        except Exception as e:
            print(f"‚ùå No model found with alias @{config.STAGING_ALIAS}")
            print(f"üí° Please run uat_staging.py first to promote a model")
            raise Exception(f"Model not found: {e}")
        
        version = int(model_version.version)
        run_id = model_version.run_id
        status = model_version.status
        
        print(f"‚úÖ Found model:")
        print(f"   Version: v{version}")
        print(f"   Run ID: {run_id}")
        print(f"   Status: {status}")
        
        # Get training metrics from tags
        tags = model_version.tags
        print(f"\nüìä Training Metrics (from tags):")
        for key, value in tags.items():
            if key.startswith("metric_"):
                metric_name = key.replace("metric_", "")
                print(f"   {metric_name}: {value}")
        
        # Load the model
        model_uri = f"models:/{config.MODEL_NAME}@{config.STAGING_ALIAS}"
        print(f"\n‚è≥ Loading model...")
        model = mlflow.pyfunc.load_model(model_uri)
        
        print(f"\n‚úÖ Model loaded successfully")
        
        return model, version, run_id
        
    except Exception as e:
        print(f"\n‚ùå Failed to load model: {e}")
        traceback.print_exc()
        raise
 
# üìã STEP 2: LOAD UAT DATA

def load_uat_data() -> Tuple[pd.DataFrame, np.ndarray, np.ndarray]:
    """
    Load UAT data from preprocessed Delta table
    Same format as train.py uses - with PySpark Vectors
    """
    print(f"\n{'='*70}")
    print("üìã STEP 2: Loading UAT Data (Preprocessed)")
    print(f"{'='*70}")
    
    try:
        print(f"üìä Loading from: {config.UAT_INPUT_TABLE}")
        
        # Read from Delta
        df_spark = spark.read.format("delta").table(config.UAT_INPUT_TABLE)
        
        # Check schema
        print(f"\nüìä Schema:")
        df_spark.printSchema()
        
        # Convert to pandas
        df = df_spark.toPandas()
        
        print(f"\n‚úÖ Data loaded successfully!")
        print(f"   Total rows: {len(df):,}")
        print(f"   Columns: {list(df.columns)}")
        
        # Extract features and labels (same as train.py)
        print(f"\nüîÑ Converting PySpark Vectors to numpy arrays...")
        
        # 'features' column contains PySpark Vectors, convert to numpy
        X = np.array([vector_to_array(row) for row in df['features']])
        y_true = df[config.LABEL_COL].values
        
        print(f"‚úÖ Conversion complete!")
        print(f"   Features shape: {X.shape}")
        print(f"   Labels shape: {y_true.shape}")
        
        # Check label distribution
        unique_labels, counts = np.unique(y_true, return_counts=True)
        print(f"\nüìä Label distribution:")
        for label, count in zip(unique_labels, counts):
            label_name = "Default" if label == 1 else "No Default"
            print(f"      Class {int(label)} ({label_name}): {count:,} ({count / len(y_true) * 100:.1f}%)")
        
        return df, X, y_true
        
    except Exception as e:
        print(f"\n‚ùå Failed to load data: {e}")
        print(f"üí° Verify table exists: {config.UAT_INPUT_TABLE}")
        traceback.print_exc()
        raise
 
# üìã STEP 3: RUN INFERENCE

def run_inference(model: any, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Run model inference on UAT data"""
    print(f"\n{'='*70}")
    print("üìã STEP 3: Running Inference")
    print(f"{'='*70}")
    
    try:
        print(f"‚è≥ Generating predictions for {len(X):,} samples...")
        
        import time
        start_time = time.time()
        
        # Get predictions
        y_pred = model.predict(X)
        
        # Get probability predictions (for ROC-AUC)
        # MLflow pyfunc models may not have predict_proba, so we handle this
        try:
            # Try to get the underlying sklearn model
            if hasattr(model, '_model_impl') and hasattr(model._model_impl, 'predict_proba'):
                y_pred_proba = model._model_impl.predict_proba(X)[:, 1]
            else:
                # Fallback: use predictions as probabilities (binary 0/1)
                print("   ‚ÑπÔ∏è  predict_proba not available, using binary predictions")
                y_pred_proba = y_pred.astype(float)
        except Exception as e:
            print(f"   ‚ö†Ô∏è  Could not get probabilities: {e}")
            y_pred_proba = y_pred.astype(float)
        
        inference_time = time.time() - start_time
        inference_time_ms = (inference_time / len(X)) * 1000  # ms per sample
        
        print(f"\n‚úÖ Inference complete")
        print(f"   Predictions generated: {len(y_pred):,}")
        print(f"   Total time: {inference_time:.2f}s")
        print(f"   Time per sample: {inference_time_ms:.3f}ms")
        print(f"   Prediction distribution:\n{pd.Series(y_pred).value_counts()}")
        
        return y_pred, y_pred_proba
        
    except Exception as e:
        print(f"\n‚ùå Inference failed: {e}")
        traceback.print_exc()
        raise
 
# üìã STEP 4: CALCULATE METRICS

def calculate_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_pred_proba: np.ndarray) -> Dict:
    """Calculate evaluation metrics for classification"""
    print(f"\n{'='*70}")
    print("üìã STEP 4: Calculating Metrics")
    print(f"{'='*70}")
    
    try:
        # Classification metrics
        metrics = {
            'accuracy': accuracy_score(y_true, y_pred),
            'precision': precision_score(y_true, y_pred, average='binary', zero_division=0),
            'recall': recall_score(y_true, y_pred, average='binary', zero_division=0),
            'f1': f1_score(y_true, y_pred, average='binary', zero_division=0)
        }
        
        # ROC-AUC (if probabilities available)
        try:
            metrics['roc_auc'] = roc_auc_score(y_true, y_pred_proba)
        except Exception as e:
            print(f"   ‚ö†Ô∏è  Could not calculate ROC-AUC: {e}")
            metrics['roc_auc'] = None
        
        print(f"üìä Classification Metrics:")
        print(f"   Accuracy:  {metrics['accuracy']:.4f}")
        print(f"   Precision: {metrics['precision']:.4f}")
        print(f"   Recall:    {metrics['recall']:.4f}")
        print(f"   F1 Score:  {metrics['f1']:.4f}")
        if metrics['roc_auc'] is not None:
            print(f"   ROC-AUC:   {metrics['roc_auc']:.4f}")
        
        # Confusion Matrix
        cm = confusion_matrix(y_true, y_pred)
        print(f"\nüìä Confusion Matrix:")
        print(f"   [[TN={cm[0,0]:<6} FP={cm[0,1]:<6}]")
        print(f"    [FN={cm[1,0]:<6} TP={cm[1,1]:<6}]]")
        
        metrics['confusion_matrix'] = cm.tolist()
        
        return metrics
        
    except Exception as e:
        print(f"\n‚ùå Metric calculation failed: {e}")
        traceback.print_exc()
        raise
 
# üìã STEP 5: VALIDATE UAT

def validate_uat(metrics: Dict, model_version: int) -> Tuple[str, list]:
    """Validate model against UAT thresholds"""
    print(f"\n{'='*70}")
    print("üìã STEP 5: UAT Validation")
    print(f"{'='*70}")
    
    print(f"\nüìè Validation Thresholds:")
    for metric_name, threshold_value in config.UAT_THRESHOLDS.items():
        if not metric_name.startswith("max_"):
            print(f"   {metric_name}: ‚â• {threshold_value}")
        else:
            print(f"   {metric_name}: ‚â§ {threshold_value}")
    
    print(f"\nüìä Actual Performance:")
    
    passed_checks = []
    failed_checks = []
    
    # Map metric names from config to calculated metrics
    metric_mapping = {
        'min_accuracy': 'accuracy',
        'min_f1': 'f1',
        'min_roc_auc': 'roc_auc',
        'min_recall': 'recall',
        'max_inference_time_ms': 'inference_time_ms'  # Would need to track separately
    }
    
    for threshold_name, threshold_value in config.UAT_THRESHOLDS.items():
        # Map threshold name to metric name
        metric_name = metric_mapping.get(threshold_name)
        
        if metric_name is None or metric_name not in metrics or metrics[metric_name] is None:
            continue
        
        actual_value = metrics[metric_name]
        
        # Determine if threshold is min or max
        if threshold_name.startswith("min_"):
            passed = actual_value >= threshold_value
            operator = ">="
        elif threshold_name.startswith("max_"):
            passed = actual_value <= threshold_value
            operator = "<="
        else:
            passed = actual_value >= threshold_value
            operator = ">="
        
        status = "‚úÖ PASS" if passed else "‚ùå FAIL"
        print(f"   {threshold_name}: {actual_value:.4f} {operator} {threshold_value} {status}")
        
        if passed:
            passed_checks.append(threshold_name)
        else:
            failed_checks.append({
                'metric': threshold_name,
                'actual': actual_value,
                'threshold': threshold_value,
                'operator': operator
            })
    
    # Determine overall status
    if not failed_checks:
        print(f"\n{'='*70}")
        print("‚úÖ‚úÖ UAT PASSED ‚úÖ‚úÖ")
        print(f"{'='*70}")
        print(f"   Model v{model_version} is ready for production!")
        
        slack.send(
            "UAT Validation PASSED",
            level="success",
            extra_fields={
                "Model": config.MODEL_NAME,
                "Version": f"v{model_version}",
                "Accuracy": f"{metrics.get('accuracy', 0):.4f}",
                "F1 Score": f"{metrics.get('f1', 0):.4f}",
                "Status": "Ready for Production"
            }
        )
        
        return "PASSED", []
    else:
        print(f"\n{'='*70}")
        print("‚ùå‚ùå UAT FAILED ‚ùå‚ùå")
        print(f"{'='*70}")
        
        print(f"\n   Failed checks ({len(failed_checks)}):")
        for check in failed_checks:
            print(f"   ‚Ä¢ {check['metric']}: {check['actual']:.4f} "
                  f"{check['operator']} {check['threshold']}")
        
        slack.send(
            "UAT Validation FAILED",
            level="error",
            extra_fields={
                "Model": config.MODEL_NAME,
                "Version": f"v{model_version}",
                "Failed Checks": len(failed_checks),
                "Status": "Needs Retraining"
            }
        )
        
        return "FAILED", failed_checks
 
# üìã STEP 6: LOG RESULTS

def log_results(
    model_version: int,
    run_id: str,
    metrics: Dict,
    status: str,
    failed_checks: list
) -> None:
    """Log UAT results to Delta table"""
    print(f"\n{'='*70}")
    print("üìã STEP 6: Logging Results")
    print(f"{'='*70}")
    
    try:
        # Check if table exists
        table_exists = False
        try:
            spark.table(config.UAT_RESULTS_TABLE)
            table_exists = True
            print(f"   Table exists: Yes")
        except Exception:
            print(f"   Table exists: No (will be created)")
        
        # Prepare result data
        result_data = {
            "timestamp": datetime.now(),
            "model_name": config.MODEL_NAME,
            "model_type": config.MODEL_TYPE,
            "model_version": str(model_version),
            "run_id": run_id,
            "uat_status": status,
            
            # Individual metrics
            "accuracy": float(metrics.get('accuracy', 0)),
            "precision": float(metrics.get('precision', 0)),
            "recall": float(metrics.get('recall', 0)),
            "f1": float(metrics.get('f1', 0)),
            "roc_auc": float(metrics.get('roc_auc', 0)) if metrics.get('roc_auc') else None,
            
            # Metadata
            "all_metrics_json": json.dumps({k: float(v) if v is not None else None 
                                           for k, v in metrics.items() 
                                           if k != 'confusion_matrix'}),
            "thresholds_json": json.dumps(config.UAT_THRESHOLDS),
            "failed_checks_json": json.dumps(failed_checks) if failed_checks else None,
            "num_failed_checks": len(failed_checks),
            
            # Confusion matrix
            "confusion_matrix_json": json.dumps(metrics.get('confusion_matrix', []))
        }
        
        result_df = pd.DataFrame([result_data])
        spark_df = spark.createDataFrame(result_df)
        
        # Write to Delta table
        if table_exists:
            spark_df.write.mode("append").option("mergeSchema", "true").saveAsTable(
                config.UAT_RESULTS_TABLE
            )
        else:
            spark_df.write.mode("append").saveAsTable(config.UAT_RESULTS_TABLE)
        
        print(f"\n‚úÖ Results logged successfully")
        print(f"   Output Table: {config.UAT_RESULTS_TABLE}")
        print(f"   Model: {config.MODEL_NAME}")
        print(f"   Version: v{model_version}")
        print(f"   Status: {status}")
        
    except Exception as e:
        print(f"\n‚ö†Ô∏è  Failed to log results: {e}")
        traceback.print_exc()
 
# üé¨ MAIN EXECUTION

def main():
    """Main UAT inference pipeline"""
    try:
        print("\n" + "="*80)
        print("üé¨ STARTING UAT INFERENCE PIPELINE")
        print("="*80 + "\n")
        
        # Step 1: Load model
        model, model_version, run_id = load_staging_model()
        
        # Step 2: Load UAT data
        df, X, y_true = load_uat_data()
        
        # Step 3: Run inference
        y_pred, y_pred_proba = run_inference(model, X)
        
        # Step 4: Calculate metrics
        metrics = calculate_metrics(y_true, y_pred, y_pred_proba)
        
        # Step 5: Validate UAT
        status, failed_checks = validate_uat(metrics, model_version)
        
        # Step 6: Log results
        log_results(model_version, run_id, metrics, status, failed_checks)
        
        # Final summary
        print("\n" + "="*80)
        print("‚ú® UAT INFERENCE COMPLETED SUCCESSFULLY ‚ú®")
        print("="*80)
        print(f"\nüìä Final Summary:")
        print(f"   Model: {config.MODEL_NAME}")
        print(f"   Model Type: {config.MODEL_TYPE.upper()}")
        print(f"   Version: v{model_version}")
        print(f"   UAT Status: {status}")
        print(f"\n   Key Metrics:")
        print(f"     ‚Ä¢ Accuracy:  {metrics['accuracy']:.4f}")
        print(f"     ‚Ä¢ Precision: {metrics['precision']:.4f}")
        print(f"     ‚Ä¢ Recall:    {metrics['recall']:.4f}")
        print(f"     ‚Ä¢ F1 Score:  {metrics['f1']:.4f}")
        if metrics.get('roc_auc'):
            print(f"     ‚Ä¢ ROC-AUC:   {metrics['roc_auc']:.4f}")
        
        if status == "PASSED":
            print(f"\nüìå Next Step:")
            print(f"   ‚úÖ Model is ready for production promotion")
            print(f"   Run production_promotion.py to deploy")
        else:
            print(f"\nüìå Next Step:")
            print(f"   ‚ùå Model needs improvement")
            print(f"   Failed checks: {len(failed_checks)}")
            print(f"   Review metrics and retrain with better hyperparameters")
        
        print("="*80 + "\n")
        
        # Save task values for workflow
        try:
            dbutils.jobs.taskValues.set(key="uat_status", value=status)
            dbutils.jobs.taskValues.set(key="model_version", value=model_version)
            dbutils.jobs.taskValues.set(key="accuracy", value=float(metrics['accuracy']))
            dbutils.jobs.taskValues.set(key="f1_score", value=float(metrics['f1']))
            dbutils.jobs.taskValues.set(key="num_failed_checks", value=len(failed_checks))
            print("‚úÖ Task values saved for workflow")
        except:
            print("‚ÑπÔ∏è  Not running in workflow - skipping task values")
        
        # Send final summary notification
        slack.send(
            "UAT Pipeline Completed",
            level="success" if status == "PASSED" else "warning",
            extra_fields={
                "Model": config.MODEL_NAME,
                "Version": f"v{model_version}",
                "Status": status,
                "Accuracy": f"{metrics['accuracy']:.4f}",
                "F1": f"{metrics['f1']:.4f}",
                "Next Step": "Production" if status == "PASSED" else "Retrain"
            }
        )
        
    except Exception as e:
        print("\n" + "="*80)
        print("‚ùå UAT INFERENCE FAILED")
        print("="*80)
        print(f"Error: {str(e)}")
        print("="*80 + "\n")
        
        slack.send(
            "UAT Pipeline Failed",
            level="error",
            extra_fields={
                "Model": config.MODEL_NAME,
                "Error": str(e)
            }
        )
        
        traceback.print_exc()
        sys.exit(1)
 
# ‚úÖ EXECUTE

if __name__ == "__main__":
    main()