In [1]:
# =============================================================
# Example: NBA Feature Drift Monitoring with Plots
# =============================================================

import pandas as pd
import matplotlib.pyplot as plt
from src.monitoring import drift

# -------------------------------------------------------------
# Step 1: Load baseline and recent datasets
# -------------------------------------------------------------
baseline = pd.read_parquet("data/features/features_20251210_120000.parquet")
recent = pd.read_parquet("data/features/features_20251217_120000.parquet")

numeric_cols = [
    "home_roll_winrate",
    "home_roll_pts_for",
    "home_roll_pts_against",
    "away_roll_winrate",
    "away_roll_pts_for",
    "away_roll_pts_against",
]

# -------------------------------------------------------------
# Step 2: Compute KS-test and PSI
# -------------------------------------------------------------
ks_report = drift.ks_drift_report(
    baseline=baseline,
    recent=recent,
    columns=numeric_cols,
    alpha=0.05,
    min_samples=5,
)

psi_values = drift.psi_report(
    baseline=baseline,
    recent=recent,
    columns=numeric_cols,
    buckets=10,
)

# -------------------------------------------------------------
# Step 3: Summarize drift
# -------------------------------------------------------------
num_drifted, num_tested = drift.summarize_drift(ks_report)
print(f"KS-test summary: {num_drifted}/{num_tested} columns show drift")

print("\nFeature-wise KS p-values and PSI:")
for col in numeric_cols:
    ks_p = ks_report[col]["pvalue"]
    ks_flag = bool(ks_report[col]["drift"])
    psi_val = psi_values[col]
    psi_level = (
        "No drift" if psi_val < 0.1 else "Moderate drift" if psi_val < 0.25 else "Significant drift"
    )
    print(f"{col}: KS p={ks_p:.3f}, drift={ks_flag}, PSI={psi_val:.3f} ({psi_level})")

# -------------------------------------------------------------
# Step 4: Plot drift summary
# -------------------------------------------------------------
# KS p-values bar plot
ks_pvals = [ks_report[c]["pvalue"] for c in numeric_cols]
ks_flags = [ks_report[c]["drift"] for c in numeric_cols]

plt.figure(figsize=(10, 4))
plt.bar(numeric_cols, ks_pvals, color=["red" if f else "green" for f in ks_flags])
plt.axhline(0.05, color="black", linestyle="--", label="alpha=0.05")
plt.ylabel("KS-test p-value")
plt.xticks(rotation=45)
plt.title("KS-test Drift per Feature (red=drift)")
plt.legend()
plt.tight_layout()
plt.show()

# PSI bar plot
psi_vals_plot = [psi_values[c] for c in numeric_cols]
colors = ["green" if v < 0.1 else "orange" if v < 0.25 else "red" for v in psi_vals_plot]

plt.figure(figsize=(10, 4))
plt.bar(numeric_cols, psi_vals_plot, color=colors)
plt.axhline(0.1, color="black", linestyle="--", label="No drift threshold")
plt.axhline(0.25, color="black", linestyle=":", label="Significant drift threshold")
plt.ylabel("PSI")
plt.xticks(rotation=45)
plt.title("Population Stability Index (PSI) per Feature")
plt.legend()
plt.tight_layout()
plt.show()


ModuleNotFoundError: No module named 'src'