In [None]:
# üéØ MODEL EVALUATION SCRIPT 
import mlflow
from mlflow.tracking import MlflowClient
import pandas as pd
import json
import yaml
import sys
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, TimestampType
from delta.tables import DeltaTable
from typing import Dict, List

print("=" * 80)
print("üéØ MODEL EVALUATION PIPELINE ‚Äî UNIQUE ENTRIES ONLY")
print("=" * 80)

# -------------------------
# 1Ô∏è‚É£ Load Config
# -------------------------

print("\nüìã Loading configuration from pipeline_config.yml...")

try:
    with open("pipeline_config.yml", "r") as f:
        pipeline_cfg = yaml.safe_load(f)
    print("‚úÖ Configuration Loaded\n")

except Exception as e:
    print(f"‚ùå Error loading config: {e}")
    sys.exit(1)

class Config:
    def __init__(self, cfg):
        self.MODEL_TYPE = cfg["model"]["type"]
        UC_CATALOG = cfg["model"]["catalog"]
        UC_SCHEMA = cfg["model"]["schema"]
        BASE_NAME = cfg["model"]["base_name"]

        self.MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{BASE_NAME}_{self.MODEL_TYPE}"

        self.EXPERIMENT_NAME = cfg["experiment"]["name"]
        self.ARTIFACT_PATH = cfg["experiment"]["artifact_path"]

        metrics_cfg = cfg["metrics"]["classification"]
        self.PRIMARY_METRIC = metrics_cfg["primary_metric"]
        self.TRACKED_METRICS = metrics_cfg["tracked_metrics"]
        self.DIRECTION = metrics_cfg["direction"]
        self.THRESHOLD_METRICS = metrics_cfg["threshold_metrics"]

        self.EVALUATION_LOG_TABLE = cfg["tables"]["evaluation_log"]
        self.RECENT_N = cfg["comparison"]["recent_n"]

        print(f"\nüìå Evaluation Config Summary:")
        print(f"   Model: {self.MODEL_NAME}")
        print(f"   Primary Metric: {self.PRIMARY_METRIC} ({self.DIRECTION})")
        print(f"   Logging Table: {self.EVALUATION_LOG_TABLE}")

config = Config(pipeline_cfg)
print("=" * 80)

# -------------------------
# 2Ô∏è‚É£ Initialize MLflow & Spark
# -------------------------

try:
    spark = SparkSession.builder.appName("ModelEvaluationOnly").getOrCreate()
    mlflow.set_tracking_uri("databricks")
    client = MlflowClient()

    experiment = mlflow.get_experiment_by_name(config.EXPERIMENT_NAME)
    if experiment is None:
        raise Exception(f"Experiment not found: {config.EXPERIMENT_NAME}")

    print("\nüî• MLflow + Spark loaded successfully\n")

except Exception as e:
    print(f"‚ùå MLflow Init Failed: {e}")
    sys.exit(1)


# -------------------------
# TABLE SCHEMA DEFINITION
# -------------------------

def get_evaluation_table_schema():
    """Define fixed schema for evaluation log table"""
    return StructType([
        StructField("timestamp", TimestampType(), True),
        StructField("run_id", StringType(), True),
        StructField("run_name", StringType(), True),
        StructField("model_name", StringType(), True),
        StructField("primary_metric", StringType(), True),
        StructField("primary_metric_value", DoubleType(), True),
        StructField("all_metrics_json", StringType(), True),
        StructField("params_json", StringType(), True),
        StructField("model_uri", StringType(), True)
    ])


# -------------------------
# AUTO TABLE CREATION
# -------------------------

def ensure_table_exists():
    """Create evaluation delta table if not exists"""
    try:
        spark.sql(f"DESCRIBE TABLE {config.EVALUATION_LOG_TABLE}")
        print(f"üìå Table exists: {config.EVALUATION_LOG_TABLE}")
    except:
        print(f"üÜï Creating new Delta table: {config.EVALUATION_LOG_TABLE}")
        schema = get_evaluation_table_schema()
        empty_df = spark.createDataFrame([], schema)
        empty_df.write.format("delta").option("overwriteSchema", "true").saveAsTable(config.EVALUATION_LOG_TABLE)
        print(f"‚úÖ Table created: {config.EVALUATION_LOG_TABLE}")


# -------------------------
# CHECK IF ALREADY LOGGED
# -------------------------

def get_existing_run_ids() -> set:
    """Get all run_ids already in the evaluation table"""
    try:
        existing_df = spark.sql(f"""
            SELECT DISTINCT run_id 
            FROM {config.EVALUATION_LOG_TABLE}
        """)
        existing_ids = set([row.run_id for row in existing_df.collect()])
        print(f"üìä Found {len(existing_ids)} existing entries in evaluation table")
        return existing_ids
    except:
        print("üìä No existing entries found (new table)")
        return set()


# -------------------------
# 3Ô∏è‚É£ Fetch Runs
# -------------------------

def get_recent_runs():
    print("\nüìç Fetching Experiment Runs...")

    order = f"metrics.{config.PRIMARY_METRIC} {'DESC' if config.DIRECTION=='maximize' else 'ASC'}"

    runs = client.search_runs(
        [experiment.experiment_id],
        order_by=[order],
        max_results=config.RECENT_N
    )

    if not runs:
        print("‚ö† No model runs found.")
        return []

    # Get existing run_ids to filter out duplicates
    existing_run_ids = get_existing_run_ids()

    run_list = []
    skipped_count = 0

    for run in runs:
        run_id = run.info.run_id
        
        # üî• Skip if already logged
        if run_id in existing_run_ids:
            skipped_count += 1
            continue

        metrics = {m: run.data.metrics.get(m) for m in config.TRACKED_METRICS if m in run.data.metrics}

        run_list.append({
            "run_id": run_id,
            "run_name": run.info.run_name or "unnamed_run",
            "primary_metric": run.data.metrics.get(config.PRIMARY_METRIC),
            "all_metrics": metrics,
            "params": run.data.params,
            "model_uri": f"runs:/{run_id}/{config.ARTIFACT_PATH}",
            "timestamp": datetime.fromtimestamp(run.info.start_time / 1000),
        })

    print(f"üìå {len(run_list)} NEW runs to log (Skipped {skipped_count} already logged)")
    return run_list


# -------------------------
# 4Ô∏è‚É£ Log Evaluation Results (NO DUPLICATES)
# -------------------------

def log_results(run_list):
    if not run_list:
        print("\n‚úÖ No new runs to log (all already exist)")
        return

    print(f"\nüìù Logging {len(run_list)} new evaluation results...")

    records = []
    for run in run_list:
        records.append({
            "timestamp": datetime.now(),
            "run_id": run["run_id"],
            "run_name": run["run_name"],
            "model_name": config.MODEL_NAME,
            "primary_metric": config.PRIMARY_METRIC,
            "primary_metric_value": float(run["primary_metric"]) if run["primary_metric"] else 0.0,
            "all_metrics_json": json.dumps(run["all_metrics"]),
            "params_json": json.dumps(run["params"]),
            "model_uri": run["model_uri"]
        })

    df = pd.DataFrame(records)
    spark_df = spark.createDataFrame(df, schema=get_evaluation_table_schema())

    try:
        # Append only new records
        spark_df.write.format("delta").mode("append").saveAsTable(config.EVALUATION_LOG_TABLE)
        print(f"‚úÖ Successfully logged {len(run_list)} new entries")
    except Exception as e:
        print(f"‚ùå Failed to log results: {e}")


# -------------------------
# 5Ô∏è‚É£ Display Summary
# -------------------------

def show_summary(run_list):
    if not run_list:
        print("\nüìä No new runs to display")
        return

    print("\nüìä NEWLY LOGGED MODEL RESULTS:\n")
    for rank, run in enumerate(run_list[:10], 1):
        pm_val = run['primary_metric']
        print(f"{rank}. {run['run_name']} ‚Üí {config.PRIMARY_METRIC}: {pm_val:.4f if pm_val else 0.0}")


# -------------------------
# 6Ô∏è‚É£ Show All Unique Experiments
# -------------------------

def show_all_experiments():
    """Display all unique experiments in the evaluation table"""
    print("\nüìã ALL UNIQUE EXPERIMENTS IN EVALUATION TABLE:\n")
    
    try:
        all_runs = spark.sql(f"""
            SELECT DISTINCT run_name, primary_metric_value, timestamp
            FROM {config.EVALUATION_LOG_TABLE}
            ORDER BY primary_metric_value DESC
        """)
        
        results = all_runs.collect()
        
        for idx, row in enumerate(results, 1):
            print(f"{idx}. {row.run_name} ‚Üí {config.PRIMARY_METRIC}: {row.primary_metric_value:.4f}")
        
        print(f"\nüìä Total Unique Experiments: {len(results)}")
        
    except Exception as e:
        print(f"‚ö† Could not fetch experiments: {e}")


# -------------------------
# üöÄ MAIN EXECUTION
# -------------------------

def main():
    # Ensure table exists
    ensure_table_exists()
    
    # Fetch new runs (excluding already logged)
    run_list = get_recent_runs()
    
    # Log new results
    log_results(run_list)
    
    # Show summary of newly logged
    show_summary(run_list)
    
    # Show all unique experiments
    show_all_experiments()

    print("\nüéâ Evaluation Completed Successfully!")
    print("=" * 80)

if __name__ == "__main__":
    main()