# Time Series Forecasting Model Testing - Improved Version

This notebook tests multiple forecasting models (Naive, SARIMAX) on price data and provides comprehensive comparison and logging functionality.

In [1]:
# ============================================================================
# Cell 1: ALL IMPORTS AND CONFIGURATION
# ============================================================================

import sys
import json
import sqlite3
import logging
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional, Tuple

# Data processing
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import root_mean_squared_error, mean_squared_error

# Visualization
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from IPython.display import display, HTML

# Set up project paths
project_root = Path.cwd().parent if Path.cwd().name == "tests" else Path.cwd()
src_path = project_root / "src"
if src_path.exists() and str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

# Model imports (after path setup)
from models.naive_model import run_naive_model, evaluate_rmse
from models.sarimax_model import run_sarimax

# Logging configuration
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# Global configuration
CONFIG = {
    'DB_PATH': src_path / "data" / "WARP.db",
    'LOGS_DB_PATH': src_path / "data" / "logs.db",
    'HORIZON': 168,
    'TARGET_COL': "Price",
    'FEATURE_COLS': [
        "Load", "shortwave_radiation", "temperature_2m", "direct_normal_irradiance", 
        "diffuse_radiation", "Flow_NO", "yearday_cos", "Flow_GB", "month", "is_dst", 
        "yearday_sin", "is_non_working_day", "hour_cos", "is_weekend", "cloud_cover", 
        "weekday_sin", "hour_sin", "weekday_cos"
    ],
    'TRAIN_START': pd.Timestamp("2025-01-01 00:00:00", tz="UTC"),
    'TRAIN_END': pd.Timestamp("2025-03-14 23:00:00", tz="UTC"),
    'FORECAST_START': pd.Timestamp("2025-03-15 00:00:00", tz="UTC")
}

logger.info("📚 All imports loaded and configuration set")

2025-05-23 16:26:26,747 - sarimax - INFO - 📚 All imports loaded and configuration set


In [2]:
# ============================================================================
# Cell 2: UTILITY FUNCTIONS
# ============================================================================

def safe_rmse(y_true: pd.Series, y_pred: pd.Series) -> float:
    """Calculate RMSE with safety checks for NaN values"""
    try:
        common_idx = y_true.index.intersection(y_pred.index)
        if len(common_idx) == 0:
            return np.nan
        
        y_true_vals = y_true.loc[common_idx].dropna()
        y_pred_vals = y_pred.loc[common_idx].dropna()
        
        if len(y_true_vals) == 0 or len(y_pred_vals) == 0:
            return np.nan
            
        common_final = y_true_vals.index.intersection(y_pred_vals.index)
        if len(common_final) == 0:
            return np.nan
            
        return root_mean_squared_error(
            y_true_vals.loc[common_final], 
            y_pred_vals.loc[common_final]
        )
    except Exception as e:
        logger.warning(f"Error calculating RMSE: {e}")
        return np.nan

def log_rmse_to_sqlite(
    model_name: str, variant: str, train_start: str, train_end: str,
    forecast_start: str, forecast_end: str, rmse_overall: float,
    rmse_per_day: Dict[str, float], rmse_per_hour: Dict[str, float],
    parameters: Dict, features_used: List[str]
):
    """Log RMSE results to central logs database"""
    
    def ensure_directory_exists(path: Path):
        if not path.parent.exists():
            path.parent.mkdir(parents=True, exist_ok=True)
    
    ensure_directory_exists(CONFIG['LOGS_DB_PATH'])
    
    conn = sqlite3.connect(CONFIG['LOGS_DB_PATH'])
    cursor = conn.cursor()
    
    cursor.execute("""
        CREATE TABLE IF NOT EXISTS model_rmse_logs (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            model_name TEXT NOT NULL,
            variant TEXT NOT NULL,
            train_start TEXT NOT NULL,
            train_end TEXT NOT NULL,
            forecast_start TEXT NOT NULL,
            forecast_end TEXT NOT NULL,
            forecast_horizon INTEGER NOT NULL,
            rmse_json TEXT NOT NULL,
            parameters_json TEXT NOT NULL,
            features_used_json TEXT NOT NULL,
            created_at TEXT NOT NULL
        )
    """)
    
    rmse_json = json.dumps({
        "overall": rmse_overall,
        "per_day": rmse_per_day,
        "per_hour": rmse_per_hour
    })
    
    cursor.execute("""
        INSERT INTO model_rmse_logs (
            model_name, variant, train_start, train_end,
            forecast_start, forecast_end, forecast_horizon,
            rmse_json, parameters_json, features_used_json, created_at
        ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    """, (
        model_name, variant, train_start, train_end,
        forecast_start, forecast_end, CONFIG['HORIZON'],
        rmse_json, json.dumps(parameters), json.dumps(features_used),
        datetime.utcnow().isoformat()
    ))
    
    conn.commit()
    conn.close()
    logger.info(f"✅ RMSE logged for {model_name} ({variant})")

def fix_timezone_mismatch(series_with_tz: pd.Series, series_without_tz: pd.Series) -> pd.Series:
    """Fix timezone mismatches between series"""
    if hasattr(series_without_tz.index, 'tz') and series_without_tz.index.tz is None:
        # Convert naive to UTC
        series_without_tz.index = pd.DatetimeIndex(series_without_tz.index).tz_localize('UTC')
    return series_without_tz

logger.info("🔧 Utility functions defined")

2025-05-23 16:26:26,754 - sarimax - INFO - 🔧 Utility functions defined


In [3]:
# ============================================================================
# Cell 3: DATA LOADING AND PREPARATION
# ============================================================================

logger.info("📊 Loading and preparing data...")

# Load data
conn = sqlite3.connect(CONFIG['DB_PATH'])
df = pd.read_sql("SELECT * FROM master_warp", conn)
conn.close()

# Data preparation
df["target_datetime"] = pd.to_datetime(df["target_datetime"], utc=True)
df = df.sort_values("target_datetime").set_index("target_datetime")

# Targets and features
y = df[CONFIG['TARGET_COL']].dropna()
y = y[~y.index.duplicated()]
X = df[CONFIG['FEATURE_COLS']].loc[y.index].dropna()

logger.info(f"✅ Data loaded: {len(y)} rows, {len(CONFIG['FEATURE_COLS'])} features")
logger.info(f"📅 Data range: {y.index.min()} to {y.index.max()}")

print(f"Dataset shape: {df.shape}")
print(f"Target variable: {CONFIG['TARGET_COL']}")
print(f"Number of features: {len(CONFIG['FEATURE_COLS'])}")
print(f"Features: {CONFIG['FEATURE_COLS']}")

2025-05-23 16:26:26,759 - sarimax - INFO - 📊 Loading and preparing data...
2025-05-23 16:26:26,787 - sarimax - INFO - ✅ Data loaded: 3768 rows, 18 features
2025-05-23 16:26:26,788 - sarimax - INFO - 📅 Data range: 2025-01-01 00:00:00+00:00 to 2025-06-06 23:00:00+00:00


Dataset shape: (3768, 31)
Target variable: Price
Number of features: 18
Features: ['Load', 'shortwave_radiation', 'temperature_2m', 'direct_normal_irradiance', 'diffuse_radiation', 'Flow_NO', 'yearday_cos', 'Flow_GB', 'month', 'is_dst', 'yearday_sin', 'is_non_working_day', 'hour_cos', 'is_weekend', 'cloud_cover', 'weekday_sin', 'hour_sin', 'weekday_cos']


In [4]:
# ============================================================================
# Cell 4: TRAIN/TEST SPLIT
# ============================================================================

# Calculate forecast horizon
fh = pd.date_range(
    start=CONFIG['FORECAST_START'], 
    periods=CONFIG['HORIZON'], 
    freq="h"
)

# Create splits
y_train_sarimax = y.loc[CONFIG['TRAIN_START']:CONFIG['TRAIN_END']]
X_train_sarimax = X.loc[CONFIG['TRAIN_START']:CONFIG['TRAIN_END']]
X_test_sarimax = X.loc[fh]

# Naive model preparation
y_train_naive = y.loc[CONFIG['TRAIN_END'] - pd.Timedelta(hours=CONFIG['HORIZON']-1):CONFIG['TRAIN_END']]
y_test = y.loc[fh]

# Feature scaling
scaler = StandardScaler()
X_train_scaled = pd.DataFrame(
    scaler.fit_transform(X_train_sarimax), 
    index=X_train_sarimax.index, 
    columns=X_train_sarimax.columns
)
X_test_scaled = pd.DataFrame(
    scaler.transform(X_test_sarimax), 
    index=X_test_sarimax.index, 
    columns=X_test_sarimax.columns
)

logger.info(f"🔄 Data split complete:")
logger.info(f"  - Training: {len(y_train_sarimax)} samples")
logger.info(f"  - Forecast: {len(y_test)} samples")
logger.info(f"  - Forecast period: {fh[0]} to {fh[-1]}")

print(f"\nTraining period: {CONFIG['TRAIN_START']} to {CONFIG['TRAIN_END']}")
print(f"Forecast period: {fh[0]} to {fh[-1]}")
print(f"Training samples: {len(y_train_sarimax)}")
print(f"Forecast samples: {len(y_test)}")

2025-05-23 16:26:26,797 - sarimax - INFO - 🔄 Data split complete:
2025-05-23 16:26:26,797 - sarimax - INFO -   - Training: 1752 samples
2025-05-23 16:26:26,797 - sarimax - INFO -   - Forecast: 168 samples
2025-05-23 16:26:26,798 - sarimax - INFO -   - Forecast period: 2025-03-15 00:00:00+00:00 to 2025-03-21 23:00:00+00:00



Training period: 2025-01-01 00:00:00+00:00 to 2025-03-14 23:00:00+00:00
Forecast period: 2025-03-15 00:00:00+00:00 to 2025-03-21 23:00:00+00:00
Training samples: 1752
Forecast samples: 168


In [5]:
# ============================================================================
# Cell 5: MODEL TRAINING AND EVALUATION
# ============================================================================

logger.info("🤖 Training and evaluating models...")

results = {}

# 1. NAIVE MODEL
logger.info("🟡 Running Naive model...")
try:
    naive_preds = run_naive_model(y, lag=CONFIG['HORIZON']).loc[fh]
    rmse_naive = safe_rmse(y.loc[fh], naive_preds)
    results['naive'] = {
        'predictions': naive_preds,
        'rmse': rmse_naive,
        'model_type': 'Naive',
        'parameters': {'lag': CONFIG['HORIZON']}
    }
    logger.info(f"  ✅ Naive RMSE: {rmse_naive:.4f}")
except Exception as e:
    logger.error(f"  ❌ Naive model failed: {e}")
    results['naive'] = None

# 2. SARIMAX WITHOUT FEATURES
logger.info("🔵 Running SARIMAX without features...")
try:
    # Fix frequency issues
    y_train_freq = y_train_sarimax.copy()
    y_train_freq.index = pd.DatetimeIndex(y_train_freq.index, freq='h')
    
    sarimax_preds_nf, rmse_sarimax_nf = run_sarimax(
        y_train_freq,
        X_train=None,
        X_test=pd.DataFrame(index=pd.DatetimeIndex(fh, freq='h')),
        order=(1, 1, 1),
        seasonal_order=(1, 1, 1, 24)
    )
    
    # Calculate RMSE if not returned
    if rmse_sarimax_nf is None and sarimax_preds_nf is not None:
        rmse_sarimax_nf = safe_rmse(y.loc[fh], sarimax_preds_nf)
    
    results['sarimax_no_exog'] = {
        'predictions': sarimax_preds_nf,
        'rmse': rmse_sarimax_nf,
        'model_type': 'SARIMAX',
        'parameters': {'order': (1, 1, 1), 'seasonal_order': (1, 1, 1, 24), 'exog': False}
    }
    logger.info(f"  ✅ SARIMAX (no exog) RMSE: {rmse_sarimax_nf:.4f}")
except Exception as e:
    logger.error(f"  ❌ SARIMAX (no exog) failed: {e}")
    results['sarimax_no_exog'] = None

# 3. SARIMAX WITH FEATURES
logger.info("🟢 Running SARIMAX with features...")
try:
    # Fix frequency and timezone issues
    X_train_scaled_freq = X_train_scaled.copy()
    X_train_scaled_freq.index = pd.DatetimeIndex(X_train_scaled_freq.index, freq='h')
    X_test_scaled_freq = X_test_scaled.copy()
    X_test_scaled_freq.index = pd.DatetimeIndex(X_test_scaled_freq.index, freq='h')
    
    sarimax_preds_wf, rmse_sarimax_wf = run_sarimax(
        y_train_freq,
        X_train_scaled_freq,
        X_test_scaled_freq,
        order=(1, 1, 1),
        seasonal_order=(1, 1, 1, 24)
    )
    
    # Calculate RMSE if not returned
    if rmse_sarimax_wf is None and sarimax_preds_wf is not None:
        rmse_sarimax_wf = safe_rmse(y.loc[fh], sarimax_preds_wf)
    
    results['sarimax_with_exog'] = {
        'predictions': sarimax_preds_wf,
        'rmse': rmse_sarimax_wf,
        'model_type': 'SARIMAX',
        'parameters': {'order': (1, 1, 1), 'seasonal_order': (1, 1, 1, 24), 'exog': True},
        'features': CONFIG['FEATURE_COLS']
    }
    logger.info(f"  ✅ SARIMAX (with exog) RMSE: {rmse_sarimax_wf:.4f}")
except Exception as e:
    logger.error(f"  ❌ SARIMAX (with exog) failed: {e}")
    results['sarimax_with_exog'] = None

# SUMMARY
valid_results = {k: v for k, v in results.items() if v is not None}
if valid_results:
    best_model = min(valid_results.items(), key=lambda x: x[1]['rmse'])
    logger.info(f"🏆 Best model: {best_model[0]} with RMSE: {best_model[1]['rmse']:.4f}")
    
    print("📊 MODEL COMPARISON RESULTS:")
    print("=" * 50)
    for name, result in valid_results.items():
        print(f"{name.replace('_', ' ').title():20}: {result['rmse']:.4f}")
    print("=" * 50)
    print(f"🏆 WINNER: {best_model[0].replace('_', ' ').title()}")
else:
    logger.error("❌ No models completed successfully")

2025-05-23 16:26:26,805 - sarimax - INFO - 🤖 Training and evaluating models...
2025-05-23 16:26:26,806 - sarimax - INFO - 🟡 Running Naive model...
2025-05-23 16:26:26,807 - sarimax - INFO -   ✅ Naive RMSE: 0.0432
2025-05-23 16:26:26,807 - sarimax - INFO - 🔵 Running SARIMAX without features...
2025-05-23 16:26:26,808 - sarimax - INFO - 📈 Fitting SARIMAX with order=(1, 1, 1), seasonal_order=(1, 1, 1, 24)
2025-05-23 16:26:42,552 - sarimax - INFO - 📊 RMSE: 0.03
2025-05-23 16:26:42,564 - sarimax - INFO -   ✅ SARIMAX (no exog) RMSE: 0.0382
2025-05-23 16:26:42,569 - sarimax - INFO - 🟢 Running SARIMAX with features...
2025-05-23 16:26:42,572 - sarimax - INFO - 📈 Fitting SARIMAX with order=(1, 1, 1), seasonal_order=(1, 1, 1, 24)
2025-05-23 16:27:49,100 - sarimax - INFO - 📊 RMSE: 0.04
2025-05-23 16:27:49,133 - sarimax - INFO -   ✅ SARIMAX (with exog) RMSE: 0.0245
2025-05-23 16:27:49,134 - sarimax - INFO - 🏆 Best model: sarimax_with_exog with RMSE: 0.0245


📊 MODEL COMPARISON RESULTS:
Naive               : 0.0432
Sarimax No Exog     : 0.0382
Sarimax With Exog   : 0.0245
🏆 WINNER: Sarimax With Exog


In [6]:
# ============================================================================
# Cell 6: DETAILED ANALYSIS AND LOGGING
# ============================================================================

logger.info("📈 Performing detailed analysis...")

def calculate_detailed_rmse(actual: pd.Series, predicted: pd.Series) -> Dict:
    """Calculate detailed RMSE metrics"""
    if predicted is None:
        return None
    
    try:
        # Fix timezone issues
        if hasattr(predicted.index, 'tz') and predicted.index.tz is None:
            predicted = fix_timezone_mismatch(actual, predicted)
        
        # Overall RMSE
        overall_rmse = safe_rmse(actual, predicted)
        
        # Daily RMSE
        df_combined = pd.DataFrame({
            'actual': actual,
            'predicted': predicted
        }).dropna()
        
        df_combined['date'] = df_combined.index.date
        daily_rmse = df_combined.groupby('date').apply(
            lambda x: safe_rmse(x['actual'], x['predicted'])
        ).dropna()
        
        # Convert daily RMSE to required format (day 1-7)
        rmse_per_day = {}
        for i, (date, rmse_val) in enumerate(daily_rmse.items(), 1):
            if i <= 7:
                rmse_per_day[str(i)] = round(float(rmse_val), 4)
        
        # Hourly absolute errors
        hourly_errors = np.abs(actual - predicted).dropna()
        rmse_per_hour = {}
        for i, error in enumerate(hourly_errors.values[:168]):  # Limit to 168 hours
            rmse_per_hour[str(i)] = round(float(error), 4)
        
        return {
            'overall': round(float(overall_rmse), 4),
            'per_day': rmse_per_day,
            'per_hour': rmse_per_hour
        }
    except Exception as e:
        logger.error(f"Error in detailed RMSE calculation: {e}")
        return None

# Calculate detailed metrics and log results
for model_name, result in valid_results.items():
    if result and result['predictions'] is not None:
        detailed_rmse = calculate_detailed_rmse(y.loc[fh], result['predictions'])
        
        if detailed_rmse:
            # Log to database
            try:
                log_rmse_to_sqlite(
                    model_name=result['model_type'],
                    variant=model_name,
                    train_start=CONFIG['TRAIN_START'].isoformat(),
                    train_end=CONFIG['TRAIN_END'].isoformat(),
                    forecast_start=CONFIG['FORECAST_START'].isoformat(),
                    forecast_end=(CONFIG['FORECAST_START'] + pd.Timedelta(hours=CONFIG['HORIZON']-1)).isoformat(),
                    rmse_overall=detailed_rmse['overall'],
                    rmse_per_day=detailed_rmse['per_day'],
                    rmse_per_hour=detailed_rmse['per_hour'],
                    parameters=result['parameters'],
                    features_used=result.get('features', [])
                )
            except Exception as e:
                logger.error(f"Failed to log {model_name}: {e}")

logger.info("✅ Detailed analysis complete and results logged")

2025-05-23 16:27:49,146 - sarimax - INFO - 📈 Performing detailed analysis...
  daily_rmse = df_combined.groupby('date').apply(
2025-05-23 16:27:49,167 - sarimax - INFO - ✅ RMSE logged for Naive (naive)
  daily_rmse = df_combined.groupby('date').apply(
2025-05-23 16:27:49,188 - sarimax - INFO - ✅ RMSE logged for SARIMAX (sarimax_no_exog)
  daily_rmse = df_combined.groupby('date').apply(
2025-05-23 16:27:49,204 - sarimax - INFO - ✅ RMSE logged for SARIMAX (sarimax_with_exog)
2025-05-23 16:27:49,204 - sarimax - INFO - ✅ Detailed analysis complete and results logged


In [7]:
# ============================================================================
# Cell 7: VISUALIZATION
# ============================================================================

logger.info("📊 Creating visualizations...")

# Interactive Plotly visualization
fig = go.Figure()

# Training data
fig.add_trace(go.Scatter(
    x=y_train_sarimax.index,
    y=y_train_sarimax.values,
    mode="lines",
    name="Training Data",
    line=dict(color="lightgray", width=1),
    opacity=0.7
))

# Actual values
fig.add_trace(go.Scatter(
    x=y.loc[fh].index,
    y=y.loc[fh].values,
    mode="lines",
    name="Actual",
    line=dict(color="black", width=3)
))

# Model predictions
colors = ["orange", "steelblue", "forestgreen"]
styles = ["dash", "dot", "dashdot"]

for i, (name, result) in enumerate(valid_results.items()):
    if result and result['predictions'] is not None:
        preds = result['predictions']
        # Fix timezone if needed
        if hasattr(preds.index, 'tz') and preds.index.tz is None:
            preds = fix_timezone_mismatch(y.loc[fh], preds)
        
        fig.add_trace(go.Scatter(
            x=preds.index,
            y=preds.values,
            mode="lines",
            name=f"{name.replace('_', ' ').title()} (RMSE: {result['rmse']:.3f})",
            line=dict(color=colors[i % len(colors)], dash=styles[i % len(styles)], width=2)
        ))

fig.update_layout(
    title="Time Series Forecasting Model Comparison",
    xaxis_title="Time (UTC)",
    yaxis_title="Price",
    template="plotly_white",
    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5),
    hovermode="x unified",
    height=600
)

fig.show()

# Performance summary table
summary_data = []
for name, result in valid_results.items():
    if result:
        summary_data.append({
            'Model': name.replace('_', ' ').title(),
            'RMSE': f"{result['rmse']:.4f}",
            'Parameters': str(result['parameters']),
            'Features Used': 'Yes' if result.get('features') else 'No'
        })

if summary_data:
    summary_df = pd.DataFrame(summary_data)
    summary_df = summary_df.sort_values('RMSE')
    
    display(HTML("<h3>📊 Model Performance Summary</h3>"))
    display(summary_df)

logger.info("✅ Visualizations complete")

2025-05-23 16:27:49,211 - sarimax - INFO - 📊 Creating visualizations...


Unnamed: 0,Model,RMSE,Parameters,Features Used
2,Sarimax With Exog,0.0245,"{'order': (1, 1, 1), 'seasonal_order': (1, 1, ...",Yes
1,Sarimax No Exog,0.0382,"{'order': (1, 1, 1), 'seasonal_order': (1, 1, ...",No
0,Naive,0.0432,{'lag': 168},No


2025-05-23 16:27:49,443 - sarimax - INFO - ✅ Visualizations complete


In [8]:
# ============================================================================
# Cell 8: ENHANCED ROLLING WINDOW VALIDATION WITH PARAMETER CAPTURE
# ============================================================================

import time
import numpy as np
from statsmodels.tsa.statespace.sarimax import SARIMAX

logger.info("🔄 Starting enhanced rolling window validation with parameter capture...")

def extract_model_info(model_result, model_name, hyperparams=None):
    """Extract comprehensive model information"""
    
    model_info = {
        'parameters': {},
        'hyperparameters': hyperparams or {},
        'diagnostics': {},
        'convergence': {},
        'summary': None
    }
    
    if model_result is None:
        return model_info
        
    try:
        # Model parameters (coefficients)
        if hasattr(model_result, 'params'):
            model_info['parameters'] = dict(model_result.params)
            
        # Model diagnostics
        if hasattr(model_result, 'aic'):
            model_info['diagnostics']['aic'] = float(model_result.aic)
        if hasattr(model_result, 'bic'):
            model_info['diagnostics']['bic'] = float(model_result.bic)
        if hasattr(model_result, 'llf'):
            model_info['diagnostics']['log_likelihood'] = float(model_result.llf)
        if hasattr(model_result, 'hqic'):
            model_info['diagnostics']['hqic'] = float(model_result.hqic)
            
        # Convergence information
        if hasattr(model_result, 'mle_retvals'):
            retvals = model_result.mle_retvals
            model_info['convergence'] = {
                'converged': retvals.get('converged', False),
                'iterations': retvals.get('iterations', None),
                'function_calls': retvals.get('fcalls', None),
                'gradient_calls': retvals.get('gcalls', None),
                'warning_flag': retvals.get('warnflag', None)
            }
            
        # Model summary (truncated)
        if hasattr(model_result, 'summary'):
            try:
                summary_str = str(model_result.summary())
                # Keep first 2000 characters to avoid bloat
                model_info['summary'] = summary_str[:2000] + "..." if len(summary_str) > 2000 else summary_str
            except:
                model_info['summary'] = "Summary extraction failed"
                
    except Exception as e:
        logger.warning(f"Error extracting model info for {model_name}: {e}")
        
    return model_info

def run_sarimax_with_info(y_train, X_train, X_test, order, seasonal_order):
    """Run SARIMAX and return predictions, RMSE, and model info"""
    try:
        # Create and fit model
        if X_train is not None:
            model = SARIMAX(y_train, exog=X_train, order=order, seasonal_order=seasonal_order)
        else:
            model = SARIMAX(y_train, order=order, seasonal_order=seasonal_order)
            
        model_result = model.fit(disp=False, maxiter=100)
        
        # Make predictions
        if X_test is not None:
            predictions = model_result.forecast(steps=len(X_test), exog=X_test)
        else:
            predictions = model_result.forecast(steps=len(X_test) if X_test is not None else 168)
            
        # Calculate RMSE
        # Note: you'll need to align this with your actual values
        rmse = None  # Will be calculated later with actual values
        
        return predictions, rmse, model_result
        
    except Exception as e:
        logger.error(f"SARIMAX failed: {e}")
        return None, None, None

def enhanced_rolling_window_validation(n_windows: int = 3) -> pd.DataFrame:
    """Enhanced rolling window validation with comprehensive logging"""
    rolling_results = []
    model_info_log = []  # Store detailed model information
    
    for i in range(n_windows):
        logger.info(f"📊 Processing window {i+1}/{n_windows}")
        
        # Adjust time windows
        delta = pd.Timedelta(days=i)
        train_start_i = CONFIG['TRAIN_START'] + delta
        train_end_i = CONFIG['TRAIN_END'] + delta
        forecast_start_i = train_end_i + pd.Timedelta(hours=1)
        fh_i = pd.date_range(start=forecast_start_i, periods=CONFIG['HORIZON'], freq="h")
        
        # Check data availability
        if fh_i[-1] not in y.index:
            logger.warning(f"Insufficient data for window {i+1}, skipping...")
            continue
        
        window_results = {
            'window': i+1,
            'train_start': train_start_i,
            'train_end': train_end_i,
            'forecast_start': forecast_start_i,
            'forecast_end': fh_i[-1]
        }
        
        # Get actual values for this window
        y_actual_i = y.loc[fh_i]
        y_train_i = y.loc[train_start_i:train_end_i]
        
        # Log window characteristics
        train_stats = {
            'mean': float(y_train_i.mean()),
            'std': float(y_train_i.std()),
            'min': float(y_train_i.min()),
            'max': float(y_train_i.max()),
            'count': len(y_train_i)
        }
        
        actual_stats = {
            'mean': float(y_actual_i.mean()),
            'std': float(y_actual_i.std()),
            'min': float(y_actual_i.min()),
            'max': float(y_actual_i.max()),
            'count': len(y_actual_i)
        }
        
        # Test each model
        for model_name in ['naive', 'sarimax_no_exog', 'sarimax_with_exog']:
            start_time = time.time()
            
            try:
                if model_name == 'naive':
                    # Naive model
                    naive_preds_i = run_naive_model(y.loc[:train_end_i], lag=CONFIG['HORIZON']).loc[fh_i]
                    rmse_i = safe_rmse(y_actual_i, naive_preds_i)
                    
                    # Store model info for naive
                    model_info = {
                        'parameters': {'lag': CONFIG['HORIZON']},
                        'hyperparameters': {'model_type': 'naive', 'lag': CONFIG['HORIZON']},
                        'diagnostics': {},
                        'convergence': {'converged': True},
                        'summary': f"Naive model with lag={CONFIG['HORIZON']}"
                    }
                    
                else:
                    # SARIMAX models
                    y_train_i.index = pd.DatetimeIndex(y_train_i.index, freq='h')
                    order = (1, 1, 1)
                    seasonal_order = (1, 1, 1, 24)
                    
                    if model_name == 'sarimax_no_exog':
                        preds_i, rmse_i, model_result = run_sarimax_with_info(
                            y_train_i, None, 
                            pd.DataFrame(index=pd.DatetimeIndex(fh_i, freq='h')),
                            order, seasonal_order
                        )
                        
                        hyperparams = {
                            'order': order,
                            'seasonal_order': seasonal_order,
                            'model_type': 'sarimax_no_exog'
                        }
                        
                    else:  # sarimax_with_exog
                        X_train_i = X.loc[train_start_i:train_end_i]
                        X_test_i = X.loc[fh_i]
                        scaler_i = StandardScaler()
                        X_train_scaled_i = pd.DataFrame(
                            scaler_i.fit_transform(X_train_i),
                            index=X_train_i.index, columns=X_train_i.columns
                        )
                        X_test_scaled_i = pd.DataFrame(
                            scaler_i.transform(X_test_i),
                            index=X_test_i.index, columns=X_test_i.columns
                        )
                        X_train_scaled_i.index = pd.DatetimeIndex(X_train_scaled_i.index, freq='h')
                        X_test_scaled_i.index = pd.DatetimeIndex(X_test_scaled_i.index, freq='h')
                        
                        preds_i, rmse_i, model_result = run_sarimax_with_info(
                            y_train_i, X_train_scaled_i, X_test_scaled_i,
                            order, seasonal_order
                        )
                        
                        hyperparams = {
                            'order': order,
                            'seasonal_order': seasonal_order,
                            'model_type': 'sarimax_with_exog',
                            'n_features': len(X_train_scaled_i.columns),
                            'feature_names': list(X_train_scaled_i.columns),
                            'scaler_type': 'StandardScaler'
                        }
                        
                        # Feature drift analysis
                        feature_stats = {}
                        for col in X_train_scaled_i.columns:
                            train_mean = X_train_scaled_i[col].mean()
                            test_mean = X_test_scaled_i[col].mean()
                            feature_stats[col] = {
                                'train_mean': float(train_mean),
                                'test_mean': float(test_mean),
                                'drift': float(abs(test_mean - train_mean))
                            }
                    
                    # Calculate RMSE if not already done
                    if rmse_i is None and preds_i is not None:
                        rmse_i = safe_rmse(y_actual_i, preds_i)
                    
                    # Extract model information
                    model_info = extract_model_info(model_result, model_name, hyperparams)
                
                # Store results
                window_results[model_name] = rmse_i
                execution_time = time.time() - start_time
                
                # Store detailed model info for logging
                detailed_info = {
                    'window_id': i+1,
                    'model_name': model_name,
                    'train_start': train_start_i.isoformat(),
                    'train_end': train_end_i.isoformat(),
                    'forecast_start': forecast_start_i.isoformat(),
                    'forecast_end': fh_i[-1].isoformat(),
                    'rmse': rmse_i,
                    'train_stats': train_stats,
                    'actual_stats': actual_stats,
                    'model_parameters': model_info['parameters'],
                    'hyperparameters': model_info['hyperparameters'],
                    'model_diagnostics': model_info['diagnostics'],
                    'convergence_info': model_info['convergence'],
                    'model_summary': model_info['summary'],
                    'feature_stats': feature_stats if model_name == 'sarimax_with_exog' else None,
                    'execution_time': execution_time
                }
                
                model_info_log.append(detailed_info)
                
                logger.info(f"  {model_name}: RMSE = {rmse_i:.6f}, Time = {execution_time:.2f}s")
                
                # Log convergence issues
                if model_name != 'naive' and model_result:
                    conv_info = model_info['convergence']
                    if not conv_info.get('converged', True):
                        logger.warning(f"  ⚠️  {model_name} did not converge properly")
                    if conv_info.get('warning_flag', 0) > 0:
                        logger.warning(f"  ⚠️  {model_name} has optimization warnings")
                
            except Exception as e:
                logger.error(f"Model {model_name} failed in window {i+1}: {str(e)}")
                window_results[model_name] = np.nan
                
                # Store failed attempt info
                failed_info = {
                    'window_id': i+1,
                    'model_name': model_name,
                    'train_start': train_start_i.isoformat(),
                    'train_end': train_end_i.isoformat(),
                    'forecast_start': forecast_start_i.isoformat(),
                    'forecast_end': fh_i[-1].isoformat(),
                    'rmse': None,
                    'train_stats': train_stats,
                    'actual_stats': actual_stats,
                    'execution_time': time.time() - start_time,
                    'notes': f"Model failed: {str(e)}"
                }
                model_info_log.append(failed_info)
        
        rolling_results.append(window_results)
        logger.info(f"Window {i+1} complete\n")
    
    # Store model info for Cell 9
    globals()['model_info_log'] = model_info_log
    
    return pd.DataFrame(rolling_results)

# Run enhanced rolling validation
rolling_df = enhanced_rolling_window_validation(n_windows=3)

# Display results
display(HTML("<h3>🔄 Enhanced Rolling Window Validation Results</h3>"))
display(rolling_df)

logger.info("✅ Enhanced rolling window validation complete with parameter capture")

2025-05-23 16:27:49,457 - sarimax - INFO - 🔄 Starting enhanced rolling window validation with parameter capture...
2025-05-23 16:27:49,458 - sarimax - INFO - 📊 Processing window 1/3
2025-05-23 16:27:49,460 - sarimax - ERROR - Model naive failed in window 1: "None of [DatetimeIndex(['2025-03-15 00:00:00+00:00', '2025-03-15 01:00:00+00:00',\n               '2025-03-15 02:00:00+00:00', '2025-03-15 03:00:00+00:00',\n               '2025-03-15 04:00:00+00:00', '2025-03-15 05:00:00+00:00',\n               '2025-03-15 06:00:00+00:00', '2025-03-15 07:00:00+00:00',\n               '2025-03-15 08:00:00+00:00', '2025-03-15 09:00:00+00:00',\n               ...\n               '2025-03-21 14:00:00+00:00', '2025-03-21 15:00:00+00:00',\n               '2025-03-21 16:00:00+00:00', '2025-03-21 17:00:00+00:00',\n               '2025-03-21 18:00:00+00:00', '2025-03-21 19:00:00+00:00',\n               '2025-03-21 20:00:00+00:00', '2025-03-21 21:00:00+00:00',\n               '2025-03-21 22:00:00+00:00', '2

Unnamed: 0,window,train_start,train_end,forecast_start,forecast_end,naive,sarimax_no_exog,sarimax_with_exog
0,1,2025-01-01 00:00:00+00:00,2025-03-14 23:00:00+00:00,2025-03-15 00:00:00+00:00,2025-03-21 23:00:00+00:00,,0.038786,0.025547
1,2,2025-01-02 00:00:00+00:00,2025-03-15 23:00:00+00:00,2025-03-16 00:00:00+00:00,2025-03-22 23:00:00+00:00,,0.04087,0.051685
2,3,2025-01-03 00:00:00+00:00,2025-03-16 23:00:00+00:00,2025-03-17 00:00:00+00:00,2025-03-23 23:00:00+00:00,,0.037788,0.099248


2025-05-23 17:01:50,675 - sarimax - INFO - ✅ Enhanced rolling window validation complete with parameter capture


In [9]:
# ============================================================================
# Cell 9: SAVE ROLLING WINDOW RESULTS (INLINE FUNCTIONS)
# ============================================================================

import sqlite3
import json
from datetime import datetime

logger.info("💾 Saving rolling window results...")

def log_rolling_window_to_sqlite(window_id, model_name, train_start, train_end, 
                                forecast_start, forecast_end, rmse, notes=None):
    """Simple logging function for rolling window results"""
    
    conn = sqlite3.connect(CONFIG['LOGS_DB_PATH'])
    cursor = conn.cursor()
    
    # Create table if not exists
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS rolling_window_logs (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            window_id INTEGER NOT NULL,
            model_name TEXT NOT NULL,
            train_start TEXT NOT NULL,
            train_end TEXT NOT NULL,
            forecast_start TEXT NOT NULL,
            forecast_end TEXT NOT NULL,
            rmse REAL,
            notes TEXT,
            created_at TEXT NOT NULL
        )
    ''')
    
    # Insert record
    cursor.execute('''
        INSERT INTO rolling_window_logs (
            window_id, model_name, train_start, train_end, 
            forecast_start, forecast_end, rmse, notes, created_at
        ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
    ''', (
        window_id, model_name, train_start, train_end,
        forecast_start, forecast_end, rmse, notes,
        datetime.utcnow().isoformat()
    ))
    
    conn.commit()
    conn.close()

def analyze_performance_degradation():
    """Analyze performance trends from logged data"""
    
    conn = sqlite3.connect(CONFIG['LOGS_DB_PATH'])
    cursor = conn.cursor()
    
    cursor.execute('''
        SELECT window_id, model_name, rmse 
        FROM rolling_window_logs 
        ORDER BY created_at DESC, window_id ASC
        LIMIT 50
    ''')
    
    results = cursor.fetchall()
    conn.close()
    
    # Group by model
    model_data = {}
    for window_id, model_name, rmse in results:
        if model_name not in model_data:
            model_data[model_name] = []
        if rmse is not None:
            model_data[model_name].append((window_id, rmse))
    
    # Calculate degradation
    analysis = {'degradation_summary': {}}
    for model_name, data in model_data.items():
        if len(data) >= 2:
            data.sort()  # Sort by window_id
            first_rmse = data[0][1]
            last_rmse = data[-1][1]
            
            degradation_pct = ((last_rmse - first_rmse) / first_rmse) * 100
            
            analysis['degradation_summary'][model_name] = {
                'first_window_rmse': first_rmse,
                'last_window_rmse': last_rmse,
                'degradation_percent': degradation_pct,
                'trend': 'SEVERE' if degradation_pct > 100 else 
                        'SIGNIFICANT' if degradation_pct > 50 else
                        'MODERATE' if degradation_pct > 20 else
                        'STABLE'
            }
    
    return analysis

# ============================================================================
# Cell 9: SAVE ROLLING WINDOW RESULTS WITH DETAILED MODEL INFO
# ============================================================================

import sqlite3
import json
from datetime import datetime

logger.info("💾 Saving rolling window results with detailed model information...")

def log_rolling_window_to_sqlite(window_id, model_name, train_start, train_end, 
                                forecast_start, forecast_end, rmse, 
                                model_parameters=None, hyperparameters=None,
                                model_diagnostics=None, convergence_info=None,
                                model_summary=None, feature_stats=None,
                                train_stats=None, actual_stats=None,
                                execution_time=None, notes=None):
    """Enhanced logging function with model parameters and diagnostics"""
    
    conn = sqlite3.connect(CONFIG['LOGS_DB_PATH'])
    cursor = conn.cursor()
    
    # Create enhanced table
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS rolling_window_logs (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            window_id INTEGER NOT NULL,
            model_name TEXT NOT NULL,
            train_start TEXT NOT NULL,
            train_end TEXT NOT NULL,
            forecast_start TEXT NOT NULL,
            forecast_end TEXT NOT NULL,
            rmse REAL,
            model_parameters_json TEXT,
            hyperparameters_json TEXT,
            model_diagnostics_json TEXT,
            convergence_info_json TEXT,
            model_summary TEXT,
            feature_stats_json TEXT,
            train_stats_json TEXT,
            actual_stats_json TEXT,
            execution_time_seconds REAL,
            notes TEXT,
            created_at TEXT NOT NULL
        )
    ''')
    
    # Convert dictionaries to JSON
    model_parameters_json = json.dumps(model_parameters) if model_parameters else None
    hyperparameters_json = json.dumps(hyperparameters) if hyperparameters else None
    model_diagnostics_json = json.dumps(model_diagnostics) if model_diagnostics else None
    convergence_info_json = json.dumps(convergence_info) if convergence_info else None
    feature_stats_json = json.dumps(feature_stats) if feature_stats else None
    train_stats_json = json.dumps(train_stats) if train_stats else None
    actual_stats_json = json.dumps(actual_stats) if actual_stats else None
    
    # Insert record
    cursor.execute('''
        INSERT INTO rolling_window_logs (
            window_id, model_name, train_start, train_end, 
            forecast_start, forecast_end, rmse,
            model_parameters_json, hyperparameters_json, model_diagnostics_json,
            convergence_info_json, model_summary, feature_stats_json,
            train_stats_json, actual_stats_json, execution_time_seconds, notes, created_at
        ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    ''', (
        window_id, model_name, train_start, train_end,
        forecast_start, forecast_end, rmse,
        model_parameters_json, hyperparameters_json, model_diagnostics_json,
        convergence_info_json, model_summary, feature_stats_json,
        train_stats_json, actual_stats_json, execution_time, notes,
        datetime.utcnow().isoformat()
    ))
    
    conn.commit()
    conn.close()

def analyze_model_parameter_changes():
    """Analyze how model parameters change across windows"""
    
    conn = sqlite3.connect(CONFIG['LOGS_DB_PATH'])
    cursor = conn.cursor()
    
    cursor.execute('''
        SELECT window_id, model_name, rmse, model_parameters_json, 
               hyperparameters_json, model_diagnostics_json, convergence_info_json
        FROM rolling_window_logs 
        ORDER BY window_id ASC, model_name ASC
    ''')
    
    results = cursor.fetchall()
    conn.close()
    
    # Group by model
    model_analysis = {}
    for window_id, model_name, rmse, params_json, hyper_json, diag_json, conv_json in results:
        if model_name not in model_analysis:
            model_analysis[model_name] = []
            
        entry = {
            'window_id': window_id,
            'rmse': rmse,
            'parameters': json.loads(params_json) if params_json else {},
            'hyperparameters': json.loads(hyper_json) if hyper_json else {},
            'diagnostics': json.loads(diag_json) if diag_json else {},
            'convergence': json.loads(conv_json) if conv_json else {}
        }
        model_analysis[model_name].append(entry)
    
    return model_analysis

# Save detailed model information if available
if 'model_info_log' in globals() and model_info_log:
    
    for info in model_info_log:
        try:
            log_rolling_window_to_sqlite(
                window_id=info['window_id'],
                model_name=info['model_name'],
                train_start=info['train_start'],
                train_end=info['train_end'],
                forecast_start=info['forecast_start'],
                forecast_end=info['forecast_end'],
                rmse=info['rmse'],
                model_parameters=info.get('model_parameters'),
                hyperparameters=info.get('hyperparameters'),
                model_diagnostics=info.get('model_diagnostics'),
                convergence_info=info.get('convergence_info'),
                model_summary=info.get('model_summary'),
                feature_stats=info.get('feature_stats'),
                train_stats=info.get('train_stats'),
                actual_stats=info.get('actual_stats'),
                execution_time=info.get('execution_time'),
                notes=info.get('notes')
            )
            logger.info(f"✅ Logged detailed info for window {info['window_id']}, model {info['model_name']}")
        except Exception as e:
            logger.error(f"Failed to log detailed info: {e}")
    
    # Analyze parameter changes
    try:
        analysis = analyze_model_parameter_changes()
        
        print("\n📊 MODEL PARAMETER ANALYSIS ACROSS WINDOWS:")
        print("=" * 70)
        
        for model_name, windows in analysis.items():
            print(f"\n🔍 {model_name.upper()} ANALYSIS:")
            
            for window_data in windows:
                window_id = window_data['window_id']
                rmse = window_data['rmse']
                diagnostics = window_data['diagnostics']
                convergence = window_data['convergence']
                
                print(f"\n  Window {window_id}:")
                print(f"    RMSE: {rmse:.6f}" if rmse else "    RMSE: FAILED")
                
                # Show key diagnostics
                if diagnostics:
                    if 'aic' in diagnostics:
                        print(f"    AIC: {diagnostics['aic']:.2f}")
                    if 'bic' in diagnostics:
                        print(f"    BIC: {diagnostics['bic']:.2f}")
                
                # Show convergence issues
                if convergence and not convergence.get('converged', True):
                    print(f"    ⚠️  Convergence: FAILED")
                    if 'iterations' in convergence:
                        print(f"    Iterations: {convergence['iterations']}")
                
                # Show key parameters for SARIMAX models
                if model_name != 'naive':
                    params = window_data['parameters']
                    if params:
                        # Show first few parameters to avoid clutter
                        param_keys = list(params.keys())[:5]
                        for key in param_keys:
                            if isinstance(params[key], (int, float)):
                                print(f"    {key}: {params[key]:.4f}")
        
        # Parameter stability analysis
        print(f"\n🔍 PARAMETER STABILITY ANALYSIS:")
        for model_name, windows in analysis.items():
            if len(windows) > 1 and model_name != 'naive':
                print(f"\n{model_name.upper()}:")
                
                # Check AIC/BIC trends
                aics = [w['diagnostics'].get('aic') for w in windows if w['diagnostics'].get('aic')]
                if len(aics) > 1:
                    aic_trend = aics[-1] - aics[0]
                    print(f"  AIC trend: {aic_trend:+.2f} (lower is better)")
                
                # Check convergence consistency
                converged_windows = [w['convergence'].get('converged', True) for w in windows]
                convergence_rate = sum(converged_windows) / len(converged_windows) * 100
                print(f"  Convergence rate: {convergence_rate:.1f}%")
                
                if convergence_rate < 100:
                    print(f"  ⚠️  Model stability issues detected!")
        
        print(f"\n💾 Detailed results saved to 'rolling_window_logs' table in src/data/logs.db")
        
    except Exception as e:
        logger.error(f"Error in parameter analysis: {e}")
        print("Parameter analysis failed, but logging succeeded.")

else:
    logger.warning("❌ No detailed model info found. Make sure enhanced Cell 8 ran successfully.")
    print("No detailed model information to save.")

logger.info("✅ Enhanced rolling window results logging complete")

logger.info("✅ Rolling window results logging complete")

2025-05-23 17:01:50,706 - sarimax - INFO - 💾 Saving rolling window results...
2025-05-23 17:01:50,707 - sarimax - INFO - 💾 Saving rolling window results with detailed model information...
2025-05-23 17:01:50,710 - sarimax - ERROR - Failed to log detailed info: table rolling_window_logs has no column named model_parameters_json
2025-05-23 17:01:50,711 - sarimax - ERROR - Failed to log detailed info: table rolling_window_logs has no column named model_parameters_json
2025-05-23 17:01:50,712 - sarimax - ERROR - Failed to log detailed info: table rolling_window_logs has no column named model_parameters_json
2025-05-23 17:01:50,714 - sarimax - ERROR - Failed to log detailed info: table rolling_window_logs has no column named model_parameters_json
2025-05-23 17:01:50,716 - sarimax - ERROR - Failed to log detailed info: table rolling_window_logs has no column named model_parameters_json
2025-05-23 17:01:50,717 - sarimax - ERROR - Failed to log detailed info: table rolling_window_logs has no c

Parameter analysis failed, but logging succeeded.


## Summary

This notebook provides a comprehensive framework for time series model testing with:

- **Automated logging** to SQLite database
- **Professional error handling** and timezone management
- **Interactive visualizations** with Plotly
- **Detailed performance metrics** (overall, daily, hourly RMSE)
- **Modular design** for easy extension
- **Optional rolling window validation**

All results are automatically logged to your logs.db database with the structure you specified.