In [0]:
import json
import pandas as pd
import pyspark.sql.functions as F

In [0]:
%sql

delete  from ispl_databricks.ff.model_logs_payload
where   execution_duration_ms is  null

In [0]:
%sql
select * from ispl_databricks.ff.model_logs_payload

In [0]:
# Databricks widgets (Job Parameters)
dbutils.widgets.text("model_name", "ff_mw", "Model Name")
dbutils.widgets.text("mean_diff_threshold", "0.2", "Mean Diff Threshold (Fraction)")
dbutils.widgets.text("std_diff_threshold", "0.3", "Std Diff Threshold (Fraction)")
dbutils.widgets.text("pvalue_threshold", "0.05", "Chi-Square p-value Threshold")

# Read job parameter values
model_name = dbutils.widgets.get("model_name")
mean_diff_threshold = float(dbutils.widgets.get("mean_diff_threshold"))
std_diff_threshold = float(dbutils.widgets.get("std_diff_threshold"))
pvalue_threshold = float(dbutils.widgets.get("pvalue_threshold"))

print(f"üèÉ Running data drift detection for model: {model_name}")
print(f"üîπ Mean threshold: {mean_diff_threshold}, Std threshold: {std_diff_threshold}, p-value: {pvalue_threshold}")

# Import dependencies
from scipy.stats import chi2_contingency
import pandas as pd
import numpy as np
from datetime import date
from pyspark.sql import functions as F
import os
import json
# --- Load reference and current (inference) datasets ---

curr_df = (
    spark.read.table("ispl_databricks.ff.model_logs_payload")
    .filter("request_time >= current_date() - INTERVAL 7 DAYS")
    .filter(F.col("execution_duration_ms").isNotNull())
    .toPandas()
)
result_list = []

for _, row in curr_df.iterrows():
    data = row['request']
    # If you want to access 'dataframe_records', you need to load the JSON string first
    data_dict = json.loads(data)
    result_list.append(data_dict['dataframe_records'][0])
curr_df = pd.DataFrame(result_list)
print(curr_df.columns)

features_df = spark.table("ispl_databricks.model_logs.mw_final_feature_store").toPandas().head(20)
ref_df = features_df.drop(["loan_id"], axis=1)
curr_df = curr_df.drop(columns = ["loan_id"])
features = curr_df.columns.tolist()







# --- Define features to monitor ---
cat_cols = ref_df.select_dtypes(include=["object", "category"]).columns.tolist()
num_cols = ref_df.select_dtypes(include=["int", "float", "number"]).columns.tolist()

results = []
evaluation_date = date.today()

# --- 1. Categorical Columns: Chi-Square Test ---
for col in cat_cols:
    try:
        ref_counts = ref_df[col].value_counts()
        cur_counts = curr_df[col].value_counts()

        all_categories = set(ref_counts.index).union(set(cur_counts.index))
        ref_aligned = [ref_counts.get(c, 0) for c in all_categories]
        cur_aligned = [cur_counts.get(c, 0) for c in all_categories]

        chi2, p, _, _ = chi2_contingency([ref_aligned, cur_aligned])
        drift_status_cat = "Drift" if p < pvalue_threshold else "Stable"

        results.append((
            evaluation_date,
            "ff_mw",
            col,
            "categorical",
            "chi_square",
            float(p),
            drift_status_cat,
            int(len(ref_df)),
            int(len(curr_df)),
            None
        ))

    except Exception as e:
        print(f"‚ö†Ô∏è Skipped {col} due to error: {e}")

# --- 2. Numeric Columns: Mean & Std Deviation Comparison ---
for col in num_cols:
    try:
        mean_ref, mean_cur = ref_df[col].mean(), curr_df[col].mean()
        std_ref, std_cur = ref_df[col].std(), curr_df[col].std()

        mean_diff = abs(mean_cur - mean_ref) / (abs(mean_ref) + 1e-6)
        std_diff = abs(std_cur - std_ref) / (abs(std_ref) + 1e-6)

        drift_status_num = "Drift" if (mean_diff > mean_diff_threshold or std_diff > std_diff_threshold) else "Stable"

        results.append((
            evaluation_date,
            "ff_mw",
            col,
            "numeric",
            "mean_std",
            float(max(mean_diff, std_diff)),
            drift_status_num,
            int(len(ref_df)),
            int(len(curr_df)),
            None,
            None
        ))

    except Exception as e:
        print(f"‚ö†Ô∏è Skipped {col} due to error: {e}")

# -------------------------------------------------------------
# ‚≠ê NEW: GLOBAL DRIFT STATUS (Do not modify per-column results)
# -------------------------------------------------------------
drift_count = sum(1 for r in results if r[6] == "Drift")  # index 6 = drift_status
drift_status = "Drift" if drift_count >= 20 else "Stable"

print(f"üîî Per-feature drift count = {drift_count}")
print(f"üåê Global drift status = {drift_status}")
# -------------------------------------------------------------

# --- Create DataFrame & Write to Delta Table ---
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    IntegerType,
    FloatType,
    DateType,
    DoubleType,
    TimestampType

)



schema = StructType([
    StructField("evaluation_date", DateType(), True),
    StructField("model_name", StringType(), True),
    StructField("feature_name", StringType(), True),
    StructField("feature_type", StringType(), True),
    StructField("metric_used", StringType(), True),
    StructField("metric_value", DoubleType(), True),
    StructField("drift_status", StringType(), True),
    StructField("ref_sample_size", IntegerType(), True),
    StructField("cur_sample_size", IntegerType(), True),
    StructField("comment", StringType(), True),
    StructField("created_at", TimestampType(), True),
])

if results:
    drift_df = spark.createDataFrame(results, schema)
    drift_df = drift_df.withColumn("created_at", F.current_timestamp())
    drift_df = drift_df.drop("global_drift_status")
    drift_df = drift_df.withColumn("metric_value", drift_df["metric_value"].cast("double"))
    drift_df.write.format("delta").mode("append").saveAsTable("ispl_databricks.model_logs.data_drift_log")
    display(drift_df.printSchema())
else:
    print("‚ö†Ô∏è No results to log ‚Äî check your reference and inference tables.")

In [0]:
%sql
select * from ispl_databricks.model_logs.data_drift_log

In [0]:
%sql
select * from ispl_databricks.ff.model_logs_payload


In [0]:
dbutils.jobs.taskValues.set("drift_flag", drift_status)