In [None]:
# üöÄ PRODUCTION BATCH INFERENCE 
 
from databricks.sdk import WorkspaceClient
from pyspark.sql import SparkSession
import pandas as pd
import numpy as np
import yaml
import traceback
import sys  # ‚Üê ADDED
from datetime import datetime
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

print("=" * 80)
print("üöÄ PRODUCTION INFERENCE PIPELINE (NEW WORKFLOW)")
print("=" * 80)
 
# 1Ô∏è‚É£ Load pipeline configuration
 
print("\nüìã Loading pipeline_config.yml...")

import os

try:
    config_path = "/Workspace/Repos/vipultak7171@gmail.com/ml-credit-risk/dev_env/pipeline_config.yml"
    
    if not os.path.exists(config_path):
        config_path = "pipeline_config.yml"  # fallback if repo path fails

    with open(config_path, "r") as f:
        pipeline_cfg = yaml.safe_load(f)

    MODEL_TYPE = pipeline_cfg["model"]["type"]
    BASE_NAME = pipeline_cfg["model"]["base_name"]

    # ‚úÖ NEW: Dynamic endpoint name from YAML format rule
    ENDPOINT_NAME = pipeline_cfg["serving"]["endpoint_name_format"].format(
        base_name=BASE_NAME.replace("_", "-"),
        model_type=MODEL_TYPE
    )

    data_cfg = pipeline_cfg["data"]
    INPUT_TABLE = data_cfg["input_table"]
    FEATURES = data_cfg["features"]
    LABEL = data_cfg["label"]

    # ‚úÖ NEW: Dynamic output table based on YAML format rule
    inference_cfg = pipeline_cfg.get("inference", {})
    OUTPUT_TABLE = inference_cfg["output_table_format"].format(
        catalog=pipeline_cfg["model"]["catalog"],
        schema=pipeline_cfg["model"]["schema"],
        model_type=MODEL_TYPE
    )

    BATCH_SIZE = inference_cfg.get("batch_size", 100)

    print(f"üìå Using configs:")
    print(f"   ‚û§ Endpoint: {ENDPOINT_NAME}")
    print(f"   ‚û§ Input Table: {INPUT_TABLE}")
    print(f"   ‚û§ Output Table: {OUTPUT_TABLE}")
    print(f"   ‚û§ Batch Size: {BATCH_SIZE}")

except Exception as e:
    print(f"‚ùå Failed to load config: {e}")
    traceback.print_exc()
    sys.exit(1)
 
# 2Ô∏è‚É£ Initialize clients

try:
    spark = SparkSession.builder.appName("ProductionInference").getOrCreate()
    ws = WorkspaceClient()
    print("\n‚úÖ Spark + WorkspaceClient initialized")

except Exception as e:
    print(f"‚ùå Initialization failed: {e}")
    traceback.print_exc()
    sys.exit(1)
 
# 3Ô∏è‚É£ Load input data
 
print(f"\nüì• Loading data from {INPUT_TABLE}...")

try:
    df = spark.read.table(INPUT_TABLE).toPandas()

    # Validate required features
    missing = [col for col in FEATURES if col not in df.columns]
    if missing:
        raise ValueError(f"Missing input features: {missing}")

    if LABEL not in df.columns:
        print(f"‚ö†Ô∏è No label column found ‚Üí metrics will be skipped")

    print(f"üìå Total records: {len(df):,}")

except Exception as e:
    print(f"‚ùå Data load error: {e}")
    traceback.print_exc()
    sys.exit(1)
 
# 4Ô∏è‚É£ Run inference using Serving Endpoint

print("\nüîÑ Running inference through serving endpoint...")

predictions = []
n_batches = (len(df) // BATCH_SIZE) + 1

try:
    for i in range(n_batches):
        batch = df.iloc[i * BATCH_SIZE : (i+1) * BATCH_SIZE]

        if batch.empty:
            continue

        response = ws.serving_endpoints.query(
            name=ENDPOINT_NAME,
            dataframe_records=batch[FEATURES].to_dict("records")
        )

        predictions.extend(response.predictions)
        
        if (i + 1) % 10 == 0 or (i + 1) == n_batches:
            print(f"   Processed batch {i+1}/{n_batches}")

    print(f"‚úÖ Inference complete: {len(predictions)} predictions generated")

except Exception as e:
    print(f"‚ùå Inference failed: {e}")
    traceback.print_exc()
    sys.exit(1)
 
# 5Ô∏è‚É£ Build output dataframe
 
print("\nüì¶ Preparing output results...")

df["prediction"] = predictions

# üîß Ensure prediction is integer before converting
df["prediction"] = pd.to_numeric(df["prediction"], errors="coerce").astype("Int64")

# üî• Convert prediction to human readable Yes/No
label_mapping = {0: "No", 1: "Yes"}
df["prediction_label"] = df["prediction"].map(label_mapping)


df["prediction_timestamp"] = datetime.now()
df["model_type"] = MODEL_TYPE.upper()
df["endpoint"] = ENDPOINT_NAME
 
# 6Ô∏è‚É£ Save to Delta Table
 
print(f"\nüíæ Saving results to {OUTPUT_TABLE}...")

try:
    spark_df = spark.createDataFrame(df)

    spark_df.write.mode("append").format("delta").option("mergeSchema", "true").saveAsTable(OUTPUT_TABLE)

    print(f"‚úÖ Results saved to: {OUTPUT_TABLE}")

except Exception as e:
    print(f"‚ö†Ô∏è Write failure but continuing: {e}")
    traceback.print_exc()
 
# 7Ô∏è‚É£ Compute model performance (if true labels present)

if LABEL in df.columns:
    print("\nüìä Evaluating model performance (Classification)...")

    # üîß FIX seed true labels to match 0/1 type
    y_true = df[LABEL].replace({"no": 0, "yes": 1}).astype(int)
    y_pred = df["prediction"].astype(int)

    metrics = {
        "accuracy": accuracy_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred, zero_division=0),
        "recall": recall_score(y_true, y_pred, zero_division=0),
        "f1": f1_score(y_true, y_pred, zero_division=0),
    }

    try:
        metrics["roc_auc"] = roc_auc_score(y_true, y_pred)
    except:
        metrics["roc_auc"] = None

    print("üìä Metrics:")
    for k, v in metrics.items():
        if v is not None:
            print(f"   ‚û§ {k}: {round(v,4)}")

else:
    print("\n‚ÑπÔ∏è No true labels found ‚Üí skipping metrics.")
 
# 8Ô∏è‚É£ Final Summary

print("\n" + "="*80)
print(f"üéØ PRODUCTION INFERENCE COMPLETED")
print("="*80)
print(f"üìå Input Data     : {INPUT_TABLE}")
print(f"üìå Saved Results  : {OUTPUT_TABLE}")
print(f"üìå Predictions    : {len(df):,}")
print(f"üìå Endpoint       : {ENDPOINT_NAME}")
print("="*80)

print("‚úÖ Script completed successfully")
