In [None]:
# Databricks notebook source
# ==================================================================================
# üöÄ CREDIT RISK CLASSIFICATION TRAINING - CONFIG DRIVEN
# ==================================================================================
# Fully integrated with pipeline_config.yml and config.yml
# Supports Random Forest Classifier with preprocessing pipeline
# Logs all experiments to MLflow and evaluates classification metrics
# ==================================================================================

# COMMAND ----------

%pip install scikit-learn pyyaml

# COMMAND ----------

import mlflow
import yaml
import numpy as np
import pandas as pd
import warnings
import time
from datetime import datetime
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, 
    roc_auc_score, confusion_matrix, classification_report
)
from mlflow.models.signature import infer_signature
from pyspark.sql import SparkSession

warnings.filterwarnings("ignore")

print("=" * 80)
print("üöÄ CREDIT RISK CLASSIFICATION TRAINING PIPELINE")
print("=" * 80)

# ==================================================================================
# ‚úÖ LOAD PIPELINE CONFIGURATION
# ==================================================================================
print("\nüìã Step 1: Loading pipeline configuration...")

try:
    with open("pipeline_config.yml", "r") as f:
        pipeline_cfg = yaml.safe_load(f)

    # Extract configuration values
    MODEL_TYPE = pipeline_cfg["model"]["type"]
    CATALOG = pipeline_cfg["model"]["catalog"]
    SCHEMA = pipeline_cfg["model"]["schema"]
    BASE_NAME = pipeline_cfg["model"]["base_name"]
    
    EXPERIMENT_NAME = pipeline_cfg["experiment"]["name"]
    MODEL_ARTIFACT_PATH = pipeline_cfg["experiment"]["artifact_path"]
    RUN_NAME_PREFIX = pipeline_cfg["experiment"]["run_name_prefix"]

    # Data configuration
    INPUT_TABLE = pipeline_cfg["data"]["input_table"]
    FEATURE_COLS = pipeline_cfg["data"]["features"]
    LABEL_COL = pipeline_cfg["data"]["label"]
    TEST_SIZE = pipeline_cfg["data"]["split"]["test_size"]
    RANDOM_STATE = pipeline_cfg["data"]["split"]["random_state"]
    STRATIFY = pipeline_cfg["data"]["split"]["stratify"]

    # Metrics configuration
    METRICS_CONFIG = pipeline_cfg["metrics"]["classification"]
    PRIMARY_METRIC = METRICS_CONFIG["primary_metric"]
    DIRECTION = METRICS_CONFIG["direction"]
    TRACKED_METRICS = METRICS_CONFIG["tracked_metrics"]
    THRESHOLD_METRICS = METRICS_CONFIG["threshold_metrics"]

    print(f"‚úÖ Pipeline configuration loaded successfully!")
    print(f"\nüìä Configuration Summary:")
    print(f"   Model Type: {MODEL_TYPE.upper()}")
    print(f"   Experiment: {EXPERIMENT_NAME}")
    print(f"   Input Table: {INPUT_TABLE}")
    print(f"   Features: {len(FEATURE_COLS)} columns")
    print(f"   Label: {LABEL_COL}")
    print(f"   Primary Metric: {PRIMARY_METRIC}")
    print(f"   Test Split: {TEST_SIZE * 100}%")

except FileNotFoundError:
    print("‚ùå ERROR: pipeline_config.yml not found!")
    print("üí° Please ensure pipeline_config.yml is in the notebook directory")
    raise
except Exception as e:
    print(f"‚ùå ERROR loading pipeline configuration: {e}")
    raise

print("=" * 80)

# ==================================================================================
# ‚úÖ LOAD EXPERIMENT CONFIGURATIONS (config.yml)
# ==================================================================================
def load_experiment_configs(path="config.yml"):
    """Load Random Forest hyperparameter configurations"""
    print(f"\nüìÑ Step 2: Loading experiment configurations from {path}...")
    
    try:
        with open(path, "r") as f:
            config = yaml.safe_load(f)

        num_experiments = len(config["experiments"])
        print(f"‚úÖ Found {num_experiments} experiment configuration(s):")
        
        for i, exp in enumerate(config["experiments"], 1):
            print(f"   {i}. {exp['name']}")
        
        return config

    except FileNotFoundError:
        print(f"‚ùå ERROR: {path} not found!")
        print("üí° Please create config.yml with Random Forest configurations")
        raise
    except Exception as e:
        print(f"‚ùå ERROR loading experiment configs: {e}")
        raise

# ==================================================================================
# ‚úÖ DATA LOADING FROM DELTA TABLE
# ==================================================================================
def load_data(spark):
    """
    Load and prepare data from Delta table
    Returns: X (features), y (labels) as pandas DataFrames
    """
    print(f"\nüì¶ Step 3: Loading data from Delta table...")
    print(f"   Table: {INPUT_TABLE}")
    
    try:
        # Read from Delta
        df = spark.read.format("delta").table(INPUT_TABLE)
        
        # Select features and label
        df_selected = df.select(*FEATURE_COLS, LABEL_COL)
        
        # Convert to pandas
        df_pd = df_selected.toPandas()
        
        # Separate features and labels
        X = df_pd[FEATURE_COLS]
        y = df_pd[LABEL_COL]
        
        # Convert label to binary if string
        if y.dtype == 'object':
            y = (y == 'yes').astype(int)
        
        print(f"‚úÖ Data loaded successfully!")
        print(f"   Total samples: {len(df_pd):,}")
        print(f"   Features shape: {X.shape}")
        print(f"   Label distribution:")
        print(f"      Class 0 (No Default): {(y == 0).sum():,} ({(y == 0).sum() / len(y) * 100:.1f}%)")
        print(f"      Class 1 (Default): {(y == 1).sum():,} ({(y == 1).sum() / len(y) * 100:.1f}%)")
        
        return X, y

    except Exception as e:
        print(f"‚ùå Failed to load data from '{INPUT_TABLE}': {e}")
        print("üí° Verify the table exists and contains required columns")
        raise

# ==================================================================================
# ‚úÖ TRAIN SINGLE EXPERIMENT
# ==================================================================================
def train_single_experiment(X, y, params, run_name):
    """
    Train a single Random Forest Classifier configuration
    
    Args:
        X: Feature matrix
        y: Target labels
        params: Model hyperparameters from config.yml
        run_name: Name for MLflow run
    
    Returns:
        run_id: MLflow run ID
        metrics_dict: Dictionary of evaluation metrics
    """
    print(f"\n{'='*70}")
    print(f"üîÅ Training Experiment: {run_name}")
    print(f"{'='*70}")
    print(f"üìù Hyperparameters:")
    for k, v in params.items():
        print(f"   {k}: {v}")

    # Split data (stratified for imbalanced classes)
    stratify_param = y if STRATIFY else None
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, 
        test_size=TEST_SIZE, 
        random_state=RANDOM_STATE,
        stratify=stratify_param
    )

    print(f"\nüìä Data Split:")
    print(f"   Train samples: {len(X_train):,}")
    print(f"   Test samples: {len(X_test):,}")
    print(f"   Train class 1: {(y_train == 1).sum():,} ({(y_train == 1).sum() / len(y_train) * 100:.1f}%)")
    print(f"   Test class 1: {(y_test == 1).sum():,} ({(y_test == 1).sum() / len(y_test) * 100:.1f}%)")

    # Start MLflow run
    with mlflow.start_run(run_name=run_name) as run:
        run_id = run.info.run_id
        print(f"\nüîñ MLflow Run ID: {run_id}")

        # Log configuration metadata
        mlflow.log_param("model_type", MODEL_TYPE)
        mlflow.log_param("experiment_name", run_name)
        mlflow.log_param("timestamp", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
        
        # Log all hyperparameters
        for param_name, param_value in params.items():
            mlflow.log_param(param_name, param_value)
        
        # Log data split info
        mlflow.log_param("train_size", len(X_train))
        mlflow.log_param("test_size", len(X_test))
        mlflow.log_param("stratified_split", STRATIFY)

        # Train model
        print(f"\nüèãÔ∏è Training Random Forest Classifier...")
        start_time = time.time()
        
        model = RandomForestClassifier(
            random_state=RANDOM_STATE,
            **params
        )
        
        model.fit(X_train, y_train)
        train_time = time.time() - start_time
        
        print(f"   ‚úÖ Training completed in {train_time:.2f} seconds")
        mlflow.log_metric("train_time", train_time)

        # Make predictions
        print(f"\nüìä Evaluating model...")
        
        # Training predictions
        start_time = time.time()
        y_train_pred = model.predict(X_train)
        y_train_pred_proba = model.predict_proba(X_train)[:, 1]
        
        # Test predictions
        y_test_pred = model.predict(X_test)
        y_test_pred_proba = model.predict_proba(X_test)[:, 1]
        inference_time = (time.time() - start_time) / len(X_test) * 1000  # ms per sample
        
        mlflow.log_metric("inference_time", inference_time)

        # Calculate metrics
        metrics_dict = {
            # Training metrics
            "train_accuracy": accuracy_score(y_train, y_train_pred),
            
            # Test metrics
            "test_accuracy": accuracy_score(y_test, y_test_pred),
            "test_precision": precision_score(y_test, y_test_pred, zero_division=0),
            "test_recall": recall_score(y_test, y_test_pred, zero_division=0),
            "test_f1": f1_score(y_test, y_test_pred, zero_division=0),
            "test_roc_auc": roc_auc_score(y_test, y_test_pred_proba),
            
            # Performance metrics
            "train_time": train_time,
            "inference_time": inference_time
        }

        # Log all metrics to MLflow
        for metric_name, metric_value in metrics_dict.items():
            mlflow.log_metric(metric_name, metric_value)

        # Print metrics
        print(f"\n‚úÖ Evaluation Results:")
        print(f"   {'Metric':<20} {'Train':<12} {'Test':<12}")
        print(f"   {'-'*44}")
        print(f"   {'Accuracy':<20} {metrics_dict['train_accuracy']:<12.4f} {metrics_dict['test_accuracy']:<12.4f}")
        print(f"   {'Precision':<20} {'-':<12} {metrics_dict['test_precision']:<12.4f}")
        print(f"   {'Recall':<20} {'-':<12} {metrics_dict['test_recall']:<12.4f}")
        print(f"   {'F1 Score':<20} {'-':<12} {metrics_dict['test_f1']:<12.4f}")
        print(f"   {'ROC-AUC':<20} {'-':<12} {metrics_dict['test_roc_auc']:<12.4f}")
        print(f"\n   Training Time: {train_time:.2f}s")
        print(f"   Inference Time: {inference_time:.3f}ms/sample")

        # Confusion Matrix
        cm = confusion_matrix(y_test, y_test_pred)
        print(f"\n   Confusion Matrix:")
        print(f"   [[TN={cm[0,0]:<6} FP={cm[0,1]:<6}]")
        print(f"    [FN={cm[1,0]:<6} TP={cm[1,1]:<6}]]")

        # Check threshold metrics
        print(f"\nüéØ Threshold Check:")
        passes_thresholds = True
        for metric_name, threshold_value in THRESHOLD_METRICS.items():
            actual_value = metrics_dict.get(metric_name)
            if actual_value is not None:
                passes = actual_value >= threshold_value
                status = "‚úÖ" if passes else "‚ùå"
                print(f"   {status} {metric_name}: {actual_value:.4f} (threshold: {threshold_value})")
                if not passes:
                    passes_thresholds = False
        
        mlflow.log_param("passes_thresholds", passes_thresholds)

        # Create model signature
        signature = infer_signature(X_train, model.predict(X_train))

        # Log model to MLflow
        print(f"\nüíæ Logging model to MLflow...")
        mlflow.sklearn.log_model(
            model,
            artifact_path=MODEL_ARTIFACT_PATH,
            signature=signature,
            registered_model_name=None  # Will register later in evaluation script
        )

        # Log feature importance
        if hasattr(model, 'feature_importances_'):
            feature_importance = pd.DataFrame({
                'feature': FEATURE_COLS,
                'importance': model.feature_importances_
            }).sort_values('importance', ascending=False)
            
            print(f"\nüéØ Top 10 Most Important Features:")
            for idx, row in feature_importance.head(10).iterrows():
                print(f"   {row['feature']:<25} {row['importance']:.4f}")
            
            # Log as artifact
            importance_file = "feature_importance.csv"
            feature_importance.to_csv(importance_file, index=False)
            mlflow.log_artifact(importance_file)

        print(f"\n‚úÖ Experiment '{run_name}' completed successfully!")
        print(f"{'='*70}")

        return run_id, metrics_dict

# ==================================================================================
# ‚úÖ MAIN EXECUTION
# ==================================================================================
if __name__ == "__main__":
    
    # Initialize MLflow
    print("\nüîß Initializing MLflow...")
    try:
        mlflow.set_tracking_uri("databricks")
        mlflow.set_registry_uri("databricks-uc")
        mlflow.set_experiment(EXPERIMENT_NAME)
        print(f"‚úÖ MLflow experiment set: {EXPERIMENT_NAME}")
    except Exception as e:
        print(f"‚ùå Failed to initialize MLflow: {e}")
        raise

    # Initialize Spark
    print("\nüîß Initializing Spark...")
    try:
        spark = SparkSession.builder.appName("CreditRiskTraining").getOrCreate()
        print("‚úÖ Spark session created")
    except Exception as e:
        print(f"‚ùå Failed to initialize Spark: {e}")
        raise

    # Load data
    X, y = load_data(spark)

    # Load experiment configurations
    config = load_experiment_configs()

    # Start training all experiments
    print("\n" + "=" * 80)
    print("üöÄ STARTING TRAINING RUNS")
    print("=" * 80)
    
    run_results = []
    total_experiments = len(config["experiments"])

    for idx, exp in enumerate(config["experiments"], 1):
        exp_name = exp["name"]
        exp_params = exp["params"]
        
        print(f"\n[Experiment {idx}/{total_experiments}]")
        
        try:
            run_id, metrics = train_single_experiment(
                X, y, 
                exp_params, 
                run_name=f"{RUN_NAME_PREFIX}_{exp_name}"
            )
            
            run_results.append({
                'name': exp_name,
                'run_id': run_id,
                'metrics': metrics,
                'params': exp_params
            })
            
        except Exception as e:
            print(f"‚ùå Failed to train {exp_name}: {e}")
            print(f"   Continuing with next experiment...")
            continue

    # Display final summary
    print("\n" + "=" * 80)
    print("‚úÖ‚úÖ‚úÖ ALL TRAINING RUNS COMPLETED ‚úÖ‚úÖ‚úÖ")
    print("=" * 80)
    
    if run_results:
        print(f"\nüìä Training Summary ({len(run_results)}/{total_experiments} successful):")
        print(f"\n{'Rank':<6} {'Experiment':<35} {PRIMARY_METRIC.upper():<12} {'ROC-AUC':<10} {'Run ID':<40}")
        print("-" * 103)
        
        # Sort by primary metric
        if DIRECTION == "maximize":
            sorted_results = sorted(run_results, key=lambda x: x['metrics'][PRIMARY_METRIC], reverse=True)
        else:
            sorted_results = sorted(run_results, key=lambda x: x['metrics'][PRIMARY_METRIC])
        
        for rank, result in enumerate(sorted_results, 1):
            marker = "üèÜ" if rank == 1 else f"{rank}."
            name = result['name']
            primary_score = result['metrics'][PRIMARY_METRIC]
            roc_auc = result['metrics']['test_roc_auc']
            run_id = result['run_id']
            
            print(f"{marker:<6} {name:<35} {primary_score:<12.4f} {roc_auc:<10.4f} {run_id}")
        
        # Highlight best model
        best = sorted_results[0]
        print("\n" + "=" * 80)
        print("üèÜ BEST MODEL FROM THIS TRAINING SESSION")
        print("=" * 80)
        print(f"   Name: {best['name']}")
        print(f"   {PRIMARY_METRIC.upper()}: {best['metrics'][PRIMARY_METRIC]:.4f}")
        print(f"   ROC-AUC: {best['metrics']['test_roc_auc']:.4f}")
        print(f"   Accuracy: {best['metrics']['test_accuracy']:.4f}")
        print(f"   Recall: {best['metrics']['test_recall']:.4f}")
        print(f"   Precision: {best['metrics']['test_precision']:.4f}")
        print(f"   Run ID: {best['run_id']}")
        print("\n   Key Hyperparameters:")
        for k in ['n_estimators', 'max_depth', 'min_samples_split', 'class_weight']:
            if k in best['params']:
                print(f"      {k}: {best['params'][k]}")
        print("=" * 80)
        
    else:
        print("\n‚ö†Ô∏è No successful training runs completed")
        print("üí° Check errors above and fix configurations")

    # Next steps
    print("\nüìå Next Steps:")
    print("   1. Run model_evaluation.py to evaluate ALL models")
    print("   2. Best model will be selected and registered")
    print("   3. Continue with UAT ‚Üí Production promotion")
    
    print(f"\nüí° Note:")
    print(f"   All {len(run_results)} models logged to: {EXPERIMENT_NAME}")
    print(f"   Evaluation will compare based on: {PRIMARY_METRIC}")
    print(f"   Models passing thresholds will be registered automatically")
    
    # Save task values for workflow
    try:
        dbutils.jobs.taskValues.set(key="model_type", value=MODEL_TYPE)
        dbutils.jobs.taskValues.set(key="experiment_name", value=EXPERIMENT_NAME)
        dbutils.jobs.taskValues.set(key="num_experiments", value=len(run_results))
        dbutils.jobs.taskValues.set(key="primary_metric", value=PRIMARY_METRIC)
        
        if run_results:
            dbutils.jobs.taskValues.set(
                key="best_score", 
                value=float(sorted_results[0]['metrics'][PRIMARY_METRIC])
            )
            dbutils.jobs.taskValues.set(
                key="best_run_id",
                value=sorted_results[0]['run_id']
            )
        
        print(f"\n‚úÖ Task values saved for workflow automation")
        
    except:
        print(f"\n‚ÑπÔ∏è Not running in Databricks workflow - skipping task values")
    
    print("\nüéâ Training pipeline completed successfully!")
    print("=" * 80)