In [None]:
#!/usr/bin/env python3
"""
Enterprise MLOps: Model Validation & Deployment Gate

Responsibilities:
- Resolve currently deployed model (alias-first)
- Load candidate run metrics & params from MLflow
- Enforce metric + parameter thresholds
- Emit CI/CD-friendly outputs
- Block deployment if validation fails
"""

import argparse
import logging
import sys
from typing import Dict, Optional

import mlflow
from mlflow.tracking import MlflowClient
from pyspark.sql import SparkSession
from pyspark.sql import SparkSession
from pyspark.dbutils import DBUtils

In [None]:

# ------------------------------------------------------------------------------
# Logging
# ------------------------------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s"
)
logger = logging.getLogger("model-validation")

In [None]:

# ------------------------------------------------------------------------------
# Databricks Context
# ------------------------------------------------------------------------------
spark = SparkSession.builder.getOrCreate()
dbutils = DBUtils(spark)

# ------------------------------------------------------------------------------
# Threshold Configuration (environment-aware if needed later)
# ------------------------------------------------------------------------------
METRIC_THRESHOLDS = {
    "accuracy": 0.85,
    "f1_score": 0.80,
    "precision": 0.75,
    "recall": 0.75,
    # lower is better
    "log_loss": 0.70,
}

REQUIRED_PARAMS = {
    # Example: enforce reproducibility / governance
    "learning_rate": (0.001, 0.5),
    "num_leaves": (8, 512),
}

In [None]:

# ------------------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------------------
def emit_ci_output(key: str, value: str) -> None:
    print(f"::mlflow-run-output::{key}={value}")

def set_task_value(key: str, value: str) -> None:
    try:
        dbutils.jobs.taskValues.set(key=key, value=value)
    except Exception as exc:
        logger.warning("Failed to set task value %s: %s", key, exc)

In [None]:

def resolve_deployed_model(
    client: MlflowClient,
    model_name: str,
    alias: str = "champion"
):
    try:
        mv = client.get_model_version_by_alias(model_name, alias)
        logger.info(
            "Resolved deployed model via alias '%s': version=%s run_id=%s",
            alias, mv.version, mv.run_id
        )
        return str(mv.version), mv.run_id
    except Exception:
        logger.info("Alias '%s' not found for model %s.", alias, model_name)
        return None, None

In [None]:

def evaluate_metric_thresholds(
    metrics: Dict[str, float],
    thresholds: Dict[str, float]
) -> list[str]:
    failures = []

    for name, threshold in thresholds.items():
        if name not in metrics:
            failures.append(f"{name}: missing")
            continue

        value = float(metrics[name])

        if name == "log_loss":
            if value > threshold:
                failures.append(f"{name}: {value:.4f} > {threshold}")
        else:
            if value < threshold:
                failures.append(f"{name}: {value:.4f} < {threshold}")

    return failures

In [None]:

def evaluate_param_constraints(
    params: Dict[str, str],
    constraints: Dict[str, tuple]
) -> list[str]:
    failures = []

    for name, (low, high) in constraints.items():
        if name not in params:
            failures.append(f"{name}: missing")
            continue

        try:
            value = float(params[name])
        except Exception:
            failures.append(f"{name}: not numeric ({params[name]})")
            continue

        if not (low <= value <= high):
            failures.append(f"{name}: {value} not in [{low}, {high}]")

    return failures

In [None]:

# ------------------------------------------------------------------------------
# Core Logic
# ------------------------------------------------------------------------------
def validate_candidate_model(
    env: str,
    model_name: str,
    candidate_run_id: str
) -> None:
    logger.info(
        "Starting model validation | env=%s model_name=%s candidate_run_id=%s",
        env, model_name, candidate_run_id
    )

    client = MlflowClient()

    # --------------------------------------------------------------------------
    # Fetch candidate run
    # --------------------------------------------------------------------------
    try:
        run = client.get_run(candidate_run_id)
    except Exception as exc:
        logger.error("Candidate run not found: %s", exc)
        raise RuntimeError("Invalid candidate_run_id") from exc

    candidate_metrics = run.data.metrics or {}
    candidate_params = run.data.params or {}

    logger.info("Candidate metrics: %s", candidate_metrics)
    logger.info("Candidate params: %s", candidate_params)

    # --------------------------------------------------------------------------
    # Resolve currently deployed model (Champion)
    # --------------------------------------------------------------------------
    deployed_version, deployed_run_id = resolve_deployed_model(
        client, model_name, alias="champion"
    )

    if deployed_version:
        set_task_value("deployed_model_version", deployed_version)
        set_task_value("deployed_run_id", deployed_run_id)
        emit_ci_output("deployed_model_version", deployed_version)
    else:
        logger.info("No deployed Champion found. Treating as first deployment.")
        set_task_value("deployed_model_version", "0")
        emit_ci_output("deployed_model_version", "0")

    # --------------------------------------------------------------------------
    # Metric Threshold Validation
    # --------------------------------------------------------------------------
    metric_failures = evaluate_metric_thresholds(
        candidate_metrics, METRIC_THRESHOLDS
    )

    # --------------------------------------------------------------------------
    # Parameter Constraint Validation
    # --------------------------------------------------------------------------
    param_failures = evaluate_param_constraints(
        candidate_params, REQUIRED_PARAMS
    )

    # --------------------------------------------------------------------------
    # Final Decision
    # --------------------------------------------------------------------------
    failures = metric_failures + param_failures

    if failures:
        logger.error("❌ Model validation FAILED")
        for f in failures:
            logger.error("   - %s", f)

        emit_ci_output("model_validation", "FAILED")
        set_task_value("model_validation", "FAILED")

        raise RuntimeError(
            "Candidate model failed validation thresholds. "
            "Blocking deployment."
        )

    logger.info("✅ Model validation PASSED")
    emit_ci_output("model_validation", "PASSED")
    set_task_value("model_validation", "PASSED")

    # --------------------------------------------------------------------------
    # Emit promotion metadata
    # --------------------------------------------------------------------------
    emit_ci_output("candidate_run_id", candidate_run_id)
    set_task_value("candidate_run_id", candidate_run_id)

    logger.info("Model validation completed successfully.")

In [None]:

# ------------------------------------------------------------------------------
# CLI
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Enterprise MLOps Model Validation")
    parser.add_argument("--env", required=True, help="Environment name")
    parser.add_argument("--model_name", required=True, help="Registered model name")
    parser.add_argument(
        "--candidate_run_id",
        required=True,
        help="MLflow run_id of trained candidate model"
    )

    args, unknown = parser.parse_known_args()
    if unknown:
        logger.info("Ignoring unknown args: %s", unknown)

    try:
        validate_candidate_model(
            env=args.env,
            model_name=args.model_name,
            candidate_run_id=args.candidate_run_id
        )
    except Exception:
        logger.exception("Model validation failed")
        sys.exit(1)