In [None]:
import pandas as pd
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error
import numpy as np
import matplotlib.pyplot as plt
from pmdarima import auto_arima


In [None]:
# Load and prepare data
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')

# Create lagged demand features (example: lag 1 and lag 7)
df['demand_lag1'] = df['demand'].shift(1)
df['demand_lag7'] = df['demand'].shift(7)

# Drop rows with NA values due to lagging
df = df.dropna()

# Define target and exogenous variables
y = df['demand']
exog = df[['temperature', 'holiday', 'production', 'demand_lag1', 'demand_lag7']]

# Train-test split (e.g., last 20% as test)
split_idx = int(len(df) * 0.8)
y_train, y_test = y[:split_idx], y[split_idx:]
exog_train, exog_test = exog[:split_idx], exog[split_idx:]

# Automatically find best SARIMA order using auto_arima
stepwise_model = auto_arima(y_train,
                             exogenous=exog_train,
                             seasonal=True,
                             m=7,  # weekly seasonality
                             stepwise=True,
                             suppress_warnings=True,
                             error_action="ignore",
                             trace=True)

# Extract best order and seasonal order
order = stepwise_model.order
seasonal_order = stepwise_model.seasonal_order

# Fit SARIMAX model
model = SARIMAX(y_train,
                exog=exog_train,
                order=order,
                seasonal_order=seasonal_order,
                enforce_stationarity=False,
                enforce_invertibility=False)

results = model.fit(disp=False)

# Forecast
forecast = results.predict(start=y_test.index[0], end=y_test.index[-1], exog=exog_test)

# Evaluation
mse = mean_squared_error(y_test, forecast)
print(f"Test MSE: {mse:.2f}")



In [None]:
# Plot results
plt.figure(figsize=(14,6))
plt.plot(y_train[-60:], label='Train')
plt.plot(y_test, label='Test')
plt.plot(forecast, label='Forecast', linestyle='--')
plt.legend()
plt.title("Energy Demand Forecast")
plt.show()