# Chronos

In [1]:
import warnings
import transformers
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
from chronos import ChronosPipeline
from pmdarima.arima import auto_arima
from statsmodels.tsa.statespace.sarimax import SARIMAX
from tqdm import tqdm
from sklearn.metrics import mean_absolute_percentage_error, mean_absolute_error, root_mean_squared_error
from darts import TimeSeries

In [2]:
# Import the data
def load_and_prepare_data(file_path):
    """
    Load energy prices data from a CSV file, ensure chronological order, and convert 'Date' to datetime.
    """
    df = pd.read_csv(file_path)
    df.sort_values('Date', inplace=True)
    df.set_index('Date', inplace=True)
    df = pd.DataFrame(df)
    return df

In [3]:
train_df = load_and_prepare_data('../../data/Final_data/train_df.csv')
test_df = load_and_prepare_data('../../data/Final_data/test_df.csv')

train_df.reset_index(inplace=True)
test_df.reset_index(inplace=True)

In [4]:
# Import the data
df = load_and_prepare_data('../../data/Final_data/final_data_july.csv')

# Reset the index
df = df.reset_index()

target_column = "Day_ahead_price (€/MWh)"

In [5]:
# date of first forecast
start_date = "2022-07-01"

# date of last forecast
end_date = "2024-07-28"

In [6]:
# create a TimeSeries object from df
data = df[['Date', target_column]]
# transform to a pandas series
data = data.set_index('Date')
data = data.squeeze()
data

Date
2012-01-08    26.83
2012-01-09    47.91
2012-01-10    45.77
2012-01-11    47.83
2012-01-12    43.10
              ...  
2024-07-24    66.61
2024-07-25    78.34
2024-07-26    93.04
2024-07-27    80.74
2024-07-28    43.96
Name: Day_ahead_price (€/MWh), Length: 4586, dtype: float64

In [7]:
data.shape

(4586,)

# Sarimax

In [8]:
best_sarima_model = auto_arima(
    y=data[data.index < start_date],
    start_p=0,
    start_q=0,
    start_P=0,
    start_Q=0,
    m=12,
    seasonal=True,
)

print(best_sarima_model.summary())

KeyboardInterrupt: 

In [68]:
# create a list for storing the forecasts
sarima_forecasts = []

# loop across the dates
for t in tqdm(range(data.index.get_loc(start_date), data.index.get_loc(end_date) + 1)):

    # extract the training data
    context = data.iloc[:t]

    # train the model
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        sarima_model = SARIMAX(
            endog=context,
            order=best_sarima_model.order,
            seasonal_order=best_sarima_model.seasonal_order,
            trend="c" if best_sarima_model.with_intercept else None,
        ).fit(disp=0)

    # generate the one-step-ahead forecast
    sarima_forecast = sarima_model.get_forecast(steps=1)

    # save the forecast
    sarima_forecasts.append({
        "date": data.index[t],
        "actual": data.values[t],
        "mean": sarima_forecast.predicted_mean.item(),
        "std": sarima_forecast.var_pred_mean.item() ** 0.5,
    })

# cast the forecasts to data frame
sarima_forecasts = pd.DataFrame(sarima_forecasts)

100%|██████████| 759/759 [1:23:33<00:00,  6.61s/it]    


In [69]:
import plotly.graph_objects as go

fig = go.Figure()

# Add actual data trace
fig.add_trace(go.Scatter(
    x=sarima_forecasts["date"],
    y=sarima_forecasts["actual"],
    mode='lines',
    line=dict(color="#3f4751", width=1),
    name="Actual"
))

# Add predicted mean trace
fig.add_trace(go.Scatter(
    x=sarima_forecasts["date"],
    y=sarima_forecasts["mean"],
    mode='lines',
    line=dict(color="#ca8a04", width=1),
    name="Predicted"
))

# Add predicted +/- 1 standard deviation band
fig.add_trace(go.Scatter(
    x=sarima_forecasts["date"],
    y=sarima_forecasts["mean"] + sarima_forecasts["std"],
    fill=None,
    mode='lines',
    line=dict(color="#ca8a04", width=0.5),
    showlegend=False
))

fig.add_trace(go.Scatter(
    x=sarima_forecasts["date"],
    y=sarima_forecasts["mean"] - sarima_forecasts["std"],
    fill='tonexty',  # Fill to previous trace
    mode='lines',
    line=dict(color="#ca8a04", width=0.5),
    name="Predicted +/- 1 Std. Dev.",
    opacity=0.2
))

# Customize layout
fig.update_layout(
    title=f"SARIMA Forecast for {target_column}",
    xaxis_title="Time",
    yaxis_title="Value",
    legend=dict(x=1.05, y=1),
    margin=dict(l=50, r=50, t=50, b=50),
    template="plotly",
    width=1500,
    height=450
)

# Show plot
fig.show()

In [82]:
sarima_metrics = pd.DataFrame(
    columns=["Metric", "Value"],
    data=[
        {"Metric": "RMSE", "Value": root_mean_squared_error(
            y_true=sarima_forecasts["actual"], y_pred=sarima_forecasts["mean"])},
        {"Metric": "MAE", "Value": mean_absolute_error(
            y_true=sarima_forecasts["actual"], y_pred=sarima_forecasts["mean"])},
        {"Metric": "MAPE", "Value": mean_absolute_percentage_error(
            y_true=sarima_forecasts["actual"], y_pred=sarima_forecasts["mean"])},
        {"Metric": "SMAPE", "Value": np.mean(np.abs(sarima_forecasts["actual"] - sarima_forecasts["mean"]) / (
            np.abs(sarima_forecasts["actual"]) + np.abs(sarima_forecasts["mean"])) / 2)},
        {"Metric": "MSE", "Value": np.mean(
            (sarima_forecasts["actual"] - sarima_forecasts["mean"]) ** 2)},
    ]
).set_index("Metric")

In [83]:
sarima_metrics

Unnamed: 0_level_0,Value
Metric,Unnamed: 1_level_1
RMSE,39.979538
MAE,26.200492
MAPE,0.853567
SMAPE,0.066294
MSE,1598.363479


In [84]:
sarima_forecasts.head()

Unnamed: 0,date,actual,mean,std
0,2022-07-01,314.38,303.158149,17.09631
1,2022-07-02,218.92,281.91548,17.095132
2,2022-07-03,200.11,234.283706,17.122559
3,2022-07-04,293.89,232.698931,17.12911
4,2022-07-05,318.37,304.099594,17.15512


In [85]:
sarima_forecasts.tail()

Unnamed: 0,date,actual,mean,std
754,2024-07-24,66.61,75.859862,22.453644
755,2024-07-25,78.34,74.965748,22.451608
756,2024-07-26,93.04,78.533397,22.449213
757,2024-07-27,80.74,81.927899,22.447786
758,2024-07-28,43.96,80.922192,22.445344


## Chronos Pipeline

In [8]:
chronos_model = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-large",
    device_map="mps",
    torch_dtype=torch.bfloat16,
)

In [9]:
# save the start time
start_time = datetime.now()

# create a list for storing the forecasts
chronos_forecasts = []

transformers.set_seed(42)

NUM_SAMPLES = 200

# loop across the dates
for t in tqdm(range(data.index.get_loc(start_date), data.index.get_loc(end_date) + 1)):

    # extract the context window
    context = data.iloc[:t]

    # generate the one-step-ahead forecast

    chronos_forecast = chronos_model.predict(
        context=torch.from_numpy(context.values),
        prediction_length=1,
        num_samples=NUM_SAMPLES
    ).detach().cpu().numpy().flatten()

    # save the forecasts
    chronos_forecasts.append({
        "date": data.index[t],
        "actual": data.values[t],
        "mean": np.mean(chronos_forecast),
        "std": np.std(chronos_forecast, ddof=1),
    })

# cast the forecasts to data frame
chronos_forecasts = pd.DataFrame(chronos_forecasts)

# save the end time
end_time = datetime.now()

print(f"\nRunning time of Chronos model: {end_time - start_time}")

100%|██████████| 759/759 [3:24:42<00:00, 16.18s/it]  


Running time of Chronos model: 3:24:42.789220





In [10]:
context = data.iloc[:3827]
context

Date
2012-01-08     26.83
2012-01-09     47.91
2012-01-10     45.77
2012-01-11     47.83
2012-01-12     43.10
               ...  
2022-06-26    201.67
2022-06-27    316.65
2022-06-28    331.52
2022-06-29    315.54
2022-06-30    325.48
Name: Day_ahead_price (€/MWh), Length: 3827, dtype: float64

In [11]:
chronos_forecasts.head()

Unnamed: 0,date,actual,mean,std
0,2022-07-01,314.38,313.131665,12.720491
1,2022-07-02,218.92,267.421053,14.782506
2,2022-07-03,200.11,188.621953,13.640906
3,2022-07-04,293.89,264.571808,16.175186
4,2022-07-05,318.37,301.065168,13.961436


In [12]:
chronos_forecasts.tail()

Unnamed: 0,date,actual,mean,std
754,2024-07-24,66.61,84.920032,7.230547
755,2024-07-25,78.34,79.262749,8.652299
756,2024-07-26,93.04,77.795405,7.74718
757,2024-07-27,80.74,63.346303,8.276722
758,2024-07-28,43.96,67.684268,7.487943


In [13]:
import plotly.graph_objs as go
import plotly.io as pio

# Create traces for actual and predicted values
trace_actual = go.Scatter(
    x=chronos_forecasts["date"].values,
    y=chronos_forecasts["actual"].values,
    mode='lines',
    name='Actual',
    line=dict(color='#3f4751', width=1)
)

trace_predicted = go.Scatter(
    x=chronos_forecasts["date"].values,
    y=chronos_forecasts["mean"].values,
    mode='lines',
    name='Predicted',
    line=dict(color='#009ad3', width=1)
)

# Create traces for confidence intervals
trace_std_1 = go.Scatter(
    x=chronos_forecasts["date"].values,
    y=chronos_forecasts["mean"].values + chronos_forecasts["std"].values,
    mode='lines',
    name='Predicted +/- 1 Std. Dev.',
    line=dict(color='#009ad3', width=0),
    fill='tonexty',
    fillcolor='rgba(0, 154, 211, 0.2)'
)

trace_std_1_neg = go.Scatter(
    x=chronos_forecasts["date"].values,
    y=chronos_forecasts["mean"].values - chronos_forecasts["std"].values,
    mode='lines',
    line=dict(color='#009ad3', width=0),
    showlegend=False,
    fill='tonexty',
    fillcolor='rgba(0, 154, 211, 0.2)'
)

trace_std_2 = go.Scatter(
    x=chronos_forecasts["date"].values,
    y=chronos_forecasts["mean"].values + 2 * chronos_forecasts["std"].values,
    mode='lines',
    name='Predicted +/- 2 Std. Dev.',
    line=dict(color='#009ad3', width=0),
    fill='tonexty',
    fillcolor='rgba(0, 154, 211, 0.1)'
)

trace_std_2_neg = go.Scatter(
    x=chronos_forecasts["date"].values,
    y=chronos_forecasts["mean"].values - 2 * chronos_forecasts["std"].values,
    mode='lines',
    line=dict(color='#009ad3', width=0),
    showlegend=False,
    fill='tonexty',
    fillcolor='rgba(0, 154, 211, 0.1)'
)

# Create the figure with all the traces
fig = go.Figure()
fig.add_trace(trace_actual)
fig.add_trace(trace_predicted)
fig.add_trace(trace_std_2)
fig.add_trace(trace_std_2_neg)
fig.add_trace(trace_std_1)
fig.add_trace(trace_std_1_neg)

# Set layout options
fig.update_layout(
    xaxis_title='Time',
    yaxis_title='Value',
    legend=dict(x=1.05, y=1),
    margin=dict(t=20, b=20, l=20, r=20),
    width=1200,
    height=450
)

# Show the plot
pio.show(fig)

In [14]:
import pandas as pd
import numpy as np
from darts import TimeSeries
from darts.metrics import rmse, mae, mape, smape, mse

# Convert the chronos_forecasts DataFrame to Darts TimeSeries objects
actual_series = TimeSeries.from_series(chronos_forecasts["actual"])
predicted_series = TimeSeries.from_series(chronos_forecasts["mean"])

# Calculate the metrics using Darts functions
rmse_value = rmse(actual_series, predicted_series)
mae_value = mae(actual_series, predicted_series)
mape_value = mape(actual_series, predicted_series)
smape_value = smape(actual_series, predicted_series)
mse_value = mse(actual_series, predicted_series)

# Create the metrics DataFrame
chronos_metrics = pd.DataFrame(
    columns=["Metric", "Value"],
    data=[
        {"Metric": "RMSE", "Value": rmse_value},
        {"Metric": "MAE", "Value": mae_value},
        {"Metric": "MAPE", "Value": mape_value},
        {"Metric": "SMAPE", "Value": smape_value},
        {"Metric": "MSE", "Value": mse_value},
    ]
).set_index("Metric")

# Print the DataFrame
print(chronos_metrics)

              Value
Metric             
RMSE      34.862928
MAE       23.033619
MAPE      95.728846
SMAPE     23.087482
MSE     1215.423723


In [15]:
chronos_metrics

Unnamed: 0_level_0,Value
Metric,Unnamed: 1_level_1
RMSE,34.862928
MAE,23.033619
MAPE,95.728846
SMAPE,23.087482
MSE,1215.423723


In [16]:
# Save Chronos metrics as csv
chronos_metrics.to_csv(f'chronos_metrics_with_lags_{NUM_SAMPLES}.csv')