# 07b - SARIMA Demand Forecasting Model

## Business Objective
Build a production-ready demand forecasting system using SARIMA to:
- Predict future demand by product category
- Capture weekly and monthly seasonal patterns
- Support inventory planning with 30/60/90 day forecasts

## Approach
- Aggregate demand to daily level by category
- Use SARIMA (Seasonal ARIMA) for time series forecasting
- 80/20 time-based train-test split
- Evaluate with MAPE, MAE, RMSE

In [None]:
# ============================================================================
# IMPORTS AND CONFIGURATION
# ============================================================================

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')

# Statistical/ML libraries
from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.tsa.stattools import adfuller, acf, pacf
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from statsmodels.stats.diagnostic import acorr_ljungbox
from sklearn.metrics import mean_absolute_error, mean_squared_error
from scipy import stats

# Plot settings
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)
plt.rcParams['font.size'] = 10

# Data directory
DATA_DIR = Path('../..') / 'ml' / 'data' / 'processed'

print("Libraries loaded successfully")

## 1. Data Loading and Preparation

In [None]:
# ============================================================================
# LOAD RAW DATA
# Business Logic: We need order dates, quantities, and categories to
# forecast demand by product category
# ============================================================================

# Load datasets
orders = pd.read_csv(DATA_DIR / 'orders.csv')
order_items = pd.read_csv(DATA_DIR / 'order_items.csv')
products = pd.read_csv(DATA_DIR / 'products.csv')

# Parse dates
orders['OrderDate'] = pd.to_datetime(orders['OrderDate'])

# Merge relevant columns
# Business Logic: Join orders with items to get quantities, then with products to get categories
df = orders[['OrderID', 'OrderDate', 'OrderStatus']].merge(
    order_items[['OrderID', 'ProductID', 'Quantity']], on='OrderID'
).merge(
    products[['ProductID', 'Category']], on='ProductID'
)

print(f"Total records: {len(df):,}")
print(f"Date range: {df['OrderDate'].min().date()} to {df['OrderDate'].max().date()}")
print(f"\nCategories available: {df['Category'].nunique()}")
print(df['Category'].value_counts())

In [None]:
# ============================================================================
# DATA CLEANING
# Business Logic: Remove cancelled orders - they represent demand that
# didn't actually occur and would skew our forecasts
# ============================================================================

print("Order Status Distribution:")
print(df['OrderStatus'].value_counts())

# Filter out cancelled orders
df_clean = df[df['OrderStatus'] != 'Cancelled'].copy()
print(f"\nRecords after removing cancelled: {len(df_clean):,}")

In [None]:
# ============================================================================
# AGGREGATE TO DAILY DEMAND BY CATEGORY
# Business Logic: Forecasting works best with regular time intervals.
# We aggregate to daily totals per category for consistent time series.
# ============================================================================

def prepare_category_data(df, category):
    """
    Prepare daily demand time series for a specific category.
    
    Business Logic:
    - Filter to the selected category
    - Sum quantities by date (daily aggregation)
    - Fill missing dates with 0 (days with no orders)
    - Set proper datetime index for time series analysis
    
    Args:
        df: DataFrame with OrderDate, Quantity, Category columns
        category: Category name to filter
    
    Returns:
        Series with daily demand indexed by date
    """
    # Filter to category
    cat_df = df[df['Category'] == category].copy()
    
    # Aggregate to daily demand
    daily = cat_df.groupby(cat_df['OrderDate'].dt.date)['Quantity'].sum()
    daily.index = pd.to_datetime(daily.index)
    
    # Create complete date range and fill missing days with 0
    # Business Logic: Days with no orders still represent 0 demand
    date_range = pd.date_range(start=daily.index.min(), end=daily.index.max(), freq='D')
    daily = daily.reindex(date_range, fill_value=0)
    daily.index.name = 'Date'
    daily.name = 'Quantity'
    
    return daily

# Prepare data for all categories
categories = df_clean['Category'].unique()
print(f"Categories: {list(categories)}\n")

# Example: Show data for top category
top_category = df_clean.groupby('Category')['Quantity'].sum().idxmax()
sample_data = prepare_category_data(df_clean, top_category)
print(f"Sample data for '{top_category}':")
print(f"  Date range: {sample_data.index.min().date()} to {sample_data.index.max().date()}")
print(f"  Total days: {len(sample_data)}")
print(f"  Mean daily demand: {sample_data.mean():.2f}")
print(f"  Std daily demand: {sample_data.std():.2f}")

## 2. Time Series Analysis Functions

In [None]:
# ============================================================================
# STATIONARITY CHECK
# Business Logic: SARIMA requires stationary data (constant mean/variance).
# If non-stationary, we need differencing (d parameter in SARIMA).
# ============================================================================

def check_stationarity(series, name='Series'):
    """
    Perform Augmented Dickey-Fuller test for stationarity.
    
    Business Logic:
    - Stationary data has constant statistical properties over time
    - Non-stationary data needs differencing before modeling
    - p-value < 0.05 indicates stationarity
    
    Args:
        series: Time series data
        name: Name for display
    
    Returns:
        Boolean indicating if series is stationary
    """
    result = adfuller(series.dropna(), autolag='AIC')
    
    print(f"\nStationarity Test for {name}:")
    print(f"  ADF Statistic: {result[0]:.4f}")
    print(f"  p-value: {result[1]:.6f}")
    print(f"  Critical Values:")
    for key, value in result[4].items():
        print(f"    {key}: {value:.4f}")
    
    is_stationary = result[1] < 0.05
    print(f"  Result: {'✓ STATIONARY' if is_stationary else '✗ NON-STATIONARY'}")
    
    return is_stationary

# Test on sample data
is_stationary = check_stationarity(sample_data, top_category)

In [None]:
# ============================================================================
# TRAIN-TEST SPLIT
# Business Logic: For time series, we must split chronologically.
# We use 80% for training and 20% for testing to evaluate real-world performance.
# ============================================================================

def train_test_split_ts(series, train_ratio=0.8):
    """
    Split time series data into train and test sets.
    
    Business Logic:
    - Time series MUST be split chronologically (not randomly)
    - We train on historical data and test on future data
    - This simulates real-world forecasting scenarios
    
    Args:
        series: Time series data
        train_ratio: Proportion for training (default 80%)
    
    Returns:
        train, test series
    """
    split_idx = int(len(series) * train_ratio)
    train = series[:split_idx]
    test = series[split_idx:]
    
    print(f"Train-Test Split:")
    print(f"  Training: {len(train)} days ({train.index.min().date()} to {train.index.max().date()})")
    print(f"  Testing: {len(test)} days ({test.index.min().date()} to {test.index.max().date()})")
    
    return train, test

# Split sample data
train, test = train_test_split_ts(sample_data, train_ratio=0.8)

## 3. SARIMA Model Building

In [None]:
# ============================================================================
# SARIMA MODEL CLASS
# Business Logic: SARIMA captures both non-seasonal and seasonal patterns.
# For retail demand:
#   - Weekly seasonality (s=7): Weekend vs weekday patterns
#   - Monthly patterns captured through longer seasonal periods
# ============================================================================

class SARIMAForecaster:
    """
    Production-ready SARIMA forecasting model for demand prediction.
    
    SARIMA(p, d, q)(P, D, Q, s) parameters:
    - p: Autoregressive order (how many past values influence current)
    - d: Differencing order (how many times to difference for stationarity)
    - q: Moving average order (how many past errors influence current)
    - P: Seasonal autoregressive order
    - D: Seasonal differencing order
    - Q: Seasonal moving average order
    - s: Seasonal period (7 for weekly, 30 for monthly)
    """
    
    def __init__(self, seasonal_period=7):
        """
        Initialize forecaster.
        
        Args:
            seasonal_period: Period for seasonality (default 7 for weekly)
        """
        self.seasonal_period = seasonal_period
        self.model = None
        self.fitted_model = None
        self.order = None
        self.seasonal_order = None
        
    def find_best_parameters(self, train_data, p_range=(0, 2), d_range=(0, 2), 
                             q_range=(0, 2), P_range=(0, 1), D_range=(0, 1), Q_range=(0, 1)):
        """
        Find optimal SARIMA parameters using AIC criterion.
        
        Business Logic:
        - AIC (Akaike Information Criterion) balances model fit vs complexity
        - Lower AIC indicates better model
        - We test multiple parameter combinations
        
        Args:
            train_data: Training time series
            p_range, d_range, q_range: Non-seasonal parameter ranges
            P_range, D_range, Q_range: Seasonal parameter ranges
        
        Returns:
            Best (order, seasonal_order) tuple
        """
        best_aic = float('inf')
        best_params = None
        results = []
        
        print("Searching for best SARIMA parameters...")
        
        for p in range(p_range[0], p_range[1] + 1):
            for d in range(d_range[0], d_range[1] + 1):
                for q in range(q_range[0], q_range[1] + 1):
                    for P in range(P_range[0], P_range[1] + 1):
                        for D in range(D_range[0], D_range[1] + 1):
                            for Q in range(Q_range[0], Q_range[1] + 1):
                                try:
                                    order = (p, d, q)
                                    seasonal_order = (P, D, Q, self.seasonal_period)
                                    
                                    model = SARIMAX(
                                        train_data,
                                        order=order,
                                        seasonal_order=seasonal_order,
                                        enforce_stationarity=False,
                                        enforce_invertibility=False
                                    )
                                    fitted = model.fit(disp=False, maxiter=100)
                                    
                                    results.append({
                                        'order': order,
                                        'seasonal_order': seasonal_order,
                                        'aic': fitted.aic
                                    })
                                    
                                    if fitted.aic < best_aic:
                                        best_aic = fitted.aic
                                        best_params = (order, seasonal_order)
                                        
                                except Exception as e:
                                    continue
        
        print(f"\nBest parameters found:")
        print(f"  Order: {best_params[0]}")
        print(f"  Seasonal Order: {best_params[1]}")
        print(f"  AIC: {best_aic:.2f}")
        
        self.order = best_params[0]
        self.seasonal_order = best_params[1]
        
        return best_params
    
    def fit(self, train_data, order=None, seasonal_order=None):
        """
        Fit SARIMA model to training data.
        
        Args:
            train_data: Training time series
            order: (p, d, q) tuple or None to use auto-found
            seasonal_order: (P, D, Q, s) tuple or None to use auto-found
        
        Returns:
            Fitted model
        """
        if order is not None:
            self.order = order
        if seasonal_order is not None:
            self.seasonal_order = seasonal_order
            
        if self.order is None or self.seasonal_order is None:
            raise ValueError("Must set parameters via find_best_parameters() or provide them directly")
        
        print(f"Fitting SARIMA{self.order}x{self.seasonal_order}...")
        
        self.model = SARIMAX(
            train_data,
            order=self.order,
            seasonal_order=self.seasonal_order,
            enforce_stationarity=False,
            enforce_invertibility=False
        )
        self.fitted_model = self.model.fit(disp=False, maxiter=200)
        
        print(f"Model fitted successfully")
        print(f"  AIC: {self.fitted_model.aic:.2f}")
        print(f"  BIC: {self.fitted_model.bic:.2f}")
        
        return self.fitted_model
    
    def forecast(self, steps, return_conf_int=True, alpha=0.05):
        """
        Generate forecasts for future periods.
        
        Business Logic:
        - Forecasts include confidence intervals for uncertainty quantification
        - 95% confidence interval (alpha=0.05) is standard for planning
        
        Args:
            steps: Number of periods to forecast
            return_conf_int: Whether to return confidence intervals
            alpha: Significance level for confidence intervals
        
        Returns:
            DataFrame with forecast, lower_ci, upper_ci columns
        """
        if self.fitted_model is None:
            raise ValueError("Must fit model before forecasting")
        
        # Generate forecast
        forecast_result = self.fitted_model.get_forecast(steps=steps)
        forecast_mean = forecast_result.predicted_mean
        conf_int = forecast_result.conf_int(alpha=alpha)
        
        # Create output DataFrame
        forecast_df = pd.DataFrame({
            'Forecast': forecast_mean.values,
            'Lower_CI': conf_int.iloc[:, 0].values,
            'Upper_CI': conf_int.iloc[:, 1].values
        }, index=forecast_mean.index)
        
        # Business Logic: Demand cannot be negative
        forecast_df['Forecast'] = forecast_df['Forecast'].clip(lower=0)
        forecast_df['Lower_CI'] = forecast_df['Lower_CI'].clip(lower=0)
        
        return forecast_df
    
    def evaluate(self, test_data):
        """
        Evaluate model performance on test data.
        
        Business Logic:
        - MAE: Average absolute error (interpretable in units)
        - RMSE: Penalizes large errors more heavily
        - MAPE: Percentage error (scale-independent)
        
        Args:
            test_data: Actual test series
        
        Returns:
            Dictionary with metrics
        """
        # Generate predictions for test period
        forecast_df = self.forecast(steps=len(test_data))
        predictions = forecast_df['Forecast'].values
        actual = test_data.values
        
        # Calculate metrics
        mae = mean_absolute_error(actual, predictions)
        rmse = np.sqrt(mean_squared_error(actual, predictions))
        
        # MAPE: Handle zeros in actual values
        mask = actual != 0
        if mask.sum() > 0:
            mape = np.mean(np.abs((actual[mask] - predictions[mask]) / actual[mask])) * 100
        else:
            mape = np.nan
        
        metrics = {
            'MAE': mae,
            'RMSE': rmse,
            'MAPE': mape
        }
        
        return metrics, forecast_df
    
    def check_residuals(self):
        """
        Perform residual diagnostics.
        
        Business Logic:
        - Good model should have random residuals (no patterns)
        - Residuals should be normally distributed
        - No autocorrelation in residuals
        
        Returns:
            Dictionary with diagnostic test results
        """
        if self.fitted_model is None:
            raise ValueError("Must fit model first")
        
        residuals = self.fitted_model.resid
        
        # Ljung-Box test for autocorrelation
        lb_test = acorr_ljungbox(residuals, lags=[10], return_df=True)
        lb_pvalue = lb_test['lb_pvalue'].values[0]
        
        # Normality test
        jb_stat, jb_pvalue = stats.jarque_bera(residuals)
        
        diagnostics = {
            'residual_mean': residuals.mean(),
            'residual_std': residuals.std(),
            'ljung_box_pvalue': lb_pvalue,
            'no_autocorrelation': lb_pvalue > 0.05,
            'jarque_bera_pvalue': jb_pvalue,
            'normality': jb_pvalue > 0.05
        }
        
        return diagnostics, residuals

## 4. Visualization Functions

In [None]:
# ============================================================================
# VISUALIZATION FUNCTIONS
# Business Logic: Visualizations help stakeholders understand forecasts
# and build confidence in the model
# ============================================================================

def plot_actual_vs_predicted(actual, predicted, title='Actual vs Predicted'):
    """
    Plot actual vs predicted values.
    
    Business Logic:
    - Shows how well the model captures actual patterns
    - Helps identify periods where model under/over predicts
    """
    fig, ax = plt.subplots(figsize=(14, 6))
    
    ax.plot(actual.index, actual.values, label='Actual', linewidth=2, color='blue')
    ax.plot(predicted.index, predicted.values, label='Predicted', linewidth=2, 
            color='red', linestyle='--')
    
    ax.set_xlabel('Date')
    ax.set_ylabel('Quantity')
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

def plot_forecast_with_ci(train, test, forecast_df, title='Demand Forecast'):
    """
    Plot forecast with confidence intervals.
    
    Business Logic:
    - Confidence intervals show forecast uncertainty
    - Wider intervals = more uncertainty
    - Helps with safety stock planning (use upper CI for buffer)
    """
    fig, ax = plt.subplots(figsize=(14, 7))
    
    # Plot training data (last 90 days for clarity)
    train_plot = train[-90:] if len(train) > 90 else train
    ax.plot(train_plot.index, train_plot.values, label='Historical', 
            linewidth=2, color='blue')
    
    # Plot test data (actual)
    if test is not None:
        ax.plot(test.index, test.values, label='Actual', 
                linewidth=2, color='green')
    
    # Plot forecast
    ax.plot(forecast_df.index, forecast_df['Forecast'], 
            label='Forecast', linewidth=2, color='red')
    
    # Plot confidence interval
    ax.fill_between(forecast_df.index, 
                    forecast_df['Lower_CI'], 
                    forecast_df['Upper_CI'],
                    color='red', alpha=0.2, label='95% CI')
    
    # Add vertical line at forecast start
    ax.axvline(x=forecast_df.index[0], color='gray', linestyle=':', linewidth=1)
    
    ax.set_xlabel('Date')
    ax.set_ylabel('Quantity')
    ax.set_title(title)
    ax.legend(loc='upper left')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

def plot_residual_diagnostics(residuals, title='Residual Diagnostics'):
    """
    Plot residual diagnostic charts.
    
    Business Logic:
    - Residuals should look like random noise
    - Patterns indicate model is missing something
    """
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Time series of residuals
    axes[0, 0].plot(residuals)
    axes[0, 0].axhline(y=0, color='r', linestyle='--')
    axes[0, 0].set_title('Residuals Over Time')
    axes[0, 0].set_xlabel('Time')
    axes[0, 0].set_ylabel('Residual')
    
    # Histogram
    axes[0, 1].hist(residuals, bins=30, density=True, alpha=0.7, edgecolor='black')
    xmin, xmax = axes[0, 1].get_xlim()
    x = np.linspace(xmin, xmax, 100)
    axes[0, 1].plot(x, stats.norm.pdf(x, residuals.mean(), residuals.std()), 'r-', linewidth=2)
    axes[0, 1].set_title('Residual Distribution')
    axes[0, 1].set_xlabel('Residual')
    
    # Q-Q plot
    stats.probplot(residuals, dist="norm", plot=axes[1, 0])
    axes[1, 0].set_title('Q-Q Plot')
    
    # ACF of residuals
    plot_acf(residuals, ax=axes[1, 1], lags=30)
    axes[1, 1].set_title('ACF of Residuals')
    
    plt.suptitle(title, fontsize=14, y=1.02)
    plt.tight_layout()
    return fig

## 5. Run Forecasting for a Category

In [None]:
# ============================================================================
# FULL FORECASTING PIPELINE FOR ONE CATEGORY
# ============================================================================

def run_forecast_pipeline(df, category, forecast_horizons=[30, 60, 90]):
    """
    Complete forecasting pipeline for a category.
    
    Business Logic:
    - 30 days: Short-term operational planning
    - 60 days: Medium-term inventory planning
    - 90 days: Long-term strategic planning
    
    Args:
        df: Clean DataFrame with OrderDate, Quantity, Category
        category: Category to forecast
        forecast_horizons: List of forecast periods
    
    Returns:
        Dictionary with model, metrics, and forecasts
    """
    print(f"\n{'='*70}")
    print(f"FORECASTING: {category}")
    print(f"{'='*70}")
    
    # Step 1: Prepare data
    print("\n[Step 1] Preparing data...")
    data = prepare_category_data(df, category)
    
    # Step 2: Check stationarity
    print("\n[Step 2] Checking stationarity...")
    is_stationary = check_stationarity(data, category)
    
    # Step 3: Train-test split (80/20)
    print("\n[Step 3] Splitting data...")
    train, test = train_test_split_ts(data, train_ratio=0.8)
    
    # Step 4: Initialize and fit model
    print("\n[Step 4] Building SARIMA model...")
    forecaster = SARIMAForecaster(seasonal_period=7)  # Weekly seasonality
    
    # Use default good parameters to save time (you can enable grid search)
    # forecaster.find_best_parameters(train)
    forecaster.fit(train, order=(1, 1, 1), seasonal_order=(1, 1, 1, 7))
    
    # Step 5: Evaluate on test set
    print("\n[Step 5] Evaluating model...")
    metrics, test_forecast = forecaster.evaluate(test)
    
    print(f"\nTest Set Performance:")
    print(f"  MAE:  {metrics['MAE']:.2f} units")
    print(f"  RMSE: {metrics['RMSE']:.2f} units")
    print(f"  MAPE: {metrics['MAPE']:.2f}%")
    
    # Step 6: Residual diagnostics
    print("\n[Step 6] Checking residuals...")
    diagnostics, residuals = forecaster.check_residuals()
    
    print(f"  Residual Mean: {diagnostics['residual_mean']:.4f} (should be ~0)")
    print(f"  No Autocorrelation: {'✓ PASS' if diagnostics['no_autocorrelation'] else '✗ FAIL'}")
    print(f"  Normality: {'✓ PASS' if diagnostics['normality'] else '✗ FAIL (acceptable)'}")
    
    # Step 7: Generate future forecasts
    print("\n[Step 7] Generating forecasts...")
    
    # Refit on full data for production forecasts
    forecaster.fit(data, order=(1, 1, 1), seasonal_order=(1, 1, 1, 7))
    
    forecasts = {}
    for horizon in forecast_horizons:
        forecast_df = forecaster.forecast(steps=horizon)
        forecasts[f'{horizon}_day'] = forecast_df
        total_forecast = forecast_df['Forecast'].sum()
        print(f"  {horizon}-day forecast: {total_forecast:.0f} total units")
    
    return {
        'category': category,
        'train': train,
        'test': test,
        'test_forecast': test_forecast,
        'metrics': metrics,
        'diagnostics': diagnostics,
        'residuals': residuals,
        'forecasts': forecasts,
        'model': forecaster
    }

# Run for top category
results = run_forecast_pipeline(df_clean, top_category)

In [None]:
# ============================================================================
# VISUALIZATIONS
# ============================================================================

# 1. Actual vs Predicted on test set
test_actual = results['test']
test_pred = results['test_forecast']['Forecast']
test_pred.index = test_actual.index  # Align indices

fig1 = plot_actual_vs_predicted(
    test_actual, 
    test_pred,
    title=f"Actual vs Predicted - {results['category']} (Test Period)"
)
plt.show()

# 2. Forecast with confidence intervals (90-day)
fig2 = plot_forecast_with_ci(
    results['train'],
    results['test'],
    results['forecasts']['90_day'],
    title=f"90-Day Demand Forecast - {results['category']}"
)
plt.show()

# 3. Residual diagnostics
fig3 = plot_residual_diagnostics(
    results['residuals'],
    title=f"Residual Diagnostics - {results['category']}"
)
plt.show()

## 6. Forecast All Categories

In [None]:
# ============================================================================
# RUN FORECASTING FOR ALL CATEGORIES
# Business Logic: Generate forecasts for entire product portfolio
# ============================================================================

def forecast_all_categories(df, categories=None):
    """
    Run forecasting pipeline for multiple categories.
    
    Args:
        df: Clean DataFrame
        categories: List of categories (None = all)
    
    Returns:
        Dictionary mapping category to results
    """
    if categories is None:
        categories = df['Category'].unique()
    
    all_results = {}
    summary_data = []
    
    for category in categories:
        try:
            results = run_forecast_pipeline(df, category)
            all_results[category] = results
            
            summary_data.append({
                'Category': category,
                'MAE': results['metrics']['MAE'],
                'RMSE': results['metrics']['RMSE'],
                'MAPE': results['metrics']['MAPE'],
                '30_Day_Forecast': results['forecasts']['30_day']['Forecast'].sum(),
                '60_Day_Forecast': results['forecasts']['60_day']['Forecast'].sum(),
                '90_Day_Forecast': results['forecasts']['90_day']['Forecast'].sum()
            })
        except Exception as e:
            print(f"\nError forecasting {category}: {str(e)}")
            continue
    
    summary_df = pd.DataFrame(summary_data)
    return all_results, summary_df

# Run for all categories
all_results, summary_df = forecast_all_categories(df_clean)

print("\n" + "="*70)
print("FORECAST SUMMARY - ALL CATEGORIES")
print("="*70)
print(summary_df.to_string(index=False))

In [None]:
# ============================================================================
# SUMMARY VISUALIZATIONS
# ============================================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. MAPE by Category
summary_sorted = summary_df.sort_values('MAPE')
axes[0, 0].barh(summary_sorted['Category'], summary_sorted['MAPE'])
axes[0, 0].set_xlabel('MAPE (%)')
axes[0, 0].set_title('Forecast Accuracy by Category (MAPE)')
axes[0, 0].axvline(x=20, color='r', linestyle='--', label='20% threshold')

# 2. 90-Day Forecast by Category
summary_sorted = summary_df.sort_values('90_Day_Forecast', ascending=True)
axes[0, 1].barh(summary_sorted['Category'], summary_sorted['90_Day_Forecast'])
axes[0, 1].set_xlabel('Total Units')
axes[0, 1].set_title('90-Day Demand Forecast by Category')

# 3. MAE vs RMSE comparison
x = np.arange(len(summary_df))
width = 0.35
axes[1, 0].bar(x - width/2, summary_df['MAE'], width, label='MAE')
axes[1, 0].bar(x + width/2, summary_df['RMSE'], width, label='RMSE')
axes[1, 0].set_xticks(x)
axes[1, 0].set_xticklabels(summary_df['Category'], rotation=45, ha='right')
axes[1, 0].set_ylabel('Error (Units)')
axes[1, 0].set_title('MAE vs RMSE by Category')
axes[1, 0].legend()

# 4. Forecast horizons comparison
summary_df.plot(x='Category', y=['30_Day_Forecast', '60_Day_Forecast', '90_Day_Forecast'],
               kind='bar', ax=axes[1, 1])
axes[1, 1].set_title('Forecast by Horizon')
axes[1, 1].set_ylabel('Total Units')
axes[1, 1].tick_params(axis='x', rotation=45)
axes[1, 1].legend(['30-Day', '60-Day', '90-Day'])

plt.tight_layout()
plt.show()

## 7. Export Results

In [None]:
# ============================================================================
# EXPORT FORECAST RESULTS
# Business Logic: Save forecasts for use in inventory planning systems
# ============================================================================

# Save summary
summary_df.to_csv(DATA_DIR / 'forecast_summary.csv', index=False)
print(f"Saved forecast summary to: {DATA_DIR / 'forecast_summary.csv'}")

# Save detailed forecasts for each category
for category, results in all_results.items():
    category_clean = category.replace(' ', '_').replace('&', 'and')
    
    # Save 90-day forecast with confidence intervals
    forecast_90 = results['forecasts']['90_day'].copy()
    forecast_90['Category'] = category
    forecast_90.to_csv(DATA_DIR / f'forecast_90day_{category_clean}.csv')

print(f"\nDetailed forecasts saved for {len(all_results)} categories")

In [None]:
# ============================================================================
# FINAL SUMMARY
# ============================================================================

print("="*70)
print("SARIMA DEMAND FORECASTING - FINAL SUMMARY")
print("="*70)

print(f"""
MODEL SPECIFICATION:
  Type: SARIMA(1,1,1)(1,1,1,7)
  Seasonal Period: 7 (weekly)
  Train/Test Split: 80/20 (time-based)

OVERALL PERFORMANCE:
  Average MAPE: {summary_df['MAPE'].mean():.2f}%
  Average MAE: {summary_df['MAE'].mean():.2f} units
  Average RMSE: {summary_df['RMSE'].mean():.2f} units

BEST PERFORMING CATEGORY:
  {summary_df.loc[summary_df['MAPE'].idxmin(), 'Category']} (MAPE: {summary_df['MAPE'].min():.2f}%)

90-DAY TOTAL FORECAST:
  {summary_df['90_Day_Forecast'].sum():.0f} units across all categories

FILES GENERATED:
  - forecast_summary.csv (all categories)
  - forecast_90day_[category].csv (detailed forecasts)

BUSINESS RECOMMENDATIONS:
  1. Use forecast + Upper_CI for safety stock planning
  2. Categories with MAPE > 25% need additional review
  3. Re-run model monthly to incorporate new data
  4. Monitor forecast accuracy and adjust as needed
""")