In [1]:
!pip install pmdarima

Collecting pmdarima
  Downloading pmdarima-2.1.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (8.5 kB)
Downloading pmdarima-2.1.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (689 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m689.1/689.1 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hInstalling collected packages: pmdarima
Successfully installed pmdarima-2.1.1


In [2]:
import os

import pandas as pd
import numpy as np

import seaborn as sns
import altair as alt
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

%matplotlib inline
sns.set_style("whitegrid")

### Utility functions

In [3]:
def load_smard_csv(file_name):
    """
    Loads and cleans SMARD-specific CSV formatting.
    """
    path = os.path.join(DATA_DIR, file_name)
    df = pd.read_csv(path, sep=';', decimal=',', thousands='.')
    
    # Convert timestamps and set index
    df['timestamp'] = pd.to_datetime(df['Datum von'], format='%d.%m.%Y %H:%M')
    df = df.set_index('timestamp').drop(['Datum von', 'Datum bis'], axis=1)
    
    # Clean numeric data (replace SMARD '-' with NaN)
    df = df.replace('-', np.nan).apply(pd.to_numeric, errors='coerce')
    
    # Resample to Hourly (Standard for Day-Ahead Markets)
    return df.resample('h').mean()

def create_generalized_dataset():
    """
    Merges all files and creates a base feature set.
    """
    print("Loading datasets...")
    df_price = load_smard_csv(FILES['price'])
    df_forecast = load_smard_csv(FILES['forecast'])
    df_load = load_smard_csv(FILES['load'])
    df_gen = load_smard_csv(FILES['gen'])

    # Select relevant columns and rename for a clean API
    data = pd.DataFrame(index=df_price.index)
    data['target_price'] = df_price['Deutschland/Luxemburg [€/MWh] Originalauflösungen']
    
    # Forecasts (Known in advance - No shift needed for renewable forecasts)
    data['fc_solar'] = df_forecast['Photovoltaik [MWh] Originalauflösungen']
    data['fc_wind_on'] = df_forecast['Wind Onshore [MWh] Originalauflösungen']
    data['fc_wind_off'] = df_forecast['Wind Offshore [MWh] Originalauflösungen']
    
    # Total renewable forecast
    data['fc_renewables_total'] = data['fc_solar'] + data['fc_wind_on'] + data['fc_wind_off']
    
    # Demand and Conventional (Historical - Shifted by 24h to avoid leaking future data)
    data['load_lag_24h'] = df_load['Netzlast [MWh] Originalauflösungen'].shift(24)
    data['load_lag_168h'] = df_load['Netzlast [MWh] Originalauflösungen'].shift(168)
    data['gen_lignite_lag_24h'] = df_gen['Braunkohle [MWh] Originalauflösungen'].shift(24)
    data['gen_gas_lag_24h'] = df_gen['Erdgas [MWh] Originalauflösungen'].shift(24)
    
    # Target Lags (Crucial for Time Series)
    data['price_lag_24h'] = data['target_price'].shift(24)
    data['price_lag_168h'] = data['target_price'].shift(168)  # 1 week ago
    data['price_lag_48h'] = data['target_price'].shift(48)    # 2 days ago

    # FIXED: Net Load Forecast (using lagged load to avoid data leakage)
    # This estimates residual demand after renewables
    data['net_load_forecast'] = data['load_lag_24h'] - data['fc_renewables_total']
    
    # Rolling statistics (capture recent trends without leakage)
    data['price_rolling_mean_24h'] = data['target_price'].shift(1).rolling(24).mean()
    data['price_rolling_std_24h'] = data['target_price'].shift(1).rolling(24).std()
    data['price_rolling_mean_168h'] = data['target_price'].shift(1).rolling(168).mean()
    
    # Calendar Features
    data['hour'] = data.index.hour
    data['day_of_week'] = data.index.dayofweek
    data['month'] = data.index.month
    data['is_weekend'] = data['day_of_week'].isin([5, 6]).astype(int)
    
    # Cyclical encoding for hour (helps capture daily patterns)
    data['hour_sin'] = np.sin(2 * np.pi * data['hour'] / 24)
    data['hour_cos'] = np.cos(2 * np.pi * data['hour'] / 24)
    
    # Cyclical encoding for month (helps capture seasonal patterns)
    data['month_sin'] = np.sin(2 * np.pi * data['month'] / 12)
    data['month_cos'] = np.cos(2 * np.pi * data['month'] / 12)

    return data.dropna()

### Load Data

In [4]:
# Configuration
DATA_DIR = '/kaggle/input/electricppd/smard_data'
FILES = {
    'price': 'Gro_handelspreise_202101010000_202601010000_Viertelstunde.csv',
    'forecast': 'Prognostizierte_Erzeugung_Day-Ahead_202101010000_202601010000_Viertelstunde_Stunde.csv',
    'gen': 'Realisierte_Erzeugung_202101010000_202601010000_Viertelstunde.csv',
    'load': 'Realisierter_Stromverbrauch_202101010000_202601010000_Viertelstunde (1).csv'
}

In [5]:
# Create the master dataset
master_df = create_generalized_dataset()

Loading datasets...


### EDA

In [6]:
master_df.head(3)

Unnamed: 0_level_0,target_price,fc_solar,fc_wind_on,fc_wind_off,fc_renewables_total,load_lag_24h,load_lag_168h,gen_lignite_lag_24h,gen_gas_lag_24h,price_lag_24h,...,price_rolling_std_24h,price_rolling_mean_168h,hour,day_of_week,month,is_weekend,hour_sin,hour_cos,month_sin,month_cos
timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2021-01-08 00:00:00,50.53,0.0,1606.5,263.125,1869.625,13187.4375,11142.3125,3644.125,2703.9375,51.03,...,17.600125,52.078333,0,4,1,0,0.0,1.0,0.5,0.866025
2021-01-08 01:00:00,48.43,0.0,1583.375,285.4375,1868.8125,12739.3125,10701.5,3555.6875,2713.75,50.18,...,17.624979,52.07631,1,4,1,0,0.258819,0.965926,0.5,0.866025
2021-01-08 02:00:00,47.24,0.0,1564.625,290.3125,1854.9375,12636.4375,10262.4375,3414.0,2675.875,48.97,...,17.71783,52.077738,2,4,1,0,0.5,0.866025,0.5,0.866025


In [7]:
master_df.tail(3)

Unnamed: 0_level_0,target_price,fc_solar,fc_wind_on,fc_wind_off,fc_renewables_total,load_lag_24h,load_lag_168h,gen_lignite_lag_24h,gen_gas_lag_24h,price_lag_24h,...,price_rolling_std_24h,price_rolling_mean_168h,hour,day_of_week,month,is_weekend,hour_sin,hour_cos,month_sin,month_cos
timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2025-12-31 21:00:00,79.7975,0.0,6812.82,1339.775,8152.595,13252.2425,12349.735,2075.6475,2388.415,93.3225,...,8.973217,88.491563,21,2,12,0,-0.707107,0.707107,-2.449294e-16,1.0
2025-12-31 22:00:00,81.39,0.0,6957.8875,1347.2725,8305.16,12795.5425,12194.0025,2061.3225,2358.485,92.045,...,9.094905,88.56186,22,2,12,0,-0.5,0.866025,-2.449294e-16,1.0
2025-12-31 23:00:00,76.4475,0.0,7298.6075,1371.8875,8670.495,12253.7775,11676.19,1891.715,2161.3325,85.8075,...,9.156241,88.635565,23,2,12,0,-0.258819,0.965926,-2.449294e-16,1.0


In [8]:
print(f"Master dataframe from {master_df.index.min()} to {master_df.index.max()}")
print(f"Full Dataset Shape: {master_df.shape}")

Master dataframe from 2021-01-08 00:00:00 to 2025-12-31 23:00:00
Full Dataset Shape: (42811, 24)


In [9]:
master_df.info()

<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 42811 entries, 2021-01-08 00:00:00 to 2025-12-31 23:00:00
Data columns (total 24 columns):
 #   Column                   Non-Null Count  Dtype  
---  ------                   --------------  -----  
 0   target_price             42811 non-null  float64
 1   fc_solar                 42811 non-null  float64
 2   fc_wind_on               42811 non-null  float64
 3   fc_wind_off              42811 non-null  float64
 4   fc_renewables_total      42811 non-null  float64
 5   load_lag_24h             42811 non-null  float64
 6   load_lag_168h            42811 non-null  float64
 7   gen_lignite_lag_24h      42811 non-null  float64
 8   gen_gas_lag_24h          42811 non-null  float64
 9   price_lag_24h            42811 non-null  float64
 10  price_lag_168h           42811 non-null  float64
 11  price_lag_48h            42811 non-null  float64
 12  net_load_forecast        42811 non-null  float64
 13  price_rolling_mean_24h   42811 non-null  

In [10]:
master_df.describe()

Unnamed: 0,target_price,fc_solar,fc_wind_on,fc_wind_off,fc_renewables_total,load_lag_24h,load_lag_168h,gen_lignite_lag_24h,gen_gas_lag_24h,price_lag_24h,...,price_rolling_std_24h,price_rolling_mean_168h,hour,day_of_week,month,is_weekend,hour_sin,hour_cos,month_sin,month_cos
count,42811.0,42811.0,42811.0,42811.0,42811.0,42811.0,42811.0,42811.0,42811.0,42811.0,...,42811.0,42811.0,42811.0,42811.0,42811.0,42811.0,42811.0,42811.0,42811.0,42811.0
mean,119.821928,1687.115563,2986.904932,703.635016,5377.655511,13558.031436,13556.685526,2387.203196,1565.786322,119.802584,...,36.525813,119.691216,11.50111,2.99965,6.604891,0.285791,-5.8e-05,-0.0001011452,-0.02539619,-0.0006053425
std,101.186744,2580.089898,2400.599259,468.947243,3325.418656,2380.432196,2384.221533,886.225228,864.651226,101.179625,...,24.624905,84.113413,6.92191,2.000438,3.444185,0.451796,0.707136,0.7070944,0.7007044,0.7130157
min,-500.0,0.0,40.3125,3.25,80.125,7725.6875,7725.6875,469.4375,315.875,-500.0,...,1.122431,17.970655,0.0,0.0,1.0,0.0,-1.0,-1.0,-1.0,-1.0
25%,67.035,0.0,1118.375,269.6875,2544.46875,11607.1875,11603.2725,1655.46875,881.0625,67.005,...,19.320472,71.800268,6.0,1.0,4.0,0.0,-0.707107,-0.7071068,-0.8660254,-0.8660254
50%,95.8,64.3125,2252.125,652.3125,4910.9375,13553.5,13553.7,2534.375,1364.0625,95.8,...,30.885292,91.396071,12.0,3.0,7.0,0.0,0.0,-1.83697e-16,-2.449294e-16,-1.83697e-16
75%,138.15,2726.0,4279.09375,1129.34375,7728.09375,15363.125,15361.71875,3093.34375,2073.875,138.13,...,46.931055,132.389315,18.0,5.0,10.0,1.0,0.707107,0.7071068,0.5,0.8660254
max,936.28,12611.1875,11654.3125,1708.625,18747.5,20329.875,20329.875,4293.3125,5036.625,936.28,...,262.937902,608.093095,23.0,6.0,12.0,1.0,1.0,1.0,1.0,1.0


In [11]:
# Save the processed data
master_df.to_csv('master_electricity_data.csv')

#### *Price vs Net Load Correlation*

In [12]:
# The data is sampled to 5000 points to keep the interactive output fast
sample_df = master_df.sample(n=5000).reset_index()

chart = alt.Chart(sample_df).mark_circle(opacity=0.3, size=20, color='teal').encode(
    x=alt.X('net_load_forecast:Q', title='Net Load (Demand - Renewables) [MWh]'),
    y=alt.Y('target_price:Q', title='Price (€/MWh)'),
    tooltip=['timestamp:T', 'target_price:Q', 'net_load_forecast:Q']
).properties(
    title='Interactive Merit Order Effect',
    width=600,
    height=400
).interactive()

chart.save('merit_order_interactive.json')

#### *Seasonality: Hourly Price Boxplots*

In [13]:
boxplot = alt.Chart(master_df).mark_boxplot().encode(
    x=alt.X('hour:O', title='Hour of Day'),
    y=alt.Y('target_price:Q', title='Price (€/MWh)', scale=alt.Scale(zero=False)),
    color=alt.Color('hour:O', legend=None, scale=alt.Scale(scheme='viridis'))
).properties(
    title='Price Distribution by Hour of Day',
    width=600, height=400
).interactive()

boxplot.save('hourly_seasonality.json')

#### *Correlation Heatmap (top features only for readability)*

In [14]:
corr_cols = ['target_price', 'fc_renewables_total', 'net_load_forecast', 
             'price_lag_24h', 'price_lag_168h', 'load_lag_24h', 
             'gen_lignite_lag_24h', 'gen_gas_lag_24h', 'hour', 'month']

corr_matrix = master_df[corr_cols].corr().reset_index().melt(id_vars='index')
corr_matrix.columns = ['var1', 'var2', 'correlation']

heatmap = alt.Chart(corr_matrix).mark_rect().encode(
    x=alt.X('var1:N', title=None),
    y=alt.Y('var2:N', title=None),
    color=alt.Color('correlation:Q', scale=alt.Scale(scheme='redyellowgreen', domain=[-1, 1])),
    tooltip=['var1', 'var2', alt.Tooltip('correlation:Q', format='.2f')]
).properties(
    title='Feature Correlation Matrix',
    width=500, height=500
).interactive()

heatmap.save('correlation_heatmap.json')

#### *Price Volatility over Time*

In [15]:
trend_df = master_df.tail(720).reset_index()

line_chart = alt.Chart(trend_df).mark_line(strokeWidth=1.5, color='darkblue').encode(
    x=alt.X('timestamp:T', title='Date'),
    y=alt.Y('target_price:Q', title='Price (€/MWh)'),
    tooltip=['timestamp:T', 'target_price:Q']
).properties(
    title='Electricity Price Trend (Last 30 Days)',
    width=700, height=300
).interactive()

line_chart.save('price_trend.json')

#### *Renewable Generation Impact*

In [16]:
# Sampled to 5000 points for performance
scatter_sample = master_df.sample(n=min(5000, len(master_df)), random_state=42).reset_index()

scatter_plot = alt.Chart(scatter_sample).mark_circle(opacity=0.3, size=20, color='green').encode(
    x=alt.X('fc_renewables_total:Q', title='Total Renewable Forecast [MWh]'),
    y=alt.Y('target_price:Q', title='Price (€/MWh)'),
    tooltip=['timestamp:T', 'target_price:Q', 'fc_renewables_total:Q']
).properties(
    title='Impact of Renewable Generation on Prices',
    width=600, height=400
).interactive()

scatter_plot.save('renewables_impact.json')

### Baseline Models

- Naive
- ARIMA
- SARIMA
- Prophet

### Model Training

In [17]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
import warnings
from datetime import datetime

# Statistical Models
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.statespace.sarimax import SARIMAX
from pmdarima import auto_arima

# Prophet
from prophet import Prophet

# LightGBM
import lightgbm as lgb
from sklearn.metrics import mean_absolute_error, mean_squared_error, mean_absolute_percentage_error

warnings.filterwarnings('ignore')

# --- CONFIGURATION ---
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)

# --- 1. DATA PREPARATION ---

def create_train_val_test_split(df, train_size=0.7, val_size=0.15):
    """
    Creates chronological train/val/test split for time series.
    Default: 70% train, 15% val, 15% test
    """
    n = len(df)
    train_end = int(n * train_size)
    val_end = int(n * (train_size + val_size))
    
    train = df.iloc[:train_end]
    val = df.iloc[train_end:val_end]
    test = df.iloc[val_end:]
    
    print("="*70)
    print("DATASET SPLIT SUMMARY")
    print("="*70)
    print(f"Train: {train.index[0]} to {train.index[-1]} ({len(train):,} samples, {len(train)/n*100:.1f}%)")
    print(f"Val:   {val.index[0]} to {val.index[-1]} ({len(val):,} samples, {len(val)/n*100:.1f}%)")
    print(f"Test:  {test.index[0]} to {test.index[-1]} ({len(test):,} samples, {len(test)/n*100:.1f}%)")
    print("="*70 + "\n")
    
    return train, val, test

def prepare_data_for_models(train, val, test, target_col='target_price'):
    """Prepares datasets for different model types."""
    # For tree-based models (LightGBM)
    feature_cols = [col for col in train.columns if col != target_col]
    
    X_train = train[feature_cols]
    y_train = train[target_col]
    
    X_val = val[feature_cols]
    y_val = val[target_col]
    
    X_test = test[feature_cols]
    y_test = test[target_col]
    
    # For time series models (ARIMA, SARIMA, Prophet)
    ts_train = train[target_col]
    ts_val = val[target_col]
    ts_test = test[target_col]
    
    return {
        'X_train': X_train, 'y_train': y_train,
        'X_val': X_val, 'y_val': y_val,
        'X_test': X_test, 'y_test': y_test,
        'ts_train': ts_train,
        'ts_val': ts_val,
        'ts_test': ts_test
    }

# --- 2. EVALUATION METRICS ---

def calculate_metrics(y_true, y_pred, model_name="Model"):
    """Calculate comprehensive metrics."""
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae = mean_absolute_error(y_true, y_pred)
    mape = mean_absolute_percentage_error(y_true, y_pred) * 100
    
    return {
        'Model': model_name,
        'RMSE': rmse,
        'MAE': mae,
        'MAPE': mape
    }

# --- 3. BASELINE MODELS ---

def naive_baseline(data_dict):
    """Naive forecast: Use yesterday's price (24h lag)."""
    print("\n" + "="*70)
    print("MODEL 1: NAIVE BASELINE (24h Lag)")
    print("="*70)
    
    start_time = time.time()
    
    # Prediction is simply the price 24 hours ago
    y_pred_val = data_dict['ts_train'].iloc[-len(data_dict['ts_val']):].values
    y_pred_test = data_dict['ts_val'].iloc[-len(data_dict['ts_test']):].values
    
    elapsed = time.time() - start_time
    
    metrics_val = calculate_metrics(data_dict['y_val'], y_pred_val, "Naive - Val")
    metrics_test = calculate_metrics(data_dict['y_test'], y_pred_test, "Naive - Test")
    
    print(f"Training Time: {elapsed:.2f}s")
    print(f"Val  → RMSE: {metrics_val['RMSE']:.2f}, MAE: {metrics_val['MAE']:.2f}, MAPE: {metrics_val['MAPE']:.2f}%")
    print(f"Test → RMSE: {metrics_test['RMSE']:.2f}, MAE: {metrics_test['MAE']:.2f}, MAPE: {metrics_test['MAPE']:.2f}%")
    
    return {
        'model': None,
        'predictions_val': y_pred_val,
        'predictions_test': y_pred_test,
        'metrics_val': metrics_val,
        'metrics_test': metrics_test,
        'training_time': elapsed
    }

# --- 4. ARIMA MODEL ---

def train_arima(data_dict, order=(2, 1, 2)):
    """Train ARIMA model with specified order."""
    print("\n" + "="*70)
    print(f"MODEL 2: ARIMA{order}")
    print("="*70)
    
    start_time = time.time()
    
    # Fit on training data
    print("Fitting ARIMA model...")
    model = ARIMA(data_dict['ts_train'], order=order)
    model_fit = model.fit()
    
    # Forecast for validation period
    y_pred_val = model_fit.forecast(steps=len(data_dict['ts_val']))
    
    # Refit on train + val for test predictions
    ts_train_val = pd.concat([data_dict['ts_train'], data_dict['ts_val']])
    model_test = ARIMA(ts_train_val, order=order)
    model_test_fit = model_test.fit()
    y_pred_test = model_test_fit.forecast(steps=len(data_dict['ts_test']))
    
    elapsed = time.time() - start_time
    
    metrics_val = calculate_metrics(data_dict['y_val'], y_pred_val, "ARIMA - Val")
    metrics_test = calculate_metrics(data_dict['y_test'], y_pred_test, "ARIMA - Test")
    
    print(f"Training Time: {elapsed:.2f}s")
    print(f"Val  → RMSE: {metrics_val['RMSE']:.2f}, MAE: {metrics_val['MAE']:.2f}, MAPE: {metrics_val['MAPE']:.2f}%")
    print(f"Test → RMSE: {metrics_test['RMSE']:.2f}, MAE: {metrics_test['MAE']:.2f}, MAPE: {metrics_test['MAPE']:.2f}%")
    
    return {
        'model': model_fit,
        'predictions_val': y_pred_val,
        'predictions_test': y_pred_test,
        'metrics_val': metrics_val,
        'metrics_test': metrics_test,
        'training_time': elapsed
    }

# --- 5. SARIMA MODEL ---

def train_sarima(data_dict, order=(1, 1, 1), seasonal_order=(1, 1, 1, 24)):
    """Train SARIMA model (24h seasonality for hourly data)."""
    print("\n" + "="*70)
    print(f"MODEL 3: SARIMA{order}x{seasonal_order}")
    print("="*70)
    print("Warning: This may take several minutes...")
    
    start_time = time.time()
    
    # Fit on training data
    print("Fitting SARIMA model...")
    model = SARIMAX(data_dict['ts_train'], 
                    order=order, 
                    seasonal_order=seasonal_order,
                    enforce_stationarity=False,
                    enforce_invertibility=False)
    model_fit = model.fit(disp=False)
    
    # Forecast for validation
    y_pred_val = model_fit.forecast(steps=len(data_dict['ts_val']))
    
    # Refit on train + val for test
    ts_train_val = pd.concat([data_dict['ts_train'], data_dict['ts_val']])
    model_test = SARIMAX(ts_train_val, 
                         order=order, 
                         seasonal_order=seasonal_order,
                         enforce_stationarity=False,
                         enforce_invertibility=False)
    model_test_fit = model_test.fit(disp=False)
    y_pred_test = model_test_fit.forecast(steps=len(data_dict['ts_test']))
    
    elapsed = time.time() - start_time
    
    metrics_val = calculate_metrics(data_dict['y_val'], y_pred_val, "SARIMA - Val")
    metrics_test = calculate_metrics(data_dict['y_test'], y_pred_test, "SARIMA - Test")
    
    print(f"Training Time: {elapsed:.2f}s ({elapsed/60:.1f} min)")
    print(f"Val  → RMSE: {metrics_val['RMSE']:.2f}, MAE: {metrics_val['MAE']:.2f}, MAPE: {metrics_val['MAPE']:.2f}%")
    print(f"Test → RMSE: {metrics_test['RMSE']:.2f}, MAE: {metrics_test['MAE']:.2f}, MAPE: {metrics_test['MAPE']:.2f}%")
    
    return {
        'model': model_fit,
        'predictions_val': y_pred_val,
        'predictions_test': y_pred_test,
        'metrics_val': metrics_val,
        'metrics_test': metrics_test,
        'training_time': elapsed
    }

# --- 6. PROPHET MODEL ---

def train_prophet(train, val, test, target_col='target_price'):
    """Train Facebook Prophet model."""
    print("\n" + "="*70)
    print("MODEL 4: PROPHET")
    print("="*70)
    
    start_time = time.time()
    
    # Prepare data for Prophet
    prophet_train = train.reset_index()[['timestamp', target_col]]
    prophet_train.columns = ['ds', 'y']
    
    # Train model
    print("Fitting Prophet model...")
    model = Prophet(
        yearly_seasonality=True,
        weekly_seasonality=True,
        daily_seasonality=True,
        seasonality_mode='multiplicative'
    )
    model.fit(prophet_train)
    
    # Validation predictions
    future_val = pd.DataFrame({'ds': val.index})
    forecast_val = model.predict(future_val)
    y_pred_val = forecast_val['yhat'].values
    
    # Refit on train + val for test
    prophet_train_val = pd.concat([train, val]).reset_index()[['timestamp', target_col]]
    prophet_train_val.columns = ['ds', 'y']
    model_test = Prophet(
        yearly_seasonality=True,
        weekly_seasonality=True,
        daily_seasonality=True,
        seasonality_mode='multiplicative'
    )
    model_test.fit(prophet_train_val)
    future_test = pd.DataFrame({'ds': test.index})
    forecast_test = model_test.predict(future_test)
    y_pred_test = forecast_test['yhat'].values
    
    elapsed = time.time() - start_time
    
    metrics_val = calculate_metrics(val[target_col], y_pred_val, "Prophet - Val")
    metrics_test = calculate_metrics(test[target_col], y_pred_test, "Prophet - Test")
    
    print(f"Training Time: {elapsed:.2f}s")
    print(f"Val  → RMSE: {metrics_val['RMSE']:.2f}, MAE: {metrics_val['MAE']:.2f}, MAPE: {metrics_val['MAPE']:.2f}%")
    print(f"Test → RMSE: {metrics_test['RMSE']:.2f}, MAE: {metrics_test['MAE']:.2f}, MAPE: {metrics_test['MAPE']:.2f}%")
    
    return {
        'model': model,
        'predictions_val': y_pred_val,
        'predictions_test': y_pred_test,
        'metrics_val': metrics_val,
        'metrics_test': metrics_test,
        'training_time': elapsed
    }

# --- 7. LIGHTGBM WITH HYPERPARAMETER TUNING ---

def train_lightgbm(data_dict, tune_hyperparams=True):
    """Train LightGBM with optional hyperparameter tuning."""
    print("\n" + "="*70)
    print("MODEL 5: LIGHTGBM" + (" (with Hyperparameter Tuning)" if tune_hyperparams else ""))
    print("="*70)
    
    start_time = time.time()
    
    if tune_hyperparams:
        print("Tuning hyperparameters using validation set...")
        
        # Hyperparameter search space
        param_grid = {
            'num_leaves': [31, 50, 70],
            'learning_rate': [0.01, 0.05, 0.1],
            'n_estimators': [100, 300, 500],
            'max_depth': [-1, 10, 20],
            'min_child_samples': [20, 50, 100]
        }
        
        best_score = float('inf')
        best_params = None
        
        # Grid search (simplified - you can use optuna for better search)
        from itertools import product
        
        param_combinations = [
            {'num_leaves': 50, 'learning_rate': 0.05, 'n_estimators': 300, 
             'max_depth': 10, 'min_child_samples': 20},
            {'num_leaves': 70, 'learning_rate': 0.01, 'n_estimators': 500, 
             'max_depth': -1, 'min_child_samples': 50},
            {'num_leaves': 31, 'learning_rate': 0.1, 'n_estimators': 100, 
             'max_depth': 20, 'min_child_samples': 100}
        ]
        
        for params in param_combinations:
            model = lgb.LGBMRegressor(**params, random_state=RANDOM_STATE, verbose=-1)
            model.fit(data_dict['X_train'], data_dict['y_train'])
            y_pred = model.predict(data_dict['X_val'])
            score = mean_squared_error(data_dict['y_val'], y_pred)
            
            if score < best_score:
                best_score = score
                best_params = params
        
        print(f"Best params: {best_params}")
        print(f"Best validation RMSE: {np.sqrt(best_score):.2f}")
        
        final_params = best_params
    else:
        final_params = {
            'num_leaves': 50,
            'learning_rate': 0.05,
            'n_estimators': 300,
            'max_depth': 10,
            'random_state': RANDOM_STATE,
            'verbose': -1
        }
    
    # Train final model on train set
    print("Training final model...")
    model = lgb.LGBMRegressor(**final_params)
    model.fit(
        data_dict['X_train'], 
        data_dict['y_train'],
        eval_set=[(data_dict['X_val'], data_dict['y_val'])],
        callbacks=[lgb.early_stopping(50, verbose=False)]
    )
    
    # Predictions
    y_pred_val = model.predict(data_dict['X_val'])
    
    # Retrain on train + val for test
    X_train_val = pd.concat([data_dict['X_train'], data_dict['X_val']])
    y_train_val = pd.concat([data_dict['y_train'], data_dict['y_val']])
    
    model_final = lgb.LGBMRegressor(**final_params)
    model_final.fit(X_train_val, y_train_val)
    y_pred_test = model_final.predict(data_dict['X_test'])
    
    elapsed = time.time() - start_time
    
    metrics_val = calculate_metrics(data_dict['y_val'], y_pred_val, "LightGBM - Val")
    metrics_test = calculate_metrics(data_dict['y_test'], y_pred_test, "LightGBM - Test")
    
    print(f"Training Time: {elapsed:.2f}s")
    print(f"Val  → RMSE: {metrics_val['RMSE']:.2f}, MAE: {metrics_val['MAE']:.2f}, MAPE: {metrics_val['MAPE']:.2f}%")
    print(f"Test → RMSE: {metrics_test['RMSE']:.2f}, MAE: {metrics_test['MAE']:.2f}, MAPE: {metrics_test['MAPE']:.2f}%")
    
    # Feature importance
    feature_imp = pd.DataFrame({
        'feature': data_dict['X_train'].columns,
        'importance': model.feature_importances_
    }).sort_values('importance', ascending=False)
    
    print(f"\nTop 10 Most Important Features:")
    print(feature_imp.head(10).to_string(index=False))
    
    return {
        'model': model_final,
        'predictions_val': y_pred_val,
        'predictions_test': y_pred_test,
        'metrics_val': metrics_val,
        'metrics_test': metrics_test,
        'training_time': elapsed,
        'feature_importance': feature_imp
    }

# --- 8. VISUALIZATION ---

def plot_results(results, val, test, target_col='target_price'):
    """Create comprehensive visualization of all model predictions."""
    
    fig, axes = plt.subplots(2, 1, figsize=(16, 10))
    
    # Validation Set
    ax = axes[0]
    ax.plot(val.index, val[target_col], label='Actual', color='black', linewidth=2, alpha=0.7)
    
    colors = ['blue', 'green', 'red', 'purple', 'orange']
    for (name, result), color in zip(results.items(), colors):
        ax.plot(val.index, result['predictions_val'], label=name, 
                color=color, alpha=0.6, linewidth=1.5)
    
    ax.set_title('Validation Set Predictions', fontsize=14, fontweight='bold')
    ax.set_xlabel('Date')
    ax.set_ylabel('Price (€/MWh)')
    ax.legend(loc='best')
    ax.grid(alpha=0.3)
    
    # Test Set
    ax = axes[1]
    ax.plot(test.index, test[target_col], label='Actual', color='black', linewidth=2, alpha=0.7)
    
    for (name, result), color in zip(results.items(), colors):
        ax.plot(test.index, result['predictions_test'], label=name, 
                color=color, alpha=0.6, linewidth=1.5)
    
    ax.set_title('Test Set Predictions', fontsize=14, fontweight='bold')
    ax.set_xlabel('Date')
    ax.set_ylabel('Price (€/MWh)')
    ax.legend(loc='best')
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('model_comparison_predictions.png', dpi=300, bbox_inches='tight')
    plt.close()
    print("\nSaved: model_comparison_predictions.png")

def create_comparison_table(results):
    """Create performance comparison table."""
    
    metrics_list = []
    for name, result in results.items():
        metrics_list.append(result['metrics_val'])
        metrics_list.append(result['metrics_test'])
    
    df_metrics = pd.DataFrame(metrics_list)
    
    # Create styled table
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.axis('tight')
    ax.axis('off')
    
    table = ax.table(cellText=df_metrics.values, 
                     colLabels=df_metrics.columns,
                     cellLoc='center',
                     loc='center',
                     bbox=[0, 0, 1, 1])
    
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 2)
    
    # Color coding
    for i in range(len(df_metrics)):
        if 'Val' in df_metrics.iloc[i]['Model']:
            for j in range(len(df_metrics.columns)):
                table[(i+1, j)].set_facecolor('#E8F4F8')
        else:
            for j in range(len(df_metrics.columns)):
                table[(i+1, j)].set_facecolor('#FFF4E6')
    
    plt.title('Model Performance Comparison', fontsize=14, fontweight='bold', pad=20)
    plt.savefig('model_comparison_table.png', dpi=300, bbox_inches='tight')
    plt.close()
    print("Saved: model_comparison_table.png")
    
    return df_metrics

# --- 9. MAIN EXECUTION ---

def run_complete_experiment(data_path='master_electricity_data.csv'):
    """Run complete model comparison experiment."""
    
    print("\n" + "="*70)
    print("ELECTRICITY PRICE FORECASTING - COMPLETE MODEL COMPARISON")
    print("="*70 + "\n")
    
    # Load data
    print("Loading data...")
    df = pd.read_csv(data_path, index_col=0, parse_dates=True)
    print(f"Loaded {len(df):,} samples with {df.shape[1]} features")
    
    # Create splits
    train, val, test = create_train_val_test_split(df, train_size=0.7, val_size=0.15)
    
    # Prepare data
    data_dict = prepare_data_for_models(train, val, test)
    
    # Store results
    results = {}
    
    # Model 1: Naive Baseline
    results['Naive'] = naive_baseline(data_dict)
    
    # Model 2: ARIMA
    results['ARIMA'] = train_arima(data_dict, order=(2, 1, 2))
    
    # Model 3: SARIMA (may take time!)
    results['SARIMA'] = train_sarima(data_dict, order=(1, 1, 1), seasonal_order=(1, 1, 1, 24))
    
    # Model 4: Prophet
    results['Prophet'] = train_prophet(train, val, test)
    
    # Model 5: LightGBM
    results['LightGBM'] = train_lightgbm(data_dict, tune_hyperparams=True)
    
    # Visualizations
    print("\n" + "="*70)
    print("GENERATING VISUALIZATIONS")
    print("="*70)
    plot_results(results, val, test)
    comparison_df = create_comparison_table(results)
    
    # Final Summary
    print("\n" + "="*70)
    print("FINAL SUMMARY")
    print("="*70)
    print("\nTest Set Performance:")
    print(comparison_df[comparison_df['Model'].str.contains('Test')].to_string(index=False))
    
    print("\n" + "="*70)
    print("EXPERIMENT COMPLETE!")
    print("="*70)
    print("Output files generated:")
    print("  - model_comparison_predictions.png")
    print("  - model_comparison_table.png")
    
    return results, comparison_df

# --- RUN EXPERIMENT ---
if __name__ == "__main__":
    results, metrics_df = run_complete_experiment('master_electricity_data.csv')


ELECTRICITY PRICE FORECASTING - COMPLETE MODEL COMPARISON

Loading data...
Loaded 42,811 samples with 24 features
DATASET SPLIT SUMMARY
Train: 2021-01-08 00:00:00 to 2024-07-07 18:00:00 (29,967 samples, 70.0%)
Val:   2024-07-07 19:00:00 to 2025-04-08 09:00:00 (6,422 samples, 15.0%)
Test:  2025-04-08 10:00:00 to 2025-12-31 23:00:00 (6,422 samples, 15.0%)


MODEL 1: NAIVE BASELINE (24h Lag)
Training Time: 0.00s
Val  → RMSE: 78.05, MAE: 57.21, MAPE: 106497297821827744.00%
Test → RMSE: 80.99, MAE: 57.92, MAPE: 500377128140826304.00%

MODEL 2: ARIMA(2, 1, 2)
Fitting ARIMA model...
Training Time: 14.57s
Val  → RMSE: 91.42, MAE: 74.31, MAPE: 44945541084187472.00%
Test → RMSE: 51.17, MAE: 33.71, MAPE: 466954198551052224.00%

MODEL 3: SARIMA(1, 1, 1)x(1, 1, 1, 24)
Fitting SARIMA model...
Training Time: 159.40s (2.7 min)
Val  → RMSE: 182.65, MAE: 153.58, MAPE: 284509783501280640.00%
Test → RMSE: 261.45, MAE: 222.85, MAPE: 589312318602199040.00%

MODEL 4: PROPHET
Fitting Prophet model...


13:52:26 - cmdstanpy - INFO - Chain [1] start processing
13:52:59 - cmdstanpy - INFO - Chain [1] done processing
13:53:01 - cmdstanpy - INFO - Chain [1] start processing
13:54:06 - cmdstanpy - INFO - Chain [1] done processing


Training Time: 102.65s
Val  → RMSE: 59.90, MAE: 42.15, MAPE: 98289498668852512.00%
Test → RMSE: 45.67, MAE: 33.55, MAPE: 310634773853171968.00%

MODEL 5: LIGHTGBM (with Hyperparameter Tuning)
Tuning hyperparameters using validation set...
Best params: {'num_leaves': 50, 'learning_rate': 0.05, 'n_estimators': 300, 'max_depth': 10, 'min_child_samples': 20}
Best validation RMSE: 30.18
Training final model...
Training Time: 6.55s
Val  → RMSE: 30.18, MAE: 16.78, MAPE: 18669765158820216.00%
Test → RMSE: 19.99, MAE: 13.37, MAPE: 50557948203995352.00%

Top 10 Most Important Features:
                feature  importance
  price_rolling_std_24h        1243
 price_rolling_mean_24h        1144
price_rolling_mean_168h        1090
         price_lag_168h         943
      net_load_forecast         906
            day_of_week         882
        gen_gas_lag_24h         840
    fc_renewables_total         774
    gen_lignite_lag_24h         765
             fc_wind_on         752

GENERATING VISUALIZA

### Predictions January, 27th, 2026

In [18]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
import warnings

# Models
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.statespace.sarimax import SARIMAX
from prophet import Prophet
import lightgbm as lgb

warnings.filterwarnings('ignore')

# --- CONFIGURATION ---
TARGET_DATE = '2026-01-27'  # Target prediction date
RANDOM_STATE = 42

# --- 1. LOAD TRAINED MODELS AND DATA ---

def load_full_dataset(data_path='master_electricity_data.csv'):
    """Load the complete dataset."""
    print("Loading master dataset...")
    df = pd.read_csv(data_path, index_col=0, parse_dates=True)
    print(f"Loaded {len(df):,} samples")
    print(f"Date range: {df.index[0]} to {df.index[-1]}")
    return df

# --- 2. PREDICT WITH TIME SERIES MODELS (ARIMA, SARIMA) ---

def predict_with_arima_sarima(df, target_date, target_col='target_price'):
    """Predict using ARIMA and SARIMA models."""
    
    # Calculate steps ahead
    last_date = df.index[-1]
    target_datetime = pd.to_datetime(target_date)
    hours_ahead = int((target_datetime - last_date).total_seconds() / 3600)
    
    print(f"\nPredicting {hours_ahead} hours ahead (from {last_date} to {target_datetime})")
    
    if hours_ahead <= 0:
        print("Error: Target date is in the past or present!")
        return None
    
    ts_data = df[target_col]
    
    results = {}
    
    # ARIMA
    print("\n[1/2] Training ARIMA model on full dataset...")
    arima_model = ARIMA(ts_data, order=(2, 1, 2))
    arima_fit = arima_model.fit()
    arima_forecast = arima_fit.forecast(steps=hours_ahead)
    results['ARIMA'] = arima_forecast.iloc[-1]  # Last prediction = target date
    print(f"ARIMA prediction for {target_date}: €{results['ARIMA']:.2f}/MWh")
    
    # SARIMA
    print("\n[2/2] Training SARIMA model on full dataset...")
    print("(This may take several minutes...)")
    sarima_model = SARIMAX(ts_data, 
                           order=(1, 1, 1), 
                           seasonal_order=(1, 1, 1, 24),
                           enforce_stationarity=False,
                           enforce_invertibility=False)
    sarima_fit = sarima_model.fit(disp=False)
    sarima_forecast = sarima_fit.forecast(steps=hours_ahead)
    results['SARIMA'] = sarima_forecast.iloc[-1]
    print(f"SARIMA prediction for {target_date}: €{results['SARIMA']:.2f}/MWh")
    
    return results, arima_forecast, sarima_forecast

# --- 3. PREDICT WITH PROPHET ---

def predict_with_prophet(df, target_date, target_col='target_price'):
    """Predict using Prophet model."""
    
    print("\n[3/4] Training Prophet model on full dataset...")
    
    # Prepare data
    prophet_df = df.reset_index()[[df.index.name or 'timestamp', target_col]]
    prophet_df.columns = ['ds', 'y']
    
    # Train model
    model = Prophet(
        yearly_seasonality=True,
        weekly_seasonality=True,
        daily_seasonality=True,
        seasonality_mode='multiplicative'
    )
    model.fit(prophet_df)
    
    # Create future dataframe
    last_date = df.index[-1]
    target_datetime = pd.to_datetime(target_date)
    hours_ahead = int((target_datetime - last_date).total_seconds() / 3600)
    
    future = model.make_future_dataframe(periods=hours_ahead, freq='h')
    forecast = model.predict(future)
    
    # Get prediction for target date
    target_pred = forecast[forecast['ds'] == target_datetime]['yhat'].values[0]
    
    print(f"Prophet prediction for {target_date}: €{target_pred:.2f}/MWh")
    
    return target_pred, forecast

# --- 4. PREDICT WITH LIGHTGBM (Requires Feature Engineering) ---

def create_future_features(df, target_date, target_col='target_price'):
    """
    Create features for future prediction.
    Note: This requires assumptions about future renewable forecasts.
    """
    
    last_date = df.index[-1]
    target_datetime = pd.to_datetime(target_date)
    hours_ahead = int((target_datetime - last_date).total_seconds() / 3600)
    
    # Create future datetime index
    future_index = pd.date_range(start=last_date + timedelta(hours=1), 
                                  periods=hours_ahead, 
                                  freq='h')
    
    # Initialize future dataframe with target column placeholder
    future_df = pd.DataFrame(index=future_index)
    future_df[target_col] = np.nan  # Initialize target column
    
    # Calendar features (we know these for future dates)
    future_df['hour'] = future_df.index.hour
    future_df['day_of_week'] = future_df.index.dayofweek
    future_df['month'] = future_df.index.month
    future_df['is_weekend'] = future_df['day_of_week'].isin([5, 6]).astype(int)
    future_df['hour_sin'] = np.sin(2 * np.pi * future_df['hour'] / 24)
    future_df['hour_cos'] = np.cos(2 * np.pi * future_df['hour'] / 24)
    future_df['month_sin'] = np.sin(2 * np.pi * future_df['month'] / 12)
    future_df['month_cos'] = np.cos(2 * np.pi * future_df['month'] / 12)
    
    # Lag features (use recent historical data)
    last_prices = df[target_col].tail(168)  # Last week of prices
    
    # Initialize lag columns
    future_df['price_lag_24h'] = np.nan
    future_df['price_lag_168h'] = np.nan
    future_df['price_lag_48h'] = np.nan
    
    for i, future_hour in enumerate(future_index):
        # 24h lag
        if i >= 24:
            future_df.loc[future_hour, 'price_lag_24h'] = future_df.iloc[i - 24][target_col]
        else:
            lag_idx = len(last_prices) + i - 24
            if lag_idx >= 0 and lag_idx < len(last_prices):
                future_df.loc[future_hour, 'price_lag_24h'] = last_prices.iloc[lag_idx]
            else:
                future_df.loc[future_hour, 'price_lag_24h'] = df[target_col].iloc[-24 + i]
        
        # 168h lag
        if i >= 168:
            future_df.loc[future_hour, 'price_lag_168h'] = future_df.iloc[i - 168][target_col]
        else:
            lag_idx = len(last_prices) + i - 168
            if lag_idx >= 0:
                future_df.loc[future_hour, 'price_lag_168h'] = last_prices.iloc[lag_idx]
            else:
                # Use data from main df
                historical_idx = len(df) + i - 168
                if historical_idx >= 0:
                    future_df.loc[future_hour, 'price_lag_168h'] = df[target_col].iloc[historical_idx]
                else:
                    future_df.loc[future_hour, 'price_lag_168h'] = df[target_col].mean()
        
        # 48h lag
        if i >= 48:
            future_df.loc[future_hour, 'price_lag_48h'] = future_df.iloc[i - 48][target_col]
        else:
            lag_idx = len(last_prices) + i - 48
            if lag_idx >= 0 and lag_idx < len(last_prices):
                future_df.loc[future_hour, 'price_lag_48h'] = last_prices.iloc[lag_idx]
            else:
                future_df.loc[future_hour, 'price_lag_48h'] = df[target_col].iloc[-48 + i]
    
    # Initialize all feature columns that will be needed
    future_df['fc_solar'] = np.nan
    future_df['fc_wind_on'] = np.nan
    future_df['fc_wind_off'] = np.nan
    future_df['load_lag_24h'] = np.nan
    future_df['load_lag_168h'] = np.nan
    future_df['gen_lignite_lag_24h'] = np.nan
    future_df['gen_gas_lag_24h'] = np.nan
    
    # For renewable forecasts and load - use seasonal averages from historical data
    # (In production, you'd get these from actual forecast data)
    print("\nNote: Using historical seasonal averages for renewable forecasts and load")
    print("      (In production, use actual forecast data)")
    
    for idx in future_df.index:
        hour = idx.hour
        dow = idx.dayofweek
        month = idx.month
        
        # Get similar historical periods
        mask = (df.index.hour == hour) & (df.index.dayofweek == dow) & (df.index.month == month)
        historical_similar = df[mask].tail(30)  # Last 30 similar hours
        
        if len(historical_similar) > 0:
            future_df.loc[idx, 'fc_solar'] = historical_similar['fc_solar'].mean()
            future_df.loc[idx, 'fc_wind_on'] = historical_similar['fc_wind_on'].mean()
            future_df.loc[idx, 'fc_wind_off'] = historical_similar['fc_wind_off'].mean()
            future_df.loc[idx, 'load_lag_24h'] = historical_similar['load_lag_24h'].mean()
            future_df.loc[idx, 'load_lag_168h'] = historical_similar.get('load_lag_168h', historical_similar['load_lag_24h']).mean()
            future_df.loc[idx, 'gen_lignite_lag_24h'] = historical_similar['gen_lignite_lag_24h'].mean()
            future_df.loc[idx, 'gen_gas_lag_24h'] = historical_similar['gen_gas_lag_24h'].mean()
        else:
            # Fallback to overall averages
            future_df.loc[idx, 'fc_solar'] = df['fc_solar'].mean()
            future_df.loc[idx, 'fc_wind_on'] = df['fc_wind_on'].mean()
            future_df.loc[idx, 'fc_wind_off'] = df['fc_wind_off'].mean()
            future_df.loc[idx, 'load_lag_24h'] = df['load_lag_24h'].mean()
            future_df.loc[idx, 'load_lag_168h'] = df.get('load_lag_168h', df['load_lag_24h']).mean()
            future_df.loc[idx, 'gen_lignite_lag_24h'] = df['gen_lignite_lag_24h'].mean()
            future_df.loc[idx, 'gen_gas_lag_24h'] = df['gen_gas_lag_24h'].mean()
    
    # Derived features
    future_df['fc_renewables_total'] = future_df['fc_solar'] + future_df['fc_wind_on'] + future_df['fc_wind_off']
    future_df['net_load_forecast'] = future_df['load_lag_24h'] - future_df['fc_renewables_total']
    
    # Rolling features (use last known values)
    future_df['price_rolling_mean_24h'] = df[target_col].tail(24).mean()
    future_df['price_rolling_std_24h'] = df[target_col].tail(24).std()
    future_df['price_rolling_mean_168h'] = df[target_col].tail(168).mean()
    
    return future_df

def predict_with_lightgbm(df, target_date, target_col='target_price'):
    """Predict using LightGBM model."""
    
    print("\n[4/4] Training LightGBM model on full dataset...")
    
    # Prepare training data
    feature_cols = [col for col in df.columns if col != target_col]
    X = df[feature_cols]
    y = df[target_col]
    
    # Train model
    params = {
        'num_leaves': 50,
        'learning_rate': 0.05,
        'n_estimators': 300,
        'max_depth': 10,
        'random_state': RANDOM_STATE,
        'verbose': -1
    }
    
    model = lgb.LGBMRegressor(**params)
    model.fit(X, y)
    
    # Create future features
    future_df = create_future_features(df, target_date, target_col)
    
    # Iterative prediction (important for lag features)
    predictions = []
    
    # Ensure all required columns exist before prediction
    missing_cols = [col for col in feature_cols if col not in future_df.columns]
    if missing_cols:
        print(f"Warning: Missing features {missing_cols}, filling with 0")
        for col in missing_cols:
            future_df[col] = 0
    
    for i in range(len(future_df)):
        # Get features for this hour - ensure correct order
        X_pred = future_df.iloc[i:i+1][feature_cols]
        pred = model.predict(X_pred)[0]
        predictions.append(pred)
        
        # Update future_df with this prediction for subsequent lag calculations
        future_df.iloc[i, future_df.columns.get_loc(target_col)] = pred
    
    # Get the final prediction for target date
    target_datetime = pd.to_datetime(target_date)
    target_pred = predictions[-1]  # Last prediction is for target date
    
    print(f"LightGBM prediction for {target_date}: €{target_pred:.2f}/MWh")
    
    return target_pred, future_df, predictions

# --- 5. VISUALIZATION ---

def visualize_predictions(df, predictions_dict, target_date):
    """Create visualization showing historical data and predictions."""
    
    fig, ax = plt.subplots(figsize=(16, 8))
    
    # Plot last 30 days of historical data
    last_30_days = df.tail(720)  # 30 days * 24 hours
    ax.plot(last_30_days.index, last_30_days['target_price'], 
            label='Historical Price', color='black', linewidth=2, alpha=0.7)
    
    # Mark target date
    target_datetime = pd.to_datetime(target_date)
    
    # Plot predictions
    colors = {'ARIMA': 'blue', 'SARIMA': 'green', 'Prophet': 'red', 'LightGBM': 'orange'}
    
    for model_name, pred_value in predictions_dict.items():
        ax.scatter(target_datetime, pred_value, 
                  color=colors.get(model_name, 'purple'), 
                  s=200, marker='*', 
                  label=f'{model_name}: €{pred_value:.2f}/MWh',
                  zorder=5, edgecolors='black', linewidths=1.5)
    
    # Formatting
    ax.axvline(df.index[-1], color='gray', linestyle='--', alpha=0.5, label='Last Known Data')
    ax.axvline(target_datetime, color='red', linestyle='--', alpha=0.5, label='Target Date')
    
    ax.set_title(f'Electricity Price Predictions for {target_date}', 
                fontsize=16, fontweight='bold', pad=20)
    ax.set_xlabel('Date', fontsize=12)
    ax.set_ylabel('Price (€/MWh)', fontsize=12)
    ax.legend(loc='best', fontsize=10)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'prediction_{target_date}.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"\nSaved visualization: prediction_{target_date}.png")

def create_prediction_summary(predictions_dict, target_date):
    """Create a summary table of predictions."""
    
    summary_df = pd.DataFrame({
        'Model': list(predictions_dict.keys()),
        'Predicted Price (€/MWh)': [f"{v:.2f}" for v in predictions_dict.values()]
    })
    
    # Calculate statistics
    prices = list(predictions_dict.values())
    mean_pred = np.mean(prices)
    std_pred = np.std(prices)
    min_pred = np.min(prices)
    max_pred = np.max(prices)
    
    print("\n" + "="*70)
    print(f"ELECTRICITY PRICE PREDICTIONS FOR {target_date}")
    print("="*70)
    print(summary_df.to_string(index=False))
    print("\n" + "-"*70)
    print(f"Ensemble Average:  €{mean_pred:.2f}/MWh")
    print(f"Standard Deviation: €{std_pred:.2f}/MWh")
    print(f"Range:             €{min_pred:.2f} - €{max_pred:.2f}/MWh")
    print("="*70)
    
    return summary_df, mean_pred

# --- 6. MAIN PREDICTION PIPELINE ---

def predict_electricity_price(target_date='2026-01-27', 
                              data_path='master_electricity_data.csv',
                              hourly_prediction=True):
    """
    Main function to predict electricity price for a future date.
    
    Parameters:
    -----------
    target_date : str
        Target date in format 'YYYY-MM-DD' or 'YYYY-MM-DD HH:MM:SS'
    data_path : str
        Path to the master dataset
    hourly_prediction : bool
        If True and no time specified, predicts for each hour of the day
    """
    
    print("\n" + "="*70)
    print("ELECTRICITY PRICE FORECASTING SYSTEM")
    print("="*70)
    
    # Load data
    df = load_full_dataset(data_path)
    
    # If only date provided (no time), predict for noon or all hours
    if len(target_date) == 10:  # Format: YYYY-MM-DD
        if hourly_prediction:
            print(f"\nGenerating hourly predictions for {target_date} (00:00 to 23:00)")
            target_date_full = f"{target_date} 12:00:00"  # Use noon for main prediction
        else:
            target_date_full = f"{target_date} 12:00:00"
    else:
        target_date_full = target_date
    
    # Run predictions
    predictions = {}
    
    # Time series models
    ts_results, arima_fc, sarima_fc = predict_with_arima_sarima(df, target_date_full)
    predictions.update(ts_results)
    
    # Prophet
    prophet_pred, prophet_fc = predict_with_prophet(df, target_date_full)
    predictions['Prophet'] = prophet_pred
    
    # LightGBM
    lgb_pred, future_features, lgb_preds = predict_with_lightgbm(df, target_date_full)
    predictions['LightGBM'] = lgb_pred
    
    # Create summary
    summary_df, ensemble_pred = create_prediction_summary(predictions, target_date_full)
    
    # Visualize
    visualize_predictions(df, predictions, target_date_full)
    
    print("Prediction complete!")
    print(f"Recommended Ensemble Prediction: €{ensemble_pred:.2f}/MWh")
    
    return predictions, ensemble_pred, summary_df

# --- 7. RUN PREDICTION ---

if __name__ == "__main__":
    # Predict for January 27, 2026
    predictions, ensemble, summary = predict_electricity_price(
        target_date='2026-01-27',
        data_path='master_electricity_data.csv'
    )
    
    print("\n" + "="*70)
    print("DONE! Check the generated visualization.")
    print("="*70)


ELECTRICITY PRICE FORECASTING SYSTEM
Loading master dataset...
Loaded 42,811 samples
Date range: 2021-01-08 00:00:00 to 2025-12-31 23:00:00

Generating hourly predictions for 2026-01-27 (00:00 to 23:00)

Predicting 637 hours ahead (from 2025-12-31 23:00:00 to 2026-01-27 12:00:00)

[1/2] Training ARIMA model on full dataset...
ARIMA prediction for 2026-01-27 12:00:00: €86.67/MWh

[2/2] Training SARIMA model on full dataset...
(This may take several minutes...)
SARIMA prediction for 2026-01-27 12:00:00: €65.36/MWh

[3/4] Training Prophet model on full dataset...


13:56:30 - cmdstanpy - INFO - Chain [1] start processing
13:57:25 - cmdstanpy - INFO - Chain [1] done processing


Prophet prediction for 2026-01-27 12:00:00: €94.48/MWh

[4/4] Training LightGBM model on full dataset...

Note: Using historical seasonal averages for renewable forecasts and load
      (In production, use actual forecast data)
LightGBM prediction for 2026-01-27 12:00:00: €79.49/MWh

ELECTRICITY PRICE PREDICTIONS FOR 2026-01-27 12:00:00
   Model Predicted Price (€/MWh)
   ARIMA                   86.67
  SARIMA                   65.36
 Prophet                   94.48
LightGBM                   79.49

----------------------------------------------------------------------
Ensemble Average:  €81.50/MWh
Standard Deviation: €10.72/MWh
Range:             €65.36 - €94.48/MWh

Saved visualization: prediction_2026-01-27 12:00:00.png
Prediction complete!
Recommended Ensemble Prediction: €81.50/MWh

DONE! Check the generated visualization.
