# Chronos

In [1]:
import warnings
import transformers
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
from chronos import ChronosPipeline
from pmdarima.arima import auto_arima
from statsmodels.tsa.statespace.sarimax import SARIMAX
from tqdm import tqdm
from sklearn.metrics import mean_absolute_percentage_error, mean_absolute_error, root_mean_squared_error
from darts import TimeSeries

In [2]:
# Import the data
def load_and_prepare_data(file_path):
    """
    Load energy prices data from a CSV file, ensure chronological order, and convert 'Date' to datetime.
    """
    df = pd.read_csv(file_path)
    df.sort_values('Date', inplace=True)
    df.set_index('Date', inplace=True)
    df = pd.DataFrame(df)
    return df

In [3]:
train_df = load_and_prepare_data('../../data/Final_data/train_df.csv')
test_df = load_and_prepare_data('../../data/Final_data/test_df.csv')

train_df.reset_index(inplace=True)
test_df.reset_index(inplace=True)

In [4]:
# Import the data
df = load_and_prepare_data('../../data/Final_data/final_data_july.csv')

# Reset the index
df = df.reset_index()

target_column = "Day_ahead_price (€/MWh)"

In [5]:
# date of first forecast
start_date = "2022-07-01"

# date of last forecast
end_date = "2024-07-28"

In [6]:
# create a TimeSeries object from df
data = df[['Date', target_column]]
# transform to a pandas series
data = data.set_index('Date')
data = data.squeeze()
data

Date
2012-01-08    26.83
2012-01-09    47.91
2012-01-10    45.77
2012-01-11    47.83
2012-01-12    43.10
              ...  
2024-07-24    66.61
2024-07-25    78.34
2024-07-26    93.04
2024-07-27    80.74
2024-07-28    43.96
Name: Day_ahead_price (€/MWh), Length: 4586, dtype: float64

In [7]:
data.shape

(4586,)

# Sarimax

In [8]:
best_sarima_model = auto_arima(
    y=data[data.index < start_date],
    start_p=0,
    start_q=0,
    start_P=0,
    start_Q=0,
    m=12,
    seasonal=True,
)

print(best_sarima_model.summary())

KeyboardInterrupt: 

In [68]:
# create a list for storing the forecasts
sarima_forecasts = []

# loop across the dates
for t in tqdm(range(data.index.get_loc(start_date), data.index.get_loc(end_date) + 1)):

    # extract the training data
    context = data.iloc[:t]

    # train the model
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        sarima_model = SARIMAX(
            endog=context,
            order=best_sarima_model.order,
            seasonal_order=best_sarima_model.seasonal_order,
            trend="c" if best_sarima_model.with_intercept else None,
        ).fit(disp=0)

    # generate the one-step-ahead forecast
    sarima_forecast = sarima_model.get_forecast(steps=1)

    # save the forecast
    sarima_forecasts.append({
        "date": data.index[t],
        "actual": data.values[t],
        "mean": sarima_forecast.predicted_mean.item(),
        "std": sarima_forecast.var_pred_mean.item() ** 0.5,
    })

# cast the forecasts to data frame
sarima_forecasts = pd.DataFrame(sarima_forecasts)

100%|██████████| 759/759 [1:23:33<00:00,  6.61s/it]    


In [69]:
import plotly.graph_objects as go

fig = go.Figure()

# Add actual data trace
fig.add_trace(go.Scatter(
    x=sarima_forecasts["date"],
    y=sarima_forecasts["actual"],
    mode='lines',
    line=dict(color="#3f4751", width=1),
    name="Actual"
))

# Add predicted mean trace
fig.add_trace(go.Scatter(
    x=sarima_forecasts["date"],
    y=sarima_forecasts["mean"],
    mode='lines',
    line=dict(color="#ca8a04", width=1),
    name="Predicted"
))

# Add predicted +/- 1 standard deviation band
fig.add_trace(go.Scatter(
    x=sarima_forecasts["date"],
    y=sarima_forecasts["mean"] + sarima_forecasts["std"],
    fill=None,
    mode='lines',
    line=dict(color="#ca8a04", width=0.5),
    showlegend=False
))

fig.add_trace(go.Scatter(
    x=sarima_forecasts["date"],
    y=sarima_forecasts["mean"] - sarima_forecasts["std"],
    fill='tonexty',  # Fill to previous trace
    mode='lines',
    line=dict(color="#ca8a04", width=0.5),
    name="Predicted +/- 1 Std. Dev.",
    opacity=0.2
))

# Customize layout
fig.update_layout(
    title=f"SARIMA Forecast for {target_column}",
    xaxis_title="Time",
    yaxis_title="Value",
    legend=dict(x=1.05, y=1),
    margin=dict(l=50, r=50, t=50, b=50),
    template="plotly",
    width=1500,
    height=450
)

# Show plot
fig.show()

In [82]:
sarima_metrics = pd.DataFrame(
    columns=["Metric", "Value"],
    data=[
        {"Metric": "RMSE", "Value": root_mean_squared_error(
            y_true=sarima_forecasts["actual"], y_pred=sarima_forecasts["mean"])},
        {"Metric": "MAE", "Value": mean_absolute_error(
            y_true=sarima_forecasts["actual"], y_pred=sarima_forecasts["mean"])},
        {"Metric": "MAPE", "Value": mean_absolute_percentage_error(
            y_true=sarima_forecasts["actual"], y_pred=sarima_forecasts["mean"])},
        {"Metric": "SMAPE", "Value": np.mean(np.abs(sarima_forecasts["actual"] - sarima_forecasts["mean"]) / (
            np.abs(sarima_forecasts["actual"]) + np.abs(sarima_forecasts["mean"])) / 2)},
        {"Metric": "MSE", "Value": np.mean(
            (sarima_forecasts["actual"] - sarima_forecasts["mean"]) ** 2)},
    ]
).set_index("Metric")

In [83]:
sarima_metrics

Unnamed: 0_level_0,Value
Metric,Unnamed: 1_level_1
RMSE,39.979538
MAE,26.200492
MAPE,0.853567
SMAPE,0.066294
MSE,1598.363479


In [84]:
sarima_forecasts.head()

Unnamed: 0,date,actual,mean,std
0,2022-07-01,314.38,303.158149,17.09631
1,2022-07-02,218.92,281.91548,17.095132
2,2022-07-03,200.11,234.283706,17.122559
3,2022-07-04,293.89,232.698931,17.12911
4,2022-07-05,318.37,304.099594,17.15512


In [85]:
sarima_forecasts.tail()

Unnamed: 0,date,actual,mean,std
754,2024-07-24,66.61,75.859862,22.453644
755,2024-07-25,78.34,74.965748,22.451608
756,2024-07-26,93.04,78.533397,22.449213
757,2024-07-27,80.74,81.927899,22.447786
758,2024-07-28,43.96,80.922192,22.445344


## Chronos Pipeline

In [9]:
chronos_model = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-large",
    device_map="mps",
    torch_dtype=torch.bfloat16,
)

In [10]:
# save the start time
start_time = datetime.now()

# create a list for storing the forecasts
chronos_forecasts = []

transformers.set_seed(42)

NUM_SAMPLES = 200

# loop across the dates
for t in tqdm(range(data.index.get_loc(start_date), data.index.get_loc(end_date) + 1)):

    # extract the context window
    context = data.iloc[:t]

    # generate the one-step-ahead forecast

    chronos_forecast = chronos_model.predict(
        context=torch.from_numpy(context.values),
        prediction_length=1,
        num_samples=NUM_SAMPLES
    ).detach().cpu().numpy().flatten()

    # save the forecasts
    chronos_forecasts.append({
        "date": data.index[t],
        "actual": data.values[t],
        "mean": np.mean(chronos_forecast),
        "std": np.std(chronos_forecast, ddof=1),
    })

# cast the forecasts to data frame
chronos_forecasts = pd.DataFrame(chronos_forecasts)

# save the end time
end_time = datetime.now()

print(f"\nRunning time of Chronos model: {end_time - start_time}")

100%|██████████| 759/759 [1:22:06<00:00,  6.49s/it]


Running time of Chronos model: 1:22:06.780198





In [28]:
context = data.iloc[:3827]
context

Unnamed: 0_level_0,Day_ahead_price (€/MWh),Solar_radiation (W/m2),Wind_speed (m/s),Temperature (°C),Biomass (GWh),Hard_coal (GWh),Hydro (GWh),Lignite (GWh),Natural_gas (GWh),Other (GWh),...,Lag_1_day,Lag_2_days,Lag_3_days,Lag_4_days,Lag_5_days,Lag_6_days,Lag_7_days,Day_of_week,Month,Rolling_mean_7
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2012-01-08,26.83,17.54,5.21,3.74,98.605,189.718,48.467,354.178,256.892,52.178,...,32.58,36.26,20.35,32.16,35.03,33.82,18.19,6,1,31.00
2012-01-09,47.91,13.04,4.24,3.80,98.605,344.154,49.054,382.756,282.438,60.752,...,26.83,32.58,36.26,20.35,32.16,35.03,33.82,0,1,33.02
2012-01-10,45.77,28.71,4.30,4.81,98.605,360.126,51.143,334.267,267.311,62.106,...,47.91,26.83,32.58,36.26,20.35,32.16,35.03,1,1,34.55
2012-01-11,47.83,21.58,4.08,5.14,98.605,360.330,50.693,385.000,277.343,60.862,...,45.77,47.91,26.83,32.58,36.26,20.35,32.16,2,1,36.79
2012-01-12,43.10,25.12,6.77,4.98,98.605,306.521,50.732,332.985,266.820,56.922,...,47.83,45.77,47.91,26.83,32.58,36.26,20.35,3,1,40.04
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2022-06-26,201.67,276.12,2.55,21.41,112.242,95.208,50.184,287.024,115.448,65.477,...,240.39,274.30,295.08,323.34,288.59,267.37,154.65,6,6,270.11
2022-06-27,316.65,226.29,3.22,20.08,112.200,155.084,52.595,318.812,208.367,64.947,...,201.67,240.39,274.30,295.08,323.34,288.59,267.37,0,6,277.15
2022-06-28,331.52,271.71,2.66,18.55,112.600,155.718,55.593,330.188,201.166,64.451,...,316.65,201.67,240.39,274.30,295.08,323.34,288.59,1,6,283.28
2022-06-29,315.54,219.42,2.61,19.88,112.771,152.185,57.691,334.479,211.518,66.015,...,331.52,316.65,201.67,240.39,274.30,295.08,323.34,2,6,282.16


In [12]:
chronos_forecasts.head()

Unnamed: 0,date,actual,mean,std
0,2022-07-01,314.38,312.135243,11.120469
1,2022-07-02,218.92,268.382704,14.436066
2,2022-07-03,200.11,189.05795,13.956886
3,2022-07-04,293.89,264.319289,16.543639
4,2022-07-05,318.37,301.928182,13.341038


In [13]:
chronos_forecasts.tail()

Unnamed: 0,date,actual,mean,std
754,2024-07-24,66.61,84.896215,7.078767
755,2024-07-25,78.34,78.436724,8.331457
756,2024-07-26,93.04,78.545977,8.027084
757,2024-07-27,80.74,63.325562,8.26518
758,2024-07-28,43.96,68.110365,7.153772


In [14]:
import plotly.graph_objs as go
import plotly.io as pio

# Create traces for actual and predicted values
trace_actual = go.Scatter(
    x=chronos_forecasts["date"].values,
    y=chronos_forecasts["actual"].values,
    mode='lines',
    name='Actual',
    line=dict(color='#3f4751', width=1)
)

trace_predicted = go.Scatter(
    x=chronos_forecasts["date"].values,
    y=chronos_forecasts["mean"].values,
    mode='lines',
    name='Predicted',
    line=dict(color='#009ad3', width=1)
)

# Create traces for confidence intervals
trace_std_1 = go.Scatter(
    x=chronos_forecasts["date"].values,
    y=chronos_forecasts["mean"].values + chronos_forecasts["std"].values,
    mode='lines',
    name='Predicted +/- 1 Std. Dev.',
    line=dict(color='#009ad3', width=0),
    fill='tonexty',
    fillcolor='rgba(0, 154, 211, 0.2)'
)

trace_std_1_neg = go.Scatter(
    x=chronos_forecasts["date"].values,
    y=chronos_forecasts["mean"].values - chronos_forecasts["std"].values,
    mode='lines',
    line=dict(color='#009ad3', width=0),
    showlegend=False,
    fill='tonexty',
    fillcolor='rgba(0, 154, 211, 0.2)'
)

trace_std_2 = go.Scatter(
    x=chronos_forecasts["date"].values,
    y=chronos_forecasts["mean"].values + 2 * chronos_forecasts["std"].values,
    mode='lines',
    name='Predicted +/- 2 Std. Dev.',
    line=dict(color='#009ad3', width=0),
    fill='tonexty',
    fillcolor='rgba(0, 154, 211, 0.1)'
)

trace_std_2_neg = go.Scatter(
    x=chronos_forecasts["date"].values,
    y=chronos_forecasts["mean"].values - 2 * chronos_forecasts["std"].values,
    mode='lines',
    line=dict(color='#009ad3', width=0),
    showlegend=False,
    fill='tonexty',
    fillcolor='rgba(0, 154, 211, 0.1)'
)

# Create the figure with all the traces
fig = go.Figure()
fig.add_trace(trace_actual)
fig.add_trace(trace_predicted)
fig.add_trace(trace_std_2)
fig.add_trace(trace_std_2_neg)
fig.add_trace(trace_std_1)
fig.add_trace(trace_std_1_neg)

# Set layout options
fig.update_layout(
    xaxis_title='Time',
    yaxis_title='Value',
    legend=dict(x=1.05, y=1),
    margin=dict(t=20, b=20, l=20, r=20),
    width=1200,
    height=450
)

# Show the plot
pio.show(fig)

In [20]:
chronos_metrics = pd.DataFrame(
    columns=["Metric", "Value"],
    data=[
        {"Metric": "RMSE", "Value": root_mean_squared_error(
            y_true=chronos_forecasts["actual"], y_pred=chronos_forecasts["mean"])},
        {"Metric": "MAE", "Value": mean_absolute_error(
            y_true=chronos_forecasts["actual"], y_pred=chronos_forecasts["mean"])},
        {"Metric": "MAPE", "Value": 100 * mean_absolute_percentage_error(
            y_true=chronos_forecasts["actual"], y_pred=chronos_forecasts["mean"])},
        {"Metric": "SMAPE", "Value": 100 * np.mean(np.abs(chronos_forecasts["actual"] - chronos_forecasts["mean"]) / (
            (np.abs(chronos_forecasts["actual"]) + np.abs(chronos_forecasts["mean"])) / 2))},
        {"Metric": "MSE", "Value": np.mean(
            (chronos_forecasts["actual"] - chronos_forecasts["mean"]) ** 2)},
    ]
).set_index("Metric")

In [21]:
chronos_metrics

Unnamed: 0_level_0,Value
Metric,Unnamed: 1_level_1
RMSE,34.911347
MAE,23.044318
MAPE,94.858059
SMAPE,23.111305
MSE,1218.80218


In [22]:
# Save Chronos metrics as csv
chronos_metrics.to_csv(f'chronos_metrics_{NUM_SAMPLES}.csv')