# Snowcore Industries - Demand Sensing Model

## Business Objective

Snowcore Industries requires accurate demand forecasting to optimize procurement planning and reduce inventory costs. This notebook implements a multi-model demand sensing system that predicts `forecasted_material_demand_qty` by combining:
- Historical consumption patterns
- External macro-economic indicators (construction starts, clinical trial spend, PMI)
- Seasonal patterns

## Technical Approach

We employ an **ensemble comparison approach** using three regression algorithms:
1. **XGBoost** (primary) - Gradient boosted decision trees optimized for structured data
2. **Random Forest** - Ensemble of decision trees using bagging
3. **Linear Regression** - Simple baseline for comparison

The best model is selected based on Mean Absolute Percentage Error (MAPE) and deployed for production forecasting.

## Learning Objectives

After completing this notebook, you will understand:
1. How to engineer features from historical demand data and external indicators
2. The differences between XGBoost, Random Forest, and Linear Regression for demand forecasting
3. How to evaluate regression models using MAE, RMSE, MAPE, and R²
4. Best practices for temporal train/test splits to prevent data leakage
5. How to generate backtests and future forecasts with confidence intervals

## Prerequisites

- **Mathematics**: Basic statistics (mean, standard deviation, correlation)
- **ML Concepts**: Supervised learning, regression, train/test splits, cross-validation
- **Python**: Pandas, basic SQL, Snowpark familiarity
- **Domain**: Understanding of supply chain/procurement planning (helpful but not required)

## Notebook Structure

| Section | Purpose |
|---------|---------|
| 1. Environment Setup | Session configuration, imports, visualization theme |
| 2. Data Exploration | Load demand data, explore distributions |
| 3. Feature Engineering | Create historical, external, and seasonal features |
| 4. Algorithm Overview | Conceptual explanation of XGBoost, Random Forest, Linear Regression |
| 5. Model Training | Train all three models on training data |
| 6. Evaluation | Compare model performance with metrics and visualizations |
| 7. Production Output | Save predictions, model registry, feature importance |
| 8. Key Takeaways | Summary, interpretation guidelines, limitations |

## Output Artifacts

| Table | Description |
|-------|-------------|
| `ATOMIC.DEMAND_FORECAST_PREDICTIONS` | 90-day forecasts + 6-month backtests with confidence intervals |
| `ATOMIC.MODEL_REGISTRY` | Model performance metrics for tracking |
| `ATOMIC.MODEL_FEATURE_IMPORTANCE` | Feature importance scores from XGBoost |

In [None]:
# =============================================================================
# ENVIRONMENT SETUP
# =============================================================================
# Configure Snowflake session, imports, visualization theme, and helper functions

from snowflake.snowpark.context import get_active_session
session = get_active_session()

# Import required libraries
from snowflake.snowpark import functions as F
from snowflake.snowpark.types import *
from snowflake.ml.modeling.xgboost import XGBRegressor
from snowflake.ml.modeling.ensemble import RandomForestRegressor
from snowflake.ml.modeling.linear_model import LinearRegression
from snowflake.ml.modeling.preprocessing import StandardScaler
import datetime
import matplotlib.pyplot as plt
import numpy as np

# =============================================================================
# DARK THEME VISUALIZATION SETUP (Snowflake-inspired)
# =============================================================================
# Configure matplotlib for dark backgrounds with colorblind-safe palette

plt.style.use('dark_background')
plt.rcParams.update({
    # Background colors (soft dark gray, not pure black)
    'figure.facecolor': '#121212',
    'axes.facecolor': '#121212',
    
    # Text colors (off-white to reduce glare)
    'text.color': '#E5E5E7',
    'axes.labelcolor': '#E5E5E7',
    'xtick.color': '#A1A1A6',
    'ytick.color': '#A1A1A6',
    
    # Grid and axes (subtle, not distracting)
    'axes.edgecolor': '#3A3A3C',
    'grid.color': '#2C2C2E',
    'grid.alpha': 0.6,
    
    # Line and marker styling
    'lines.linewidth': 2,
    'lines.markersize': 8,
    
    # Figure quality and sizing
    'figure.dpi': 150,
    'savefig.dpi': 200,
    'figure.figsize': (10, 6),
    
    # Font configuration
    'font.family': 'sans-serif',
    'font.size': 11,
    'axes.titlesize': 14,
    'axes.labelsize': 12,
})

# Colorblind-safe desaturated palette for dark backgrounds
COLORS = ['#64D2FF', '#FF9F0A', '#5AC8FA', '#FFD60A', '#11567F']
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=COLORS)

# =============================================================================
# FAIL-FAST ERROR HANDLING
# =============================================================================
# All queries must fail loudly - no silent failures

def execute_query(session, query: str, name: str = "query"):
    """
    Execute a SQL query with fail-fast error handling.
    
    Args:
        session: Snowflake Snowpark session
        query: SQL query string
        name: Descriptive name for error messages
    
    Returns:
        Snowpark DataFrame with query results
    
    Raises:
        RuntimeError: If query fails or returns None
    """
    try:
        result = session.sql(query)
        if result is None:
            raise RuntimeError(f"Query '{name}' returned None")
        return result
    except Exception as e:
        raise RuntimeError(f"Query '{name}' failed: {e}") from e

print(f"Connected to: {session.get_current_database()}.{session.get_current_schema()}")
print(f"Visualization theme: Dark mode with colorblind-safe palette")

## 2. Data Exploration

Load and explore demand data to understand distributions, patterns, and data quality before feature engineering.

In [None]:
# =============================================================================
# DATA EXPLORATION
# =============================================================================
# Load demand data and explore key statistics

demand_df = session.table("ATOMIC.DEMAND_ACTUAL")
record_count = demand_df.count()
print(f"Total demand records: {record_count:,}")

if record_count == 0:
    raise RuntimeError("DEMAND_ACTUAL table is empty - cannot proceed with analysis")

# Date range
print("\n=== Date Range ===")
demand_df.select(
    F.min("DEMAND_DATE").alias("MIN_DATE"),
    F.max("DEMAND_DATE").alias("MAX_DATE")
).show()

# Demand by source
print("\n=== Demand by Source ===")
demand_df.group_by("DEMAND_SOURCE").agg(
    F.count("*").alias("COUNT"),
    F.sum("ACTUAL_QUANTITY").alias("TOTAL_QTY")
).order_by("TOTAL_QTY", ascending=False).show()

# Summary statistics
print("\n=== Demand Quantity Statistics ===")
demand_df.select(
    F.avg("ACTUAL_QUANTITY").alias("AVG_QTY"),
    F.stddev("ACTUAL_QUANTITY").alias("STD_QTY"),
    F.min("ACTUAL_QUANTITY").alias("MIN_QTY"),
    F.max("ACTUAL_QUANTITY").alias("MAX_QTY"),
    F.percentile_cont(0.5).within_group("ACTUAL_QUANTITY").alias("MEDIAN_QTY")
).show()

In [None]:
# =============================================================================
# DATA EXPLORATION VISUALIZATIONS
# =============================================================================
# Visualize demand distributions and patterns

# Get data for plotting
demand_stats = demand_df.group_by("DEMAND_SOURCE").agg(
    F.count("*").alias("COUNT"),
    F.avg("ACTUAL_QUANTITY").alias("AVG_QTY"),
    F.stddev("ACTUAL_QUANTITY").alias("STD_QTY")
).to_pandas()

# Monthly demand trends
monthly_demand = demand_df.select(
    F.date_trunc("month", F.col("DEMAND_DATE")).alias("MONTH"),
    F.col("ACTUAL_QUANTITY")
).group_by("MONTH").agg(
    F.sum("ACTUAL_QUANTITY").alias("TOTAL_QTY"),
    F.count("*").alias("ORDER_COUNT")
).order_by("MONTH").to_pandas()

fig, axes = plt.subplots(1, 3, figsize=(16, 4))

# Plot 1: Demand by Source
ax1 = axes[0]
bars = ax1.barh(demand_stats['DEMAND_SOURCE'], demand_stats['COUNT'], color=COLORS[0])
ax1.set_xlabel('Number of Records')
ax1.set_title('Demand Records by Source')
ax1.grid(axis='x', alpha=0.3)

# Plot 2: Monthly Demand Trend
ax2 = axes[1]
if len(monthly_demand) > 0:
    ax2.plot(monthly_demand['MONTH'], monthly_demand['TOTAL_QTY'], 
             color=COLORS[0], marker='o', markersize=4)
    ax2.fill_between(monthly_demand['MONTH'], monthly_demand['TOTAL_QTY'], 
                     alpha=0.3, color=COLORS[0])
ax2.set_xlabel('Month')
ax2.set_ylabel('Total Quantity')
ax2.set_title('Monthly Demand Trend')
ax2.tick_params(axis='x', rotation=45)
ax2.grid(alpha=0.3)

# Plot 3: Average Demand by Source
ax3 = axes[2]
bars = ax3.barh(demand_stats['DEMAND_SOURCE'], demand_stats['AVG_QTY'], 
                color=COLORS[1], xerr=demand_stats['STD_QTY'].fillna(0), 
                capsize=3, error_kw={'color': '#A1A1A6', 'alpha': 0.7})
ax3.set_xlabel('Average Quantity (with Std Dev)')
ax3.set_title('Average Demand by Source')
ax3.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.savefig('/tmp/demand_exploration.png', dpi=150, bbox_inches='tight', 
            facecolor='#121212', edgecolor='none')
plt.show()

print("\nData exploration visualizations saved to /tmp/demand_exploration.png")

## 3. Feature Engineering

Create features from three sources:
1. **Internal features**: Historical consumption patterns (average, variability)
2. **External features**: Macro-economic indicators from marketplace data
3. **Seasonal features**: Time-based patterns (day of week, month, quarter)

> **Note**: All features are cast to FLOAT for XGBoost compatibility (INTEGER maps to object type in Snowpark ML).

In [None]:
# Feature engineering SQL with external indicators
# NOTE: All features must be FLOAT for XGBoost compatibility (INTEGER maps to object type)
feature_sql = """
WITH demand_stats AS (
    SELECT 
        PRODUCT_ID,
        SITE_ID,
        AVG(ACTUAL_QUANTITY) AS AVG_QTY,
        STDDEV(ACTUAL_QUANTITY) AS STD_QTY,
        COUNT(*) AS RECORD_COUNT
    FROM ATOMIC.DEMAND_ACTUAL
    GROUP BY PRODUCT_ID, SITE_ID
),
-- Get external indicators (pivoted)
external_indicators AS (
    SELECT 
        DATE_TRUNC('week', INDICATOR_DATE) AS WEEK_DATE,
        AVG(CASE WHEN INDICATOR_NAME = 'Construction Starts Index' THEN INDICATOR_VALUE END) AS CONSTRUCTION_IDX,
        AVG(CASE WHEN INDICATOR_NAME = 'Clinical Trial Spend ($B)' THEN INDICATOR_VALUE END) AS CLINICAL_IDX,
        AVG(CASE WHEN INDICATOR_NAME = 'Industrial Production Index' THEN INDICATOR_VALUE END) AS PRODUCTION_IDX,
        AVG(CASE WHEN INDICATOR_NAME = 'Manufacturing PMI' THEN INDICATOR_VALUE END) AS PMI_IDX
    FROM ATOMIC.MARKETPLACE_INDICATORS
    GROUP BY DATE_TRUNC('week', INDICATOR_DATE)
)
SELECT 
    d.DEMAND_ACTUAL_ID,
    d.PRODUCT_ID,
    d.SITE_ID,
    d.DEMAND_DATE,
    d.ACTUAL_QUANTITY,
    -- Internal features (all FLOAT for XGBoost)
    CAST(COALESCE(ds.AVG_QTY, 100) AS FLOAT) AS HIST_AVG_QTY,
    CAST(COALESCE(ds.STD_QTY, 10) AS FLOAT) AS HIST_STD_QTY,
    CAST(COALESCE(p.PRODUCT_CATEGORY_ID, 1) AS FLOAT) AS PRODUCT_CATEGORY_ID,
    -- External features
    CAST(COALESCE(ei.CONSTRUCTION_IDX, 100) AS FLOAT) AS CONSTRUCTION_IDX,
    CAST(COALESCE(ei.CLINICAL_IDX, 2.5) AS FLOAT) AS CLINICAL_IDX,
    CAST(COALESCE(ei.PRODUCTION_IDX, 98) AS FLOAT) AS PRODUCTION_IDX,
    CAST(COALESCE(ei.PMI_IDX, 50) AS FLOAT) AS PMI_IDX,
    -- Seasonal features (all FLOAT for XGBoost)
    CAST(DAYOFWEEK(d.DEMAND_DATE) AS FLOAT) AS DAY_OF_WEEK,
    CAST(MONTH(d.DEMAND_DATE) AS FLOAT) AS MONTH_NUM,
    CAST(QUARTER(d.DEMAND_DATE) AS FLOAT) AS QUARTER_NUM,
    -- Target
    CAST(d.ACTUAL_QUANTITY AS FLOAT) AS TARGET
FROM ATOMIC.DEMAND_ACTUAL d
LEFT JOIN demand_stats ds ON d.PRODUCT_ID = ds.PRODUCT_ID AND d.SITE_ID = ds.SITE_ID
LEFT JOIN ATOMIC.PRODUCT p ON d.PRODUCT_ID = p.PRODUCT_ID
LEFT JOIN external_indicators ei ON DATE_TRUNC('week', d.DEMAND_DATE) = ei.WEEK_DATE
WHERE ds.RECORD_COUNT >= 2
"""

features_df = session.sql(feature_sql)
print(f"Feature dataset rows: {features_df.count():,}")
features_df.show(5)

In [None]:
# Define features and target
# Internal features: historical consumption patterns
# External features: macro-economic indicators
# Seasonal features: time-based patterns
feature_cols = [
    # Internal
    'HIST_AVG_QTY', 'HIST_STD_QTY', 'PRODUCT_CATEGORY_ID',
    # External macro-economic
    'CONSTRUCTION_IDX', 'CLINICAL_IDX', 'PRODUCTION_IDX', 'PMI_IDX',
    # Seasonal
    'DAY_OF_WEEK', 'MONTH_NUM', 'QUARTER_NUM'
]
target_col = 'TARGET'

print(f"Feature count: {len(feature_cols)}")
print(f"Features: {feature_cols}")

# Train/test split (temporal - 80/20)
date_stats = features_df.select(
    F.min("DEMAND_DATE").alias("MIN_DATE"),
    F.max("DEMAND_DATE").alias("MAX_DATE")
).collect()[0]
min_date = date_stats["MIN_DATE"]
max_date = date_stats["MAX_DATE"]

# Calculate 80% point as days from min
total_days = (max_date - min_date).days
split_days = int(total_days * 0.8)
split_date = min_date + datetime.timedelta(days=split_days)

train_df = features_df.filter(F.col("DEMAND_DATE") < F.lit(split_date))
test_df = features_df.filter(F.col("DEMAND_DATE") >= F.lit(split_date))

print(f"\nSplit date: {split_date}")
print(f"Training samples: {train_df.count():,}")
print(f"Testing samples: {test_df.count():,}")

## 4. Algorithm Overview

Before training, let's understand the three regression algorithms we'll compare.

### What is XGBoost?

**XGBoost (Extreme Gradient Boosting)** is an optimized gradient boosting algorithm that builds an ensemble of decision trees sequentially. Each tree corrects the errors of the previous trees.

**Key intuition**: Instead of building one complex tree, XGBoost builds many simple trees where each tree focuses on the mistakes of the previous ones.

**Mathematical formulation**:
$$\hat{y}_i = \sum_{k=1}^{K} f_k(x_i)$$

Where $f_k$ is the prediction from tree $k$, and trees are added to minimize the loss function.

### What is Random Forest?

**Random Forest** is an ensemble of decision trees trained independently using bagging (bootstrap aggregating). Each tree sees a random subset of data and features, and predictions are averaged.

**Key intuition**: "Wisdom of crowds" - averaging many diverse trees reduces overfitting and improves generalization.

**Mathematical formulation**:
$$\hat{y} = \frac{1}{K} \sum_{k=1}^{K} T_k(x)$$

Where $T_k$ is the prediction from tree $k$ trained on a bootstrap sample.

### Why Multi-Model Comparison?

| Model | Strengths | Weaknesses |
|-------|-----------|------------|
| **XGBoost** | Handles non-linear relationships, robust to outliers | Can overfit with many features |
| **Random Forest** | Less prone to overfitting, parallelizable | May miss subtle patterns |
| **Linear Regression** | Interpretable, fast, good baseline | Assumes linear relationships |

### Key Hyperparameters

| Parameter | XGBoost | Random Forest | Purpose |
|-----------|---------|---------------|---------|
| `n_estimators` | 100 | 100 | Number of trees in ensemble |
| `max_depth` | 6 | 8 | Maximum tree depth (controls complexity) |
| `learning_rate` | 0.1 | N/A | Step size for gradient updates |
| `random_state` | 42 | 42 | Reproducibility seed |

## 5. Model Training\n\nTrain multiple models and compare performance:\n- **XGBoost Regressor** (primary model) - Gradient boosted trees\n- **Random Forest Regressor** (ensemble baseline) - Bagged decision trees\n- **Linear Regression** (simple baseline) - Linear model for comparison

In [None]:
# Train multiple models for comparison
models = {}

# 1. XGBoost Regressor (Primary Model)
print("Training XGBoost model...")
xgb_model = XGBRegressor(
    input_cols=feature_cols,
    label_cols=[target_col],
    output_cols=["PREDICTION"],
    n_estimators=100,
    max_depth=6,
    learning_rate=0.1,
    random_state=42
)
xgb_model.fit(train_df)
models['XGBoost'] = xgb_model
print("  XGBoost training complete!")

# 2. Random Forest Regressor (Ensemble Baseline)
print("Training Random Forest model...")
rf_model = RandomForestRegressor(
    input_cols=feature_cols,
    label_cols=[target_col],
    output_cols=["PREDICTION"],
    n_estimators=100,
    max_depth=8,
    random_state=42
)
rf_model.fit(train_df)
models['RandomForest'] = rf_model
print("  Random Forest training complete!")

# 3. Linear Regression (Simple Baseline)
print("Training Linear Regression model...")
lr_model = LinearRegression(
    input_cols=feature_cols,
    label_cols=[target_col],
    output_cols=["PREDICTION"]
)
lr_model.fit(train_df)
models['LinearRegression'] = lr_model
print("  Linear Regression training complete!")

print(f"\nAll {len(models)} models trained successfully.")

In [None]:
# Extract and save feature importance from XGBoost (primary model)
feature_importance_data = []
importances = xgb_model.to_xgboost().feature_importances_

for i, feat in enumerate(feature_cols):
    # Determine feature type based on naming conventions
    if feat.startswith('HIST_'):
        feat_type = 'Internal'
    elif feat.endswith('_IDX'):
        feat_type = 'External'
    elif feat in ['DAY_OF_WEEK', 'MONTH_NUM', 'QUARTER_NUM']:
        feat_type = 'Seasonal'
    else:
        feat_type = 'Derived'
    
    description = {
        'HIST_AVG_QTY': 'Historical average demand quantity',
        'HIST_STD_QTY': 'Historical demand variability',
        'PRODUCT_CATEGORY_ID': 'Product category classification',
        'CONSTRUCTION_IDX': 'Construction starts index (economic indicator)',
        'CLINICAL_IDX': 'Clinical trial spend (industry indicator)',
        'PRODUCTION_IDX': 'Industrial production index (economic indicator)',
        'PMI_IDX': 'Manufacturing PMI (economic indicator)',
        'DAY_OF_WEEK': 'Day of week seasonality',
        'MONTH_NUM': 'Monthly seasonality',
        'QUARTER_NUM': 'Quarterly seasonality'
    }.get(feat, f'Feature: {feat}')
    
    feature_importance_data.append({
        'MODEL_VERSION': 'v5.0.0',
        'FEATURE_NAME': feat,
        'IMPORTANCE_SCORE': float(importances[i]),
        'FEATURE_TYPE': feat_type,
        'DESCRIPTION': description
    })

# Sort by importance and display
feature_importance_data.sort(key=lambda x: x['IMPORTANCE_SCORE'], reverse=True)
print("=== Feature Importance (XGBoost) ===")
for fi in feature_importance_data[:5]:
    print(f"  {fi['FEATURE_NAME']:25s} {fi['IMPORTANCE_SCORE']:.4f} ({fi['FEATURE_TYPE']})")

# Save to table
fi_df = session.create_dataframe(feature_importance_data)
fi_df.write.mode("overwrite").save_as_table("ATOMIC.MODEL_FEATURE_IMPORTANCE")
print(f"\nFeature importance saved to ATOMIC.MODEL_FEATURE_IMPORTANCE ({len(feature_importance_data)} features)")

In [None]:
# Evaluate all models and compare performance
print("=== Multi-Model Performance Comparison ===\n")

# Define training parameters for each model (for registry tracking)
training_params = {
    'XGBoost': {
        'n_estimators': 100,
        'max_depth': 6,
        'learning_rate': 0.1,
        'random_state': 42
    },
    'RandomForest': {
        'n_estimators': 100,
        'max_depth': 8,
        'random_state': 42
    },
    'LinearRegression': {}
}

model_results = []
best_model = None
best_mape = float('inf')

for model_name, model in models.items():
    # Generate predictions
    predictions_df = model.predict(test_df)
    
    # Calculate metrics
    metrics = predictions_df.select(
        F.avg(F.abs(F.col("PREDICTION") - F.col(target_col))).alias("MAE"),
        F.sqrt(F.avg(F.pow(F.col("PREDICTION") - F.col(target_col), 2))).alias("RMSE"),
        F.avg(F.abs((F.col("PREDICTION") - F.col(target_col)) / F.col(target_col)) * 100).alias("MAPE"),
        F.corr(F.col("PREDICTION"), F.col(target_col)).alias("CORR")
    ).collect()[0]
    
    mae = float(metrics['MAE'])
    rmse = float(metrics['RMSE'])
    mape = float(metrics['MAPE'])
    corr = float(metrics['CORR']) if metrics['CORR'] else 0
    r2 = corr ** 2  # Approximate R²
    
    print(f"{model_name}:")
    print(f"  MAE:  {mae:.2f}")
    print(f"  RMSE: {rmse:.2f}")
    print(f"  MAPE: {mape:.2f}%")
    print(f"  R²:   {r2:.4f}")
    print()
    
    # Track best model
    if mape < best_mape:
        best_mape = mape
        best_model = model_name
    
    model_results.append({
        'MODEL_NAME': f'Demand Sensing {model_name}',
        'MODEL_VERSION': 'v5.0.0',
        'ALGORITHM': model_name,
        'MAE': mae,
        'RMSE': rmse,
        'MAPE': mape,
        'R2_SCORE': r2,
        'FEATURE_COUNT': len(feature_cols),
        'TRAINING_SAMPLES': train_df.count(),
        'IS_DEPLOYED': (model_name == 'XGBoost'),  # Deploy XGBoost by default
        'TRAINING_PARAMETERS': training_params.get(model_name, {})
    })

print(f"Best Model: {best_model} (MAPE: {best_mape:.2f}%)")

In [None]:
# =============================================================================
# MODEL EVALUATION VISUALIZATIONS
# =============================================================================
# Visualize model performance comparison

import pandas as pd

# Create DataFrame from model results for plotting
results_df = pd.DataFrame(model_results)

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Plot 1: Model Performance Comparison (MAPE - lower is better)
ax1 = axes[0]
models_list = results_df['ALGORITHM'].tolist()
mape_values = results_df['MAPE'].tolist()
bars = ax1.bar(models_list, mape_values, color=[COLORS[0], COLORS[1], COLORS[2]])
ax1.set_ylabel('MAPE (%)')
ax1.set_title('Model Comparison: MAPE (Lower is Better)')
ax1.axhline(y=min(mape_values), color=COLORS[3], linestyle='--', 
            label=f'Best: {min(mape_values):.2f}%', alpha=0.7)
ax1.legend(loc='upper right')
ax1.grid(axis='y', alpha=0.3)

# Add value labels on bars
for bar, val in zip(bars, mape_values):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
             f'{val:.1f}%', ha='center', va='bottom', fontsize=10, color='#E5E5E7')

# Plot 2: R² Score Comparison (higher is better)
ax2 = axes[1]
r2_values = results_df['R2_SCORE'].tolist()
bars = ax2.bar(models_list, r2_values, color=[COLORS[0], COLORS[1], COLORS[2]])
ax2.set_ylabel('R² Score')
ax2.set_title('Model Comparison: R² (Higher is Better)')
ax2.axhline(y=max(r2_values), color=COLORS[3], linestyle='--', 
            label=f'Best: {max(r2_values):.4f}', alpha=0.7)
ax2.legend(loc='lower right')
ax2.grid(axis='y', alpha=0.3)

# Plot 3: MAE and RMSE Comparison
ax3 = axes[2]
x = np.arange(len(models_list))
width = 0.35
mae_values = results_df['MAE'].tolist()
rmse_values = results_df['RMSE'].tolist()

bars1 = ax3.bar(x - width/2, mae_values, width, label='MAE', color=COLORS[0])
bars2 = ax3.bar(x + width/2, rmse_values, width, label='RMSE', color=COLORS[1])
ax3.set_ylabel('Error (units)')
ax3.set_title('MAE vs RMSE by Model')
ax3.set_xticks(x)
ax3.set_xticklabels(models_list)
ax3.legend()
ax3.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('/tmp/model_comparison.png', dpi=150, bbox_inches='tight',
            facecolor='#121212', edgecolor='none')
plt.show()

print(f"\nBest performing model: {best_model}")
print(f"Model comparison visualizations saved to /tmp/model_comparison.png")

## 7. Production Output

Save model performance metrics to `MODEL_REGISTRY` for tracking, generate 90-day forecasts + 6-month backtests, and persist to `DEMAND_FORECAST_PREDICTIONS`.

In [None]:
# =============================================================================
# SAVE TO MODEL REGISTRY
# =============================================================================
# Persist model metrics with training metadata
# Table schema: MODEL_ID, MODEL_NAME, MODEL_VERSION, ALGORITHM, TRAINING_DATE,
#               MAE, RMSE, MAPE, R2_SCORE, FEATURE_COUNT, TRAINING_SAMPLES,
#               IS_DEPLOYED, DEPLOYMENT_DATE, MODEL_ARTIFACT_PATH,
#               TRAINING_PARAMETERS, CREATED_TIMESTAMP

import json

# Get max MODEL_ID from registry to generate new IDs
max_id_result = session.sql("SELECT COALESCE(MAX(MODEL_ID), 0) AS MAX_ID FROM ATOMIC.MODEL_REGISTRY").collect()[0]
next_model_id = int(max_id_result['MAX_ID']) + 1

# Build rows in exact column order matching table schema
current_timestamp = datetime.datetime.now()
registry_rows = []
for idx, result in enumerate(model_results):
    row = (
        next_model_id + idx,                                                           # MODEL_ID
        result['MODEL_NAME'],                                                          # MODEL_NAME
        result['MODEL_VERSION'],                                                       # MODEL_VERSION
        result['ALGORITHM'],                                                           # ALGORITHM
        current_timestamp,                                                             # TRAINING_DATE
        float(result['MAE']),                                                          # MAE
        float(result['RMSE']),                                                         # RMSE
        float(result['MAPE']),                                                         # MAPE
        float(result['R2_SCORE']),                                                     # R2_SCORE
        int(result['FEATURE_COUNT']),                                                  # FEATURE_COUNT
        int(result['TRAINING_SAMPLES']),                                               # TRAINING_SAMPLES
        result['IS_DEPLOYED'],                                                         # IS_DEPLOYED
        current_timestamp if result['IS_DEPLOYED'] else None,                          # DEPLOYMENT_DATE
        f"snowflake://models/demand_sensing/{result['MODEL_VERSION']}/{result['ALGORITHM'].lower()}",  # MODEL_ARTIFACT_PATH
        result.get('TRAINING_PARAMETERS', {}),                                         # TRAINING_PARAMETERS
        current_timestamp                                                              # CREATED_TIMESTAMP
    )
    registry_rows.append(row)

# Define explicit schema matching table DDL
registry_schema = StructType([
    StructField("MODEL_ID", LongType(), nullable=False),
    StructField("MODEL_NAME", StringType(), nullable=False),
    StructField("MODEL_VERSION", StringType(), nullable=False),
    StructField("ALGORITHM", StringType(), nullable=True),
    StructField("TRAINING_DATE", TimestampType(), nullable=True),
    StructField("MAE", FloatType(), nullable=True),
    StructField("RMSE", FloatType(), nullable=True),
    StructField("MAPE", FloatType(), nullable=True),
    StructField("R2_SCORE", FloatType(), nullable=True),
    StructField("FEATURE_COUNT", IntegerType(), nullable=True),
    StructField("TRAINING_SAMPLES", LongType(), nullable=True),
    StructField("IS_DEPLOYED", BooleanType(), nullable=True),
    StructField("DEPLOYMENT_DATE", TimestampType(), nullable=True),
    StructField("MODEL_ARTIFACT_PATH", StringType(), nullable=True),
    StructField("TRAINING_PARAMETERS", VariantType(), nullable=True),
    StructField("CREATED_TIMESTAMP", TimestampType(), nullable=True)
])

registry_df = session.create_dataframe(registry_rows, schema=registry_schema)

# Write to registry with fail-fast error handling
registry_df.write.mode("append").save_as_table("ATOMIC.MODEL_REGISTRY")
print(f"Saved {len(model_results)} models to ATOMIC.MODEL_REGISTRY")

# Verify write succeeded
verify_count = session.sql("SELECT COUNT(*) AS CNT FROM ATOMIC.MODEL_REGISTRY WHERE MODEL_VERSION = 'v5.0.0'").collect()[0]['CNT']
if verify_count < len(model_results):
    raise RuntimeError(f"Registry write verification failed: expected {len(model_results)}, found {verify_count}")
print(f"Verified {verify_count} v5.0.0 models in registry")

# Show registry contents
print("\n=== Model Registry Summary ===")
session.sql("""
    SELECT MODEL_NAME, MODEL_VERSION, ALGORITHM, MAPE, R2_SCORE, IS_DEPLOYED
    FROM ATOMIC.MODEL_REGISTRY
    WHERE MODEL_VERSION = 'v5.0.0'
    ORDER BY MAPE ASC
""").show()

# =============================================================================
# GENERATE BACKTESTS + 90-DAY FORECASTS
# =============================================================================
# Backtests: Last 6 months of historical data (for accuracy metrics)
# Forecasts: Next 90 days from today (for planning)

print("\n=== Generating Backtests + 90-Day Forecasts ===")

forecast_sql = """
WITH demand_stats AS (
    SELECT 
        PRODUCT_ID,
        SITE_ID,
        AVG(ACTUAL_QUANTITY) AS AVG_QTY,
        STDDEV(ACTUAL_QUANTITY) AS STD_QTY
    FROM ATOMIC.DEMAND_ACTUAL
    GROUP BY PRODUCT_ID, SITE_ID
),
-- Get the date range of actual demand data for backtesting
actual_date_range AS (
    SELECT 
        MAX(DEMAND_DATE) AS MAX_ACTUAL_DATE,
        DATEADD(month, -6, MAX(DEMAND_DATE)) AS BACKTEST_START
    FROM ATOMIC.DEMAND_ACTUAL
),
-- BACKTEST dates: last 6 months of historical data where we have actuals
backtest_dates AS (
    SELECT DISTINCT da.DEMAND_DATE AS FORECAST_DATE
    FROM ATOMIC.DEMAND_ACTUAL da
    CROSS JOIN actual_date_range adr
    WHERE da.DEMAND_DATE >= adr.BACKTEST_START
),
-- FORECAST dates: next 90 days from today (future predictions)
forecast_dates AS (
    SELECT DATEADD(day, seq4(), CURRENT_DATE()) AS FORECAST_DATE 
    FROM TABLE(GENERATOR(ROWCOUNT => 90))
),
-- Union both backtest and forecast dates
all_dates AS (
    SELECT FORECAST_DATE FROM backtest_dates
    UNION
    SELECT FORECAST_DATE FROM forecast_dates
),
combos AS (SELECT DISTINCT PRODUCT_ID, SITE_ID FROM demand_stats LIMIT 100),
-- External indicators - use historical for backtests, latest for forecasts
external_indicators AS (
    SELECT 
        DATE_TRUNC('week', INDICATOR_DATE) AS WEEK_DATE,
        AVG(CASE WHEN INDICATOR_NAME = 'Construction Starts Index' THEN INDICATOR_VALUE END) AS CONSTRUCTION_IDX,
        AVG(CASE WHEN INDICATOR_NAME = 'Clinical Trial Spend ($B)' THEN INDICATOR_VALUE END) AS CLINICAL_IDX,
        AVG(CASE WHEN INDICATOR_NAME = 'Industrial Production Index' THEN INDICATOR_VALUE END) AS PRODUCTION_IDX,
        AVG(CASE WHEN INDICATOR_NAME = 'Manufacturing PMI' THEN INDICATOR_VALUE END) AS PMI_IDX
    FROM ATOMIC.MARKETPLACE_INDICATORS
    GROUP BY DATE_TRUNC('week', INDICATOR_DATE)
),
external_latest AS (
    SELECT 
        AVG(CONSTRUCTION_IDX) AS CONSTRUCTION_IDX,
        AVG(CLINICAL_IDX) AS CLINICAL_IDX,
        AVG(PRODUCTION_IDX) AS PRODUCTION_IDX,
        AVG(PMI_IDX) AS PMI_IDX
    FROM external_indicators
    WHERE WEEK_DATE = (SELECT MAX(WEEK_DATE) FROM external_indicators)
)
SELECT 
    c.PRODUCT_ID, c.SITE_ID, d.FORECAST_DATE AS PREDICTION_DATE,
    CAST(COALESCE(ds.AVG_QTY, 100) AS FLOAT) AS HIST_AVG_QTY,
    CAST(COALESCE(ds.STD_QTY, 10) AS FLOAT) AS HIST_STD_QTY,
    CAST(COALESCE(p.PRODUCT_CATEGORY_ID, 1) AS FLOAT) AS PRODUCT_CATEGORY_ID,
    -- Use historical indicators for backtests, latest for future forecasts
    CAST(COALESCE(ei.CONSTRUCTION_IDX, el.CONSTRUCTION_IDX, 100) AS FLOAT) AS CONSTRUCTION_IDX,
    CAST(COALESCE(ei.CLINICAL_IDX, el.CLINICAL_IDX, 2.5) AS FLOAT) AS CLINICAL_IDX,
    CAST(COALESCE(ei.PRODUCTION_IDX, el.PRODUCTION_IDX, 98) AS FLOAT) AS PRODUCTION_IDX,
    CAST(COALESCE(ei.PMI_IDX, el.PMI_IDX, 50) AS FLOAT) AS PMI_IDX,
    CAST(DAYOFWEEK(d.FORECAST_DATE) AS FLOAT) AS DAY_OF_WEEK,
    CAST(MONTH(d.FORECAST_DATE) AS FLOAT) AS MONTH_NUM,
    CAST(QUARTER(d.FORECAST_DATE) AS FLOAT) AS QUARTER_NUM
FROM combos c 
CROSS JOIN all_dates d
CROSS JOIN external_latest el
LEFT JOIN demand_stats ds ON c.PRODUCT_ID = ds.PRODUCT_ID AND c.SITE_ID = ds.SITE_ID
LEFT JOIN ATOMIC.PRODUCT p ON c.PRODUCT_ID = p.PRODUCT_ID
LEFT JOIN external_indicators ei ON DATE_TRUNC('week', d.FORECAST_DATE) = ei.WEEK_DATE
"""

forecast_input = session.sql(forecast_sql)
forecast_output = xgb_model.predict(forecast_input)

# Calculate confidence intervals (90%)
# Using historical std deviation for interval estimation
output_df = forecast_output.select(
    F.col("PRODUCT_ID"), 
    F.col("SITE_ID"), 
    F.col("PREDICTION_DATE"),
    F.col("PREDICTION").alias("FORECASTED_MATERIAL_DEMAND_QTY"),
    # 90% confidence interval (+/- 1.645 * std estimate)
    F.greatest(F.lit(0), F.col("PREDICTION") - F.col("HIST_STD_QTY") * 1.645).alias("PREDICTION_LOWER_BOUND"),
    (F.col("PREDICTION") + F.col("HIST_STD_QTY") * 1.645).alias("PREDICTION_UPPER_BOUND"),
    F.lit(0.90).alias("CONFIDENCE_LEVEL"),
    F.lit("v5.0.0").alias("MODEL_VERSION"),
    F.current_timestamp().alias("CREATED_TIMESTAMP")
)

# Write to table
output_df.write.mode("overwrite").save_as_table("ATOMIC.DEMAND_FORECAST_PREDICTIONS")
forecast_count = output_df.count()

# Verify write and count backtests vs forecasts
verify_forecast = session.sql("SELECT COUNT(*) AS CNT FROM ATOMIC.DEMAND_FORECAST_PREDICTIONS").collect()[0]['CNT']
if verify_forecast != forecast_count:
    raise RuntimeError(f"Forecast write verification failed: wrote {forecast_count}, found {verify_forecast}")

backtest_count = session.sql("""
    SELECT COUNT(*) AS CNT FROM ATOMIC.DEMAND_FORECAST_PREDICTIONS 
    WHERE PREDICTION_DATE <= (SELECT MAX(DEMAND_DATE) FROM ATOMIC.DEMAND_ACTUAL)
""").collect()[0]['CNT']
future_count = forecast_count - backtest_count

print(f"\n{forecast_count:,} total predictions written to ATOMIC.DEMAND_FORECAST_PREDICTIONS")
print(f"  - Backtests (historical): {backtest_count:,} (for accuracy metrics)")
print(f"  - Forecasts (future): {future_count:,} (for planning)")
print(f"  Model: XGBoost v5.0.0")
print(f"  Confidence Level: 90%")

## 8. Key Takeaways & Interpretation Guide

### What the Model Learned

1. **Historical patterns dominate**: `HIST_AVG_QTY` is typically the strongest predictor, indicating demand tends to follow historical averages
2. **Seasonal effects matter**: Day of week and month contribute to predictions, capturing weekly/monthly demand cycles
3. **External indicators add value**: Macro-economic signals like PMI and construction starts improve forecasting accuracy

### Interpretation Guidelines

| Metric | Good Value | Interpretation |
|--------|------------|----------------|
| **MAPE** | < 20% | Predictions within 20% of actuals on average |
| **R²** | > 0.7 | Model explains >70% of demand variance |
| **MAE** | Context-dependent | Average absolute error in demand units |
| **RMSE** | < 1.5× MAE | Low RMSE/MAE ratio indicates few large errors |

### Confidence Intervals

- **90% Confidence**: `PREDICTION_LOWER_BOUND` to `PREDICTION_UPPER_BOUND`
- **Interpretation**: We expect 90% of actual values to fall within this range
- **Formula**: $\hat{y} \pm 1.645 \times \sigma_{historical}$

### Limitations & Considerations

1. **Assumes stationarity**: Model may not adapt quickly to structural changes in demand patterns
2. **External data lag**: Macro-economic indicators may be published with delay
3. **Limited to 90-day horizon**: Forecast accuracy degrades beyond this window
4. **No supply constraints**: Model predicts demand, not what can be fulfilled

### Mathematical Recap

**XGBoost Objective (simplified)**:
$$L(\theta) = \sum_i l(y_i, \hat{y}_i) + \sum_k \Omega(f_k)$$

Where $l$ is the loss function (MSE for regression) and $\Omega$ is the regularization term preventing overfitting.

### Next Steps

1. **Monitor forecast accuracy**: Compare backtests to actuals weekly
2. **Retrain periodically**: Update model monthly with new data
3. **Expand features**: Consider adding supplier lead times, promotions, weather
4. **Alert thresholds**: Set up alerts when MAPE exceeds acceptable levels