In [None]:
%pip install -e ..
%restart_python

In [None]:
from pathlib import Path
import sys
sys.path.append(str(Path.cwd().parent / 'src'))

In [None]:
# ------------------------------------------------------------------
# NOW safe to import sklearn / pandas / mlflow.pyfunc
# ------------------------------------------------------------------
import pandas as pd
import numpy as np
import sklearn
import lightgbm
import mlflow.pyfunc
import sys
import subprocess
import site
import mlflow
from loguru import logger

logger.info("Runtime versions:")
logger.info("  mlflow=%s", mlflow.__version__)
logger.info("  sklearn=%s", sklearn.__version__)
logger.info("  pandas=%s", pd.__version__)
logger.info("  lightgbm=%s", lightgbm.__version__)

In [None]:
#!/usr/bin/env python3
"""
Enterprise MLOps: Model Validation & Deployment Gate (Databricks + MLflow)

Responsibilities:
- Consume exact registered model artifact via model URI (chain of custody)
- Run deployability checks (artifact, signature, inference)
- Load candidate run metrics & params
- Enforce metric + parameter thresholds
- Resolve currently deployed model (alias-first)
- Attach Challenger alias on pass
- Emit CI/CD-friendly outputs
- Block deployment if validation fails
"""

import argparse
import logging
import os
import sys
from typing import Dict, Optional, Tuple
import subprocess
import sys
from honeywell.config import ProjectConfig
from honeywell.run_inference_chk import run_real_data_inference_sanity_check
import mlflow
import pandas as pd
import numpy as np
from mlflow.models import get_model_info
from mlflow.tracking import MlflowClient
from pyspark.sql import SparkSession
from pyspark.dbutils import DBUtils
import builtins
list = builtins.list 

config = ProjectConfig.from_yaml(config_path="../project_config_honeywell.yml", env="dev")

# ------------------------------------------------------------------------------
# Logging (centralized, CI-friendly)
# ------------------------------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger("model-validation")

# ------------------------------------------------------------------------------
# Databricks Context
# ------------------------------------------------------------------------------
spark = SparkSession.builder.getOrCreate()
dbutils = DBUtils(spark)

# ------------------------------------------------------------------------------
# Threshold Configuration (env-aware if needed later)
# ------------------------------------------------------------------------------
METRIC_THRESHOLDS = {
    "accuracy": 0.85,
    "f1_score": 0.80,
    "precision": 0.75,
    "recall": 0.75,
    "roc_auc": 0.85,
    # lower is better
    "log_loss": 0.5,
}

REQUIRED_PARAMS = {
    # Basic governance for LightGBM training params (Phase-2+ optional)
    "learning_rate": (0.0005, 0.1),
    "n_estimators": (100, 3000),
    "max_depth": (2, 12),
}


# ------------------------------------------------------------------------------
# Helpers: CI/CD outputs + Databricks task values
# ------------------------------------------------------------------------------

def emit_ci_output(key: str, value: str) -> None:
    # Generic CI-friendly output (harmless in Databricks)
    print(f"::mlflow-run-output::{key}={value}")


def set_task_value(key: str, value: str) -> None:
    try:
        dbutils.jobs.taskValues.set(key=key, value=value)
    except Exception as exc:
        logger.warning("Failed to set task value %s: %s", key, exc)


# ------------------------------------------------------------------------------
# Helpers: model registry resolution
# ------------------------------------------------------------------------------

def resolve_deployed_model(
    client: MlflowClient,
    model_name: str,
    alias: str = "champion",
) -> Tuple[Optional[str], Optional[str]]:
    try:
        mv = client.get_model_version_by_alias(model_name, alias)
        logger.info(
            "Resolved deployed model via alias '%s': version=%s run_id=%s",
            alias,
            mv.version,
            mv.run_id,
        )
        return str(mv.version), mv.run_id
    except Exception:
        logger.info("Alias '%s' not found for model %s.", alias, model_name)
        return None, None


# ------------------------------------------------------------------------------
# Helpers: metric + parameter validation
# ------------------------------------------------------------------------------

def evaluate_metric_thresholds(
    metrics: Dict[str, float],
    thresholds: Dict[str, float],
) -> list[str]:
    failures: list[str] = []

    for name, threshold in thresholds.items():
        if name not in metrics:
            failures.append(f"{name}: missing")
            continue

        value = float(metrics[name])

        # Directionality
        if name == "log_loss":
            if value > threshold:
                failures.append(f"{name}: {value:.4f} > {threshold}")
        else:
            if value < threshold:
                failures.append(f"{name}: {value:.4f} < {threshold}")

    return failures


def evaluate_param_constraints(
    params: Dict[str, str],
    constraints: Dict[str, tuple],
) -> list[str]:
    failures: list[str] = [] 

    for name, (low, high) in constraints.items():
        if name not in params:
            failures.append(f"{name}: missing")
            continue

        try:
            value = float(params[name])
        except Exception:
            failures.append(f"{name}: not numeric ({params[name]})")
            continue

        if not (low <= value <= high):
            failures.append(f"{name}: {value} not in [{low}, {high}]")

    return failures


# ------------------------------------------------------------------------------
# Phase-1 Deployability Checks (Databricks MLOps Book)
# ------------------------------------------------------------------------------

def run_deployability_checks(model_uri: str) -> None:
    logger.info("Running deployability checks on model URI: %s", model_uri)

    # 1) Load model
    try:
        model = mlflow.pyfunc.load_model(model_uri)
    except Exception as exc:
        logger.error("‚ùå Failed to load model from URI: %s", model_uri)
        raise RuntimeError("Model artifact cannot be loaded") from exc

    logger.info("‚úÖ Model loads successfully from URI")

    # 2) Assert artifacts exist
    local_path = mlflow.artifacts.download_artifacts(model_uri)

    mlmodel_path = os.path.join(local_path, "MLmodel")
    if not os.path.exists(mlmodel_path):
        raise RuntimeError("MLmodel file missing in artifact directory")

    binary_found = any(
        f.endswith((".pkl", ".pt", ".onnx", ".joblib", ".bin"))
        for f in os.listdir(local_path)
    )

    if not binary_found:
        raise RuntimeError("No model binary found in artifact directory")

    logger.info("‚úÖ Artifact existence checks passed")

    # 3) Assert signature exists
    info = get_model_info(model_uri)

    if info.signature is None:
        raise RuntimeError("Model signature is missing")

    logger.info("‚úÖ Model signature exists")

    try:
        
        # Get real data
        test_data = run_real_data_inference_sanity_check(
            model=model,
            model_uri=model_uri,
            spark=spark,
            config=config,
            n_rows=1,
        )
    
    #     # --- STEP 2: CONVERT REAL TEST DATA TO WORKING FORMAT ---
    #     # Take the first row of your test data and convert to a list of one dictionary
        real_row_dict = test_data.head(1).to_dict(orient='records')
    #     # Re-create the DataFrame from that dictionary (matches your 'working format')
        sample_df = pd.DataFrame(real_row_dict)

    #     # --- STEP 3: STRICT TYPE CASTING (Mandatory for 2026) ---
    #     # This ensures no 'category' types remain, preventing the <U0 error
        sample_df = sample_df.astype({
            "SOC": "float64",
            "Voltage": "float64",
            "Current": "float64",
            "Battery_Temp": "float64",
            "Ambient_Temp": "float64",
            "Charging_Duration": "float64",
            "Degradation_Rate": "float64",
            "Efficiency": "float64",
            "Charging_Cycles": "int64",
            "Battery_Type": "str",
            "EV_Model": "str",
            "Charging_Mode": "str"
        })

    #     # --- STEP 4: PREDICT ---
    #     logger.info("Executing prediction with cleaned real-data row...")
        preds = model.predict(sample_df)
        logger.info("‚úÖ Prediction successful: %s", preds )
    except Exception as preds:
        logger.error("‚ùå Prediction execution failed: %s", preds)
        raise preds


# ------------------------------------------------------------------------------
# Core Validation Logic
# ------------------------------------------------------------------------------

def validate_candidate_model(
    env: str,
    model_name: str,
    model_uri: str,
    candidate_run_id: str,
) -> None:
    logger.info(
        "Starting model validation | env=%s model_name=%s model_uri=%s run_id=%s",
        env,
        model_name,
        model_uri,
        candidate_run_id,
    )

    client = MlflowClient()

    # --------------------------------------------------------------------------
    # Phase-1: Deployability checks on registered artifact
    # --------------------------------------------------------------------------
    run_deployability_checks(model_uri)

    # --------------------------------------------------------------------------
    # Fetch candidate run
    # --------------------------------------------------------------------------
    try:
        run = client.get_run(candidate_run_id)
    except Exception as exc:
        logger.error("Candidate run not found: %s", exc)
        raise RuntimeError("Invalid candidate_run_id") from exc

    candidate_metrics = run.data.metrics or {}
    candidate_params = run.data.params or {}

    logger.info("Candidate metrics: %s", candidate_metrics)
    logger.info("Candidate params: %s", candidate_params)

    # --------------------------------------------------------------------------
    # Resolve currently deployed model (Champion)
    # --------------------------------------------------------------------------
    deployed_version, deployed_run_id = resolve_deployed_model(
        client, model_name, alias="champion"
    )

    if deployed_version:
        set_task_value("deployed_model_version", deployed_version)
        set_task_value("deployed_run_id", deployed_run_id or "")
        emit_ci_output("deployed_model_version", deployed_version)
    else:
        logger.info("No deployed Champion found. Treating as first deployment.")
        set_task_value("deployed_model_version", "0")
        emit_ci_output("deployed_model_version", "0")

    # --------------------------------------------------------------------------
    # Metric Threshold Validation
    # --------------------------------------------------------------------------
    metric_failures = evaluate_metric_thresholds(
        candidate_metrics, METRIC_THRESHOLDS
    )

    # --------------------------------------------------------------------------
    # Parameter Constraint Validation
    # --------------------------------------------------------------------------
    param_failures = evaluate_param_constraints(
        candidate_params, REQUIRED_PARAMS
    )

    # ----------------------------------------------------------------------
    # LightGBM-specific governance: num_leaves <= 2^max_depth
    # ----------------------------------------------------------------------
    try:
        max_depth = int(candidate_params.get("max_depth", -1))
        num_leaves = int(candidate_params.get("num_leaves", -1))

        if max_depth > 0 and num_leaves > 0:
            max_allowed_leaves = 2 ** max_depth
            if num_leaves > max_allowed_leaves:
                param_failures.append(
                    f"num_leaves {num_leaves} > 2^max_depth ({max_allowed_leaves})"
                )
    except Exception as exc:
        param_failures.append(f"num_leaves/max_depth validation error: {exc}")


    # --------------------------------------------------------------------------
    # Final Decision
    # --------------------------------------------------------------------------
    failures = metric_failures + param_failures

    if failures:
        logger.error("‚ùå Model validation FAILED")
        for f in failures:
            logger.error("   - %s", f)

        emit_ci_output("model_validation", "FAILED")
        set_task_value("model_validation", "FAILED")

        # Mark model version as failed (optional governance)
        try:
            mv = client.get_model_version_by_alias(model_name, "Challenger")
            client.set_model_version_tag(
                name=model_name,
                version=mv.version,
                key="validation_status",
                value="failed",
            )
        except Exception:
            pass

        raise RuntimeError(
            "Candidate model failed validation thresholds. "
            "Blocking deployment."
        )

    # --------------------------------------------------------------------------
    # PASS: Attach Challenger alias
    # --------------------------------------------------------------------------
    logger.info("‚úÖ Model validation PASSED")

    # Extract version from URI: models:/name/version
    try:
        version = model_uri.split("/")[-1]
        client.set_registered_model_alias(
            name=model_name,
            alias="Challenger",
            version=version,
        )
        client.set_model_version_tag(
            name=model_name,
            version=version,
            key="validation_status",
            value="passed",
        )
        logger.info("Attached 'Challenger' alias to model version %s", version)
    except Exception as exc:
        logger.warning("Failed to attach Challenger alias: %s", exc)

    emit_ci_output("model_validation", "PASSED")
    set_task_value("model_validation", "PASSED")

    # --------------------------------------------------------------------------
    # Emit promotion metadata
    # --------------------------------------------------------------------------
    emit_ci_output("candidate_run_id", candidate_run_id)
    set_task_value("candidate_run_id", candidate_run_id)

    emit_ci_output("model_uri", model_uri)
    set_task_value("model_uri", model_uri)

    logger.info("Model validation completed successfully.")


# ------------------------------------------------------------------------------
# CLI Entrypoint
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    # parser = argparse.ArgumentParser(description="Honeywell MLOps Model Validation")
    
    # parser.add_argument("--env", default="honeywell_mlops_dev")
    # parser.add_argument("--model_name", default="honeywell_mlops_dev.honeywell.ev_battery_charging_model_basic")
    # parser.add_argument("--model_uri", default="models:/honeywell_mlops_dev.honeywell.ev_battery_charging_model_basic/4")
    # parser.add_argument("--candidate_run_id", default="090e59afc8d3467e9c1c28a6916e9ec3")

    # real test case

    # parser.add_argument("--env", required=True, help="Environment name")
    # parser.add_argument("--model_name", required=True, help="Registered model name")
    # parser.add_argument(
    #     "--model_uri",
    #     required=True,
    #     help="Model URI in Unity Catalog (models:/catalog.schema.model/version)",
    # )
    # parser.add_argument(
    #     "--candidate_run_id",
    #     required=True,
    #     help="MLflow run_id of trained candidate model",
    # )

    # args, unknown = parser.parse_known_args()
    # if unknown:
    #     logger.info("Ignoring unknown args: %s", unknown)


    def get_widget(name: str, default: str = "") -> str:
        try:
            return dbutils.widgets.get(name)
        except Exception:
            return default


    # ------------------------------------------------------------------
    # Read widgets
    # ------------------------------------------------------------------
    arg_env = get_widget("env", "dev")
    arg_model_name = get_widget("model_name")
    arg_latest_version = get_widget("latest_version")
    arg_candidate_run_id = get_widget("candidate_run_id")

    if not all([arg_model_name, arg_latest_version, arg_candidate_run_id]):
        raise RuntimeError("model_name, latest_version, candidate_run_id are required")
        # In notebooks, you can use dbutils.notebook.exit() or sys.exit(1)
        # sys.exit(1)

    model_uri = f"models:/{arg_model_name}/{arg_latest_version}"
    logger.info("Using model_uri: %s", model_uri)

    # ------------------------------------------------------------------
    # Install model dependencies
    # ------------------------------------------------------------------
    deps_path = mlflow.pyfunc.get_model_dependencies(model_uri)
    logger.info("üì¶ Model deps file: %s", deps_path)

    subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "pip"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", deps_path])
    # ------------------------------------------------------------------
    # Force Python to use pip-installed packages first
    # ------------------------------------------------------------------
    venv_site = site.getsitepackages()[0]
    if venv_site not in sys.path:
        sys.path.insert(0, venv_site)

    logger.info("üîß Using venv site-packages: %s", venv_site)
    logger.info("üîß sys.path[0] = %s", sys.path[0])

    logger.info("‚úÖ Model dependencies installed & activated")
    try:
        validate_candidate_model(
            env=arg_env,
            model_name=f"{arg_model_name}",
            model_uri=f"models:/{arg_model_name}/{arg_latest_version}",
            candidate_run_id=arg_candidate_run_id,
        )
    except Exception:
        logger.exception("Model validation failed")
        sys.exit(1)
