# Data Drift Detection with spark-bestfit

This notebook demonstrates **data drift detection** using distribution fitting to monitor
feature distributions over time and alert when significant changes occur.

## What You'll Learn

1. **Establish baseline distributions** from historical data
2. **Monitor distributions** across time periods using KS tests
3. **Detect gradual and abrupt drift** in feature distributions
4. **Set alert thresholds** based on statistical significance
5. **Track multi-feature drift** to identify which features are changing

## Business Context

Production ML models degrade when underlying data distributions shift. This "data drift"
can occur due to:
- Seasonal patterns or trends
- Changes in user behavior
- Data pipeline issues
- External market conditions

**Drift detection enables:**
- Proactive model retraining triggers
- Data quality issue identification
- Early warning for changing business conditions
- SLA compliance for model accuracy

## Prerequisites

```bash
pip install spark-bestfit pandas numpy matplotlib scipy
```

## Setup

In [None]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

from spark_bestfit import DistributionFitter

# Create Spark session
spark = SparkSession.builder \
    .appName("Drift-Detection") \
    .master("local[*]") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

spark.sparkContext.setLogLevel("WARN")
print(f"Spark version: {spark.version}")

## Part 1: Generate Synthetic Time Series Data

We'll simulate a realistic drift scenario:
- **Month 1 (Baseline)**: Stable Normal distribution N(100, 15)
- **Months 2-4**: Gradual mean shift (100 → 105 → 110 → 115)
- **Month 5**: Abrupt distribution change (switches to Gamma)
- **Month 6**: Recovery toward baseline

In [None]:
np.random.seed(42)

# Baseline parameters
baseline_mu = 100
baseline_sigma = 15
samples_per_month = 10000

# Generate data for each period
periods = {
    'Month 1 (Baseline)': np.random.normal(100, 15, samples_per_month),
    'Month 2': np.random.normal(105, 15, samples_per_month),  # Gradual shift
    'Month 3': np.random.normal(110, 15, samples_per_month),  # More shift
    'Month 4': np.random.normal(115, 16, samples_per_month),  # Shift + variance change
    'Month 5': np.random.gamma(5, 20, samples_per_month),     # Abrupt change!
    'Month 6': np.random.normal(105, 15, samples_per_month),  # Recovery
}

# Show summary statistics
print("Generated Data Summary:")
print(f"{'Period':<20} {'Mean':>10} {'Std':>10} {'Skew':>10}")
print("-" * 52)
for period, data in periods.items():
    print(f"{period:<20} {data.mean():>10.2f} {data.std():>10.2f} {stats.skew(data):>10.2f}")

In [None]:
# Visualize the distributions
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

for i, (period, data) in enumerate(periods.items()):
    ax = axes[i]
    ax.hist(data, bins=50, density=True, alpha=0.7, edgecolor='black')
    ax.axvline(data.mean(), color='red', linestyle='--', lw=2, label=f'Mean: {data.mean():.1f}')
    ax.set_title(period)
    ax.set_xlabel('Value')
    ax.set_ylabel('Density')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle('Feature Distribution Over Time', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Part 2: Establish Baseline Distribution

Fit distributions to the baseline period to create a reference for comparison.

In [None]:
# Create fitter
fitter = DistributionFitter(spark)

# Create Spark DataFrame for baseline
baseline_data = periods['Month 1 (Baseline)']
baseline_df = spark.createDataFrame(pd.DataFrame({'value': baseline_data}))

# Fit distributions with full metrics (lazy_metrics=False for KS/AD)
baseline_results = fitter.fit(
    baseline_df,
    column='value',
    max_distributions=15,
    lazy_metrics=False  # Need KS statistics for drift detection
)

print(f"Fitted {baseline_results.count()} distributions to baseline")

In [None]:
# Get best baseline fit
baseline_fit = baseline_results.best(n=1, metric='aic')[0]
baseline_samples = baseline_data  # Store raw samples for KS comparison

print("Baseline Distribution:")
print(f"  Best fit: {baseline_fit.distribution}")
print(f"  AIC: {baseline_fit.aic:.2f}")
print(f"  KS statistic: {baseline_fit.ks_statistic:.4f}")
print(f"  Parameters: {baseline_fit.parameters}")

# Show top 5 candidates
print("\nTop 5 Baseline Candidates:")
for i, fit in enumerate(baseline_results.best(n=5, metric='aic'), 1):
    print(f"  {i}. {fit.distribution}: AIC={fit.aic:.1f}, KS={fit.ks_statistic:.4f}")

## Part 3: Monitor Periods with KS Test

Compare each monitoring period against baseline using the two-sample KS test.
This directly compares samples without assuming a specific distribution.

In [None]:
# Monitor each period
drift_results = []

for period, data in periods.items():
    # Two-sample KS test: compare this period to baseline
    ks_stat, p_value = stats.ks_2samp(baseline_samples, data)
    
    # Also track distribution statistics
    mean_shift = abs(data.mean() - baseline_samples.mean())
    std_ratio = data.std() / baseline_samples.std()
    
    drift_results.append({
        'period': period,
        'ks_statistic': ks_stat,
        'p_value': p_value,
        'mean_shift': mean_shift,
        'std_ratio': std_ratio,
        'drift_detected': p_value < 0.05
    })

drift_df = pd.DataFrame(drift_results)
print("Drift Detection Results (KS Test vs Baseline):")
print(drift_df.to_string(index=False))

In [None]:
# Visualize drift detection
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# KS statistic over time
months = range(1, 7)
ax1 = axes[0]
bars = ax1.bar(months, drift_df['ks_statistic'], 
               color=['green' if not d else 'red' for d in drift_df['drift_detected']],
               edgecolor='black', alpha=0.7)
ax1.axhline(0.05, color='orange', linestyle='--', lw=2, label='Typical drift threshold')
ax1.set_xlabel('Month')
ax1.set_ylabel('KS Statistic')
ax1.set_title('Drift Magnitude Over Time')
ax1.legend()
ax1.grid(True, alpha=0.3)

# P-value over time (log scale for visibility)
ax2 = axes[1]
ax2.semilogy(months, drift_df['p_value'], 'bo-', lw=2, markersize=10)
ax2.axhline(0.05, color='red', linestyle='--', lw=2, label='Significance threshold (p=0.05)')
ax2.fill_between(months, 0, 0.05, alpha=0.2, color='red', label='Drift detected zone')
ax2.set_xlabel('Month')
ax2.set_ylabel('p-value (log scale)')
ax2.set_title('Statistical Significance of Drift')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim(1e-100, 1)

plt.tight_layout()
plt.show()

## Part 4: Distribution Fitting for Drift Characterization

Beyond detecting drift, we can characterize *how* the distribution changed
by fitting distributions to each period.

In [None]:
# Fit distributions to each period and track best fit
period_fits = {}

for period, data in periods.items():
    df = spark.createDataFrame(pd.DataFrame({'value': data}))
    results = fitter.fit(
        df,
        column='value',
        max_distributions=10,
        lazy_metrics=False
    )
    best = results.best(n=1, metric='aic')[0]
    period_fits[period] = {
        'distribution': best.distribution,
        'aic': best.aic,
        'ks_statistic': best.ks_statistic,
        'parameters': best.parameters
    }

# Show distribution changes
print("Best-Fit Distribution by Period:")
print(f"{'Period':<20} {'Distribution':<15} {'AIC':<12} {'KS Stat'}")
print("-" * 60)
for period, fit_info in period_fits.items():
    print(f"{period:<20} {fit_info['distribution']:<15} {fit_info['aic']:<12.1f} {fit_info['ks_statistic']:.4f}")

## Part 5: Multi-Feature Drift Monitoring

Real ML models have multiple features. Let's extend to monitor drift across
multiple features simultaneously.

In [None]:
np.random.seed(123)

# Generate multi-feature data with varying drift patterns
n_samples = 5000

# Baseline period
baseline_multi = pd.DataFrame({
    'feature_a': np.random.normal(50, 10, n_samples),   # Will drift
    'feature_b': np.random.exponential(20, n_samples),  # Stable
    'feature_c': np.random.normal(0, 1, n_samples),     # Will drift severely
})

# Monitoring period (with drift in features A and C)
monitor_multi = pd.DataFrame({
    'feature_a': np.random.normal(55, 10, n_samples),   # Mean shifted +5
    'feature_b': np.random.exponential(20, n_samples),  # Stable
    'feature_c': np.random.gamma(2, 2, n_samples),      # Distribution change!
})

print("Multi-Feature Drift Detection:")
print(f"{'Feature':<12} {'Baseline Mean':>15} {'Monitor Mean':>15} {'KS Stat':>12} {'p-value':>12} {'Drift?'}")
print("-" * 80)

multi_drift = []
for col in baseline_multi.columns:
    ks_stat, p_val = stats.ks_2samp(baseline_multi[col], monitor_multi[col])
    drift_detected = p_val < 0.05
    multi_drift.append({
        'feature': col,
        'ks_statistic': ks_stat,
        'p_value': p_val,
        'drift': drift_detected
    })
    print(f"{col:<12} {baseline_multi[col].mean():>15.2f} {monitor_multi[col].mean():>15.2f} "
          f"{ks_stat:>12.4f} {p_val:>12.4e} {'YES' if drift_detected else 'NO'}")

In [None]:
# Aggregate drift score
drifting_features = sum(1 for d in multi_drift if d['drift'])
total_features = len(multi_drift)
drift_score = drifting_features / total_features

print(f"\nAggregate Drift Assessment:")
print(f"  Features with drift: {drifting_features}/{total_features}")
print(f"  Drift score: {drift_score:.1%}")

if drift_score > 0.5:
    print(f"  ALERT: Significant drift detected - consider model retraining")
elif drift_score > 0:
    print(f"  WARNING: Partial drift detected - monitor closely")
else:
    print(f"  OK: No significant drift detected")

## Part 6: Alerting Dashboard

Create a comprehensive drift monitoring dashboard.

In [None]:
# Create drift monitoring dashboard
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Feature-wise drift comparison
ax1 = axes[0, 0]
features = [d['feature'] for d in multi_drift]
ks_stats = [d['ks_statistic'] for d in multi_drift]
colors = ['red' if d['drift'] else 'green' for d in multi_drift]
ax1.barh(features, ks_stats, color=colors, edgecolor='black', alpha=0.7)
ax1.axvline(0.05, color='orange', linestyle='--', lw=2, label='Alert threshold')
ax1.set_xlabel('KS Statistic')
ax1.set_title('Feature-wise Drift Magnitude')
ax1.legend()

# 2. Time series drift tracking (using earlier data)
ax2 = axes[0, 1]
ax2.plot(range(1, 7), drift_df['ks_statistic'], 'bo-', lw=2, markersize=8)
ax2.fill_between(range(1, 7), 0, drift_df['ks_statistic'], alpha=0.3)
ax2.axhline(0.05, color='red', linestyle='--', lw=2, label='Alert threshold')
for i, row in drift_df.iterrows():
    if row['drift_detected']:
        ax2.scatter(i+1, row['ks_statistic'], color='red', s=200, zorder=5, marker='X')
ax2.set_xlabel('Month')
ax2.set_ylabel('KS Statistic')
ax2.set_title('Drift Evolution Over Time')
ax2.legend()

# 3. Distribution overlay (baseline vs drifted)
ax3 = axes[1, 0]
x = np.linspace(baseline_multi['feature_c'].min(), 
                max(baseline_multi['feature_c'].max(), monitor_multi['feature_c'].max()), 100)
ax3.hist(baseline_multi['feature_c'], bins=40, density=True, alpha=0.5, 
         label='Baseline', color='blue', edgecolor='black')
ax3.hist(monitor_multi['feature_c'], bins=40, density=True, alpha=0.5, 
         label='Monitor', color='red', edgecolor='black')
ax3.set_xlabel('Feature C Value')
ax3.set_ylabel('Density')
ax3.set_title('Feature C: Distribution Shift (Severe Drift)')
ax3.legend()

# 4. Mean shift tracking
ax4 = axes[1, 1]
means = [periods[p].mean() for p in periods.keys()]
ax4.plot(range(1, 7), means, 'go-', lw=2, markersize=10)
ax4.axhline(baseline_mu, color='blue', linestyle='--', lw=2, label=f'Baseline mean ({baseline_mu})')
ax4.fill_between(range(1, 7), baseline_mu - baseline_sigma, baseline_mu + baseline_sigma, 
                 alpha=0.2, color='blue', label='Baseline ±1 std')
ax4.set_xlabel('Month')
ax4.set_ylabel('Mean Value')
ax4.set_title('Mean Drift Over Time')
ax4.legend()

plt.suptitle('Drift Monitoring Dashboard', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Part 7: Streaming Drift Detection Pattern

For real-time monitoring, you can apply the same pattern to streaming data
using Spark Structured Streaming's `foreachBatch`.

In [None]:
# Streaming drift detection pattern (for reference)
# This shows HOW to integrate drift detection with streaming data

def create_drift_monitor(fitter, baseline_samples, threshold=0.05):
    """Factory function that returns a streaming batch processor."""
    
    def monitor_batch(batch_df, batch_id):
        """Process each streaming batch for drift detection."""
        if batch_df.count() == 0:
            return
        
        # Convert batch to pandas for KS test
        batch_data = batch_df.toPandas()['value'].values
        
        # Two-sample KS test against baseline
        ks_stat, p_value = stats.ks_2samp(baseline_samples, batch_data)
        
        # Optionally fit distribution to batch
        results = fitter.fit(batch_df, column='value', lazy_metrics=True)
        best_fit = results.best(n=1, metric='aic')[0]
        
        # Alert on drift
        if p_value < threshold:
            print(f"DRIFT ALERT [Batch {batch_id}]: KS={ks_stat:.4f}, p={p_value:.4e}")
            print(f"  Distribution changed to: {best_fit.distribution}")
            # In production: send to monitoring system, trigger retraining, etc.
        else:
            print(f"[Batch {batch_id}] OK: KS={ks_stat:.4f}, p={p_value:.4f}")
    
    return monitor_batch

# Example usage (commented out - requires active stream)
# drift_monitor = create_drift_monitor(fitter, baseline_samples)
# stream_df.writeStream.foreachBatch(drift_monitor).start()

print("Streaming pattern defined. In production:")
print("  1. Create baseline from historical data")
print("  2. Apply foreachBatch with drift_monitor function")
print("  3. Monitor alerts in real-time")

In [None]:
# Simulate streaming batch processing
print("Simulated Streaming Drift Detection:")
print("=" * 60)

drift_monitor = create_drift_monitor(fitter, baseline_samples, threshold=0.05)

# Process each "month" as if it were a streaming batch
for batch_id, (period, data) in enumerate(periods.items()):
    batch_df = spark.createDataFrame(pd.DataFrame({'value': data}))
    print(f"\nProcessing {period}:")
    drift_monitor(batch_df, batch_id)

## Part 8: Business Impact Assessment

Link detected drift to potential model performance degradation.

In [None]:
# Simulate relationship between drift and model accuracy
# In practice, you'd track actual model performance alongside drift

# Hypothetical model accuracy degradation with drift
baseline_accuracy = 0.92
drift_impact = -0.15  # Each unit of KS stat reduces accuracy by 15%

print("Business Impact Assessment:")
print("=" * 70)
print(f"Baseline model accuracy: {baseline_accuracy:.1%}")
print(f"Assumed drift impact: {drift_impact:.0%} accuracy per unit KS statistic")
print()
print(f"{'Period':<20} {'KS Stat':>10} {'Est. Accuracy':>15} {'Revenue Impact'}")
print("-" * 70)

monthly_revenue = 1_000_000  # $1M monthly revenue depends on model

for i, row in drift_df.iterrows():
    ks = row['ks_statistic']
    est_accuracy = max(baseline_accuracy + drift_impact * ks, 0.5)  # Floor at 50%
    accuracy_drop = baseline_accuracy - est_accuracy
    revenue_impact = monthly_revenue * accuracy_drop
    
    print(f"{row['period']:<20} {ks:>10.4f} {est_accuracy:>14.1%} ${revenue_impact:>12,.0f}")

In [None]:
# Calculate retraining decision
retraining_cost = 50_000  # Cost to retrain model
drift_threshold_for_retrain = 0.10  # Retrain if KS > 0.10

print("\nRetraining Decision Analysis:")
print("=" * 70)

for i, row in drift_df.iterrows():
    ks = row['ks_statistic']
    est_accuracy = max(baseline_accuracy + drift_impact * ks, 0.5)
    monthly_loss = monthly_revenue * (baseline_accuracy - est_accuracy)
    
    if ks > drift_threshold_for_retrain and monthly_loss > retraining_cost:
        decision = "RETRAIN"
        reason = f"Loss ${monthly_loss:,.0f} > Retrain cost ${retraining_cost:,.0f}"
    elif ks > drift_threshold_for_retrain:
        decision = "MONITOR"
        reason = f"Drift detected but loss ${monthly_loss:,.0f} < ${retraining_cost:,.0f}"
    else:
        decision = "OK"
        reason = f"Drift within acceptable bounds"
    
    print(f"{row['period']}: {decision} - {reason}")

## Summary

This notebook demonstrated data drift detection with spark-bestfit:

1. **Baseline establishment**: Fit distributions to historical data for reference
2. **KS test for drift**: Two-sample KS test to detect distribution changes
3. **Distribution fitting**: Characterize *how* distributions changed
4. **Multi-feature monitoring**: Track drift across multiple features
5. **Alerting patterns**: Set thresholds and visualize drift evolution
6. **Streaming integration**: `foreachBatch` pattern for real-time monitoring
7. **Business impact**: Link drift to model performance and ROI

### Key spark-bestfit Features Used

| Feature | Purpose |
|---------|---------|
| `lazy_metrics=False` | Compute KS statistics for validation |
| `DistributionFitter` | Fit distributions to each period |
| `results.best()` | Identify best-fit distribution |
| `scipy.stats.ks_2samp` | Direct sample comparison for drift |

### Drift Detection Thresholds

| KS Statistic | p-value | Interpretation |
|--------------|---------|----------------|
| < 0.05 | > 0.05 | No significant drift |
| 0.05 - 0.10 | < 0.05 | Moderate drift - monitor |
| > 0.10 | << 0.05 | Significant drift - investigate/retrain |

### Production Recommendations

1. **Store baselines**: Serialize fitted distributions for comparison
2. **Track trends**: Monitor KS statistic over time, not just point-in-time
3. **Set tiered alerts**: Warning (yellow) and critical (red) thresholds
4. **Automate response**: Trigger retraining pipelines on critical drift
5. **Root cause analysis**: When drift occurs, investigate which features changed

In [None]:
# Cleanup
spark.stop()