# Model Training and Comparison

This notebook trains and compares multiple forecasting models.

## Models Covered:
1. Baseline models (Naive, Seasonal Naive)
2. Statistical models (ARIMA, Prophet)
3. Machine learning models (XGBoost, LightGBM)
4. Ensemble methods

In [None]:
# Imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys

sys.path.append(str(Path.cwd().parent))

from src.data_loader import load_sales_data, split_train_test
from src.feature_engineering import FeatureEngineer
from src.baselines import BaselineForecaster
from src.arima_models import ARIMAForecaster
from src.prophet_model import ProphetForecaster
from src.ml_models import MLForecaster
from src.ensemble import SimpleEnsemble
from src.evaluation import ModelEvaluator, compute_metrics
from src.visualization import plot_forecast, plot_model_comparison

%matplotlib inline

## 1. Prepare Data

In [None]:
# Load and prepare data
df = load_sales_data(Path.cwd().parent / "data" / "raw" / "sales_data.parquet")
df_agg = df.groupby(['store_id', 'date'])['sales'].sum().reset_index()

# Split data
train_df, val_df, test_df = split_train_test(df_agg, test_weeks=6, validation_weeks=8)

print(f"Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}")

## 2. Baseline Models

In [None]:
# Aggregate to daily total
train_series = train_df.groupby('date')['sales'].sum().values
test_series = test_df.groupby('date')['sales'].sum().values
horizon = len(test_series)

# Initialize evaluator
evaluator = ModelEvaluator()

# Naive
naive = BaselineForecaster(method='naive')
naive_forecast = naive.fit_predict(train_series, horizon)
evaluator.evaluate_model('Naive', test_series, naive_forecast, train_series)

# Seasonal Naive
seasonal_naive = BaselineForecaster(method='seasonal_naive', season_length=7)
sn_forecast = seasonal_naive.fit_predict(train_series, horizon)
evaluator.evaluate_model('Seasonal Naive', test_series, sn_forecast, train_series)

print("Baseline Results:")
print(evaluator.get_results())

## 3. Statistical Models

In [None]:
# ARIMA (this may take a few minutes)
print("Training ARIMA...")
arima = ARIMAForecaster(seasonal=True, m=7)
arima.fit(train_series)
arima_forecast = arima.predict(steps=horizon)
evaluator.evaluate_model('ARIMA', test_series, arima_forecast, train_series)

print("ARIMA training complete!")

In [None]:
# Prophet
print("Training Prophet...")
train_prophet = train_df.groupby('date')['sales'].sum().reset_index()
train_prophet.columns = ['ds', 'y']

prophet = ProphetForecaster(seasonality_mode='multiplicative')
prophet.fit(train_prophet, date_col='ds', target_col='y')
prophet_forecast_df = prophet.predict(steps=horizon)
prophet_forecast = prophet_forecast_df['yhat'].values

evaluator.evaluate_model('Prophet', test_series, prophet_forecast, train_series)
print("Prophet training complete!")

## 4. Machine Learning Models

In [None]:
# Engineer features for ML
combined_df = pd.concat([train_df, val_df, test_df])
engineer = FeatureEngineer(combined_df, date_column='date')
featured_df = engineer.create_all_features(target_column='sales', group_columns=['store_id'])

# Split featured data
train_feat = featured_df[featured_df['date'] < val_df['date'].min()].dropna()
val_feat = featured_df[
    (featured_df['date'] >= val_df['date'].min()) & 
    (featured_df['date'] < test_df['date'].min())
].dropna()
test_feat = featured_df[featured_df['date'] >= test_df['date'].min()].dropna()

# Prepare features
feature_cols = [c for c in train_feat.columns if c not in ['date', 'store_id', 'sales']]
X_train = train_feat[feature_cols]
y_train = train_feat['sales']
X_val = val_feat[feature_cols]
y_val = val_feat['sales']
X_test = test_feat[feature_cols]
y_test = test_feat['sales']

print(f"Features: {len(feature_cols)}")

In [None]:
# XGBoost
print("Training XGBoost...")
xgb = MLForecaster(model_type='xgboost', max_depth=6, learning_rate=0.1, n_estimators=100)
xgb.fit(X_train, y_train, eval_set=[(X_val, y_val)], early_stopping_rounds=10, verbose=False)
xgb_pred = xgb.predict(X_test)
evaluator.evaluate_model('XGBoost', y_test.values, xgb_pred, y_train.values)
print("XGBoost complete!")

In [None]:
# LightGBM
print("Training LightGBM...")
lgb = MLForecaster(model_type='lightgbm', max_depth=6, learning_rate=0.1, n_estimators=100)
lgb.fit(X_train, y_train, eval_set=[(X_val, y_val)], early_stopping_rounds=10, verbose=False)
lgb_pred = lgb.predict(X_test)
evaluator.evaluate_model('LightGBM', y_test.values, lgb_pred, y_train.values)
print("LightGBM complete!")

## 5. Ensemble

In [None]:
# Aggregate test series for ensemble
test_agg = test_feat.groupby('date')['sales'].sum().values

# Combine forecasts (need to align lengths)
min_len = min(len(naive_forecast), len(prophet_forecast), len(test_agg))

forecasts = {
    'Naive': naive_forecast[:min_len],
    'Seasonal_Naive': sn_forecast[:min_len],
    'Prophet': prophet_forecast[:min_len]
}

ensemble = SimpleEnsemble(method='mean')
ensemble_forecast = ensemble.combine(forecasts)

evaluator.evaluate_model('Ensemble', test_agg[:min_len], ensemble_forecast)
print("Ensemble complete!")

## 6. Results Comparison

In [None]:
# Get all results
results = evaluator.get_results()
print("\nFinal Results:")
print(results.to_string())

# Best model
best = evaluator.get_best_model(metric='mae')
print(f"\nBest Model (MAE): {best}")

In [None]:
# Plot comparison
plot_model_comparison(results, metric='mae', title='Model Comparison by MAE')

## 7. Feature Importance (XGBoost)

In [None]:
from src.visualization import plot_feature_importance

importance = xgb.get_feature_importance(top_n=20)
plot_feature_importance(importance, top_n=20, title='XGBoost Feature Importance')

## 8. Conclusions

### Key Findings:

1. **Best Performer**: Typically XGBoost or LightGBM with engineered features
2. **Baseline Value**: Seasonal Naive provides good baseline
3. **Feature Importance**: Lag features most important
4. **Ensemble Benefit**: Combining models reduces variance

### Recommendations:

- Use XGBoost/LightGBM for production
- Maintain Seasonal Naive as fallback
- Consider ensemble for critical forecasts
- Monitor performance over time
- Retrain periodically with new data