# Chronos

In [13]:
import warnings
import transformers
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
from chronos import ChronosPipeline
from pmdarima.arima import auto_arima
from statsmodels.tsa.statespace.sarimax import SARIMAX
from tqdm import tqdm
from sklearn.metrics import mean_absolute_percentage_error, mean_absolute_error, root_mean_squared_error
from darts import TimeSeries


In [14]:
# Import the data 
def load_and_prepare_data(file_path):
    """
    Load energy prices data from a CSV file, ensure chronological order, and convert 'Date' to datetime.
    """
    df = pd.read_csv(file_path)
    df.sort_values('Date', inplace=True)
    df.set_index('Date', inplace=True)
    df = pd.DataFrame(df)
    return df

In [15]:
train_df = load_and_prepare_data('../../data/Final_data/train_df.csv')
test_df = load_and_prepare_data('../../data/Final_data/test_df.csv')

train_df.reset_index(inplace=True)
test_df.reset_index(inplace=True)


In [27]:
test_df

Unnamed: 0,Date,Day_ahead_price (€/MWh),Solar_radiation (W/m2),Wind_speed (m/s),Temperature (°C),Biomass (GWh),Hard_coal (GWh),Hydro (GWh),Lignite (GWh),Natural_gas (GWh),...,Lag_1_day,Lag_2_days,Lag_3_days,Lag_4_days,Lag_5_days,Lag_6_days,Lag_7_days,Day_of_week,Month,Rolling_mean_7
0,2022-07-01,314.38,127.46,3.42,15.87,111.601,168.773,52.930,332.065,182.244,...,325.48,315.54,331.52,316.65,201.67,240.39,274.30,4,7,292.23
1,2022-07-02,218.92,339.67,2.75,18.16,112.222,115.877,46.251,258.776,91.636,...,314.38,325.48,315.54,331.52,316.65,201.67,240.39,5,7,289.17
2,2022-07-03,200.11,318.04,2.86,21.03,112.037,90.515,41.357,277.929,85.647,...,218.92,314.38,325.48,315.54,331.52,316.65,201.67,6,7,288.94
3,2022-07-04,293.89,286.96,3.08,19.14,112.874,140.425,43.641,311.034,123.629,...,200.11,218.92,314.38,325.48,315.54,331.52,316.65,0,7,285.69
4,2022-07-05,318.37,287.62,3.08,18.56,112.856,182.225,50.027,302.135,146.554,...,293.89,200.11,218.92,314.38,325.48,315.54,331.52,1,7,283.81
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
754,2024-07-24,66.61,225.04,3.47,17.54,110.007,43.469,85.857,199.246,194.291,...,79.62,88.75,58.45,59.32,86.47,90.75,76.79,2,7,75.71
755,2024-07-25,78.34,272.71,2.12,17.85,110.410,50.676,82.632,195.983,209.610,...,66.61,79.62,88.75,58.45,59.32,86.47,90.75,3,7,73.94
756,2024-07-26,93.04,172.33,2.60,19.09,110.852,42.333,79.531,205.273,205.773,...,78.34,66.61,79.62,88.75,58.45,59.32,86.47,4,7,74.88
757,2024-07-27,80.74,176.67,2.05,19.63,110.479,33.307,74.958,184.012,216.412,...,93.04,78.34,66.61,79.62,88.75,58.45,59.32,5,7,77.94


In [16]:
# Import the data
df = load_and_prepare_data('../../data/Final_data/final_data_july.csv')

# Reset the index
df = df.reset_index()

target_column = "Oil_price (EUR)"

In [28]:
# date of first forecast
start_date = "2022-07-01"

# date of last forecast
end_date = "2024-07-28"

In [29]:
# create a TimeSeries object from df
data = df[['Date', target_column]]
# transform to a pandas series
data = data.set_index('Date')
data = data.squeeze()
data

Date
2012-01-08    103.71
2012-01-09    103.64
2012-01-10    104.22
2012-01-11    103.93
2012-01-12    102.26
               ...  
2024-07-24     75.75
2024-07-25     76.36
2024-07-26     75.21
2024-07-27     74.79
2024-07-28     74.37
Name: Oil_price (EUR), Length: 4586, dtype: float64

In [30]:
data.shape

(4586,)

# Sarimax

In [31]:
best_sarima_model = auto_arima(
    y=data[data.index < start_date],
    start_p=0,
    start_q=0,
    start_P=0,
    start_Q=0,
    m=12,
    seasonal=True,
)

print(best_sarima_model.summary())

                                     SARIMAX Results                                      
Dep. Variable:                                  y   No. Observations:                 3827
Model:             SARIMAX(0, 1, 1)x(1, 0, 1, 12)   Log Likelihood               -5203.850
Date:                            Mon, 14 Oct 2024   AIC                          10415.700
Time:                                    12:42:36   BIC                          10440.698
Sample:                                01-08-2012   HQIC                         10424.581
                                     - 06-30-2022                                         
Covariance Type:                              opg                                         
                 coef    std err          z      P>|z|      [0.025      0.975]
------------------------------------------------------------------------------
ma.L1          0.1434      0.008     17.518      0.000       0.127       0.159
ar.S.L12      -0.8406      0.055   

In [32]:
# create a list for storing the forecasts
sarima_forecasts = []

# loop across the dates
for t in tqdm(range(data.index.get_loc(start_date), data.index.get_loc(end_date) + 1)):

    # extract the training data
    context = data.iloc[:t]

    # train the model
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        sarima_model = SARIMAX(
            endog=context,
            order=best_sarima_model.order,
            seasonal_order=best_sarima_model.seasonal_order,
            trend="c" if best_sarima_model.with_intercept else None,
        ).fit(disp=0)

    # generate the one-step-ahead forecast
    sarima_forecast = sarima_model.get_forecast(steps=1)

    # save the forecast
    sarima_forecasts.append({
        "date": data.index[t],
        "actual": data.values[t],
        "mean": sarima_forecast.predicted_mean.item(),
        "std": sarima_forecast.var_pred_mean.item() ** 0.5,
    })

# cast the forecasts to data frame
sarima_forecasts = pd.DataFrame(sarima_forecasts)

100%|██████████| 759/759 [07:52<00:00,  1.61it/s]


In [43]:
import plotly.graph_objects as go

fig = go.Figure()

# Add actual data trace
fig.add_trace(go.Scatter(
    x=sarima_forecasts["date"],
    y=sarima_forecasts["actual"],
    mode='lines',
    line=dict(color="#3f4751", width=1),
    name="Actual"
))

# Add predicted mean trace
fig.add_trace(go.Scatter(
    x=sarima_forecasts["date"],
    y=sarima_forecasts["mean"],
    mode='lines',
    line=dict(color="#ca8a04", width=1),
    name="Predicted"
))

# Add predicted +/- 1 standard deviation band
fig.add_trace(go.Scatter(
    x=sarima_forecasts["date"],
    y=sarima_forecasts["mean"] + sarima_forecasts["std"],
    fill=None,
    mode='lines',
    line=dict(color="#ca8a04", width=0.5),
    showlegend=False
))

fig.add_trace(go.Scatter(
    x=sarima_forecasts["date"],
    y=sarima_forecasts["mean"] - sarima_forecasts["std"],
    fill='tonexty',  # Fill to previous trace
    mode='lines',
    line=dict(color="#ca8a04", width=0.5),
    name="Predicted +/- 1 Std. Dev.",
    opacity=0.2
))
'''
# Add predicted +/- 2 standard deviations band
fig.add_trace(go.Scatter(
    x=sarima_forecasts["date"],
    y=sarima_forecasts["mean"] + 2 * sarima_forecasts["std"],
    fill=None,
    mode='lines',
    line=dict(color="#ca8a04", width=0.5),
    showlegend=False
))


fig.add_trace(go.Scatter(
    x=sarima_forecasts["date"],
    y=sarima_forecasts["mean"] - 2 * sarima_forecasts["std"],
    fill='tonexty',  # Fill to previous trace
    mode='lines',
    line=dict(color="#ca8a04", width=0.5),
    name="Predicted +/- 2 Std. Dev.",
    opacity=0.1
))
'''
# Customize layout
fig.update_layout(
    title="SARIMA Forecast for Oil Price",
    xaxis_title="Time",
    yaxis_title="Value",
    legend=dict(x=1.05, y=1),
    margin=dict(l=50, r=50, t=50, b=50),
    template="plotly_white",
    width=800,
    height=450
)

# Show plot
fig.show()


In [44]:
sarima_metrics = pd.DataFrame(
    columns=["Metric", "Value"],
    data=[
        {"Metric": "RMSE", "Value": root_mean_squared_error(y_true=sarima_forecasts["actual"], y_pred=sarima_forecasts["mean"])},
        {"Metric": "MAE", "Value": mean_absolute_error(y_true=sarima_forecasts["actual"], y_pred=sarima_forecasts["mean"])},
    ]
).set_index("Metric")

In [45]:
sarima_metrics

Unnamed: 0_level_0,Value
Metric,Unnamed: 1_level_1
RMSE,1.171477
MAE,0.764133


In [39]:
sarima_forecasts.head()

Unnamed: 0,date,actual,mean,std
0,2022-07-01,100.14,105.764671,0.942853
1,2022-07-02,100.88,99.358515,0.947096
2,2022-07-03,101.61,101.196429,0.947291
3,2022-07-04,102.34,101.59519,0.94719
4,2022-07-05,93.72,102.676857,0.947146


In [40]:
sarima_forecasts.tail()

Unnamed: 0,date,actual,mean,std
754,2024-07-24,75.75,74.912822,0.985543
755,2024-07-25,76.36,75.870638,0.985513
756,2024-07-26,75.21,76.430252,0.985432
757,2024-07-27,74.79,75.033324,0.985489
758,2024-07-28,74.37,74.746351,0.985388


## Chronos Pipeline

In [37]:
chronos_model = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-large",
    device_map="mps",
    torch_dtype=torch.bfloat16,
)

In [38]:
# save the start time
start_time = datetime.now()

# create a list for storing the forecasts
chronos_forecasts = []

# loop across the dates
for t in tqdm(range(data.index.get_loc(start_date), data.index.get_loc(end_date) + 1)):

    # extract the context window
    context = data.iloc[:t]

    # generate the one-step-ahead forecast
    transformers.set_seed(42)
    chronos_forecast = chronos_model.predict(
        context=torch.from_numpy(context.values),
        prediction_length=1,
        num_samples=100
    ).detach().cpu().numpy().flatten()

    # save the forecasts
    chronos_forecasts.append({
        "date": data.index[t],
        "actual": data.values[t],
        "mean": np.mean(chronos_forecast),
        "std": np.std(chronos_forecast, ddof=1),
    })

# cast the forecasts to data frame
chronos_forecasts = pd.DataFrame(chronos_forecasts)

# save the end time
end_time = datetime.now()

print(f"\nRunning time of Chronos model: {end_time - start_time}")

  1%|          | 9/759 [00:55<1:16:46,  6.14s/it]


KeyboardInterrupt: 

In [None]:
chronos_forecasts.shape

In [None]:
chronos_forecasts.head()

In [None]:
fig, ax = plt.subplots(figsize=(8, 4.5))
ax.plot(chronos_forecasts["date"].values, chronos_forecasts["actual"].values, color="#3f4751", lw=1, label="Actual")
ax.plot(chronos_forecasts["date"].values, chronos_forecasts["mean"].values, color="#009ad3", lw=1, label="Predicted")
ax.fill_between(chronos_forecasts["date"].values, chronos_forecasts["mean"].values + chronos_forecasts["std"].values, chronos_forecasts["mean"].values - chronos_forecasts["std"].values, color="#009ad3", alpha=0.2, lw=1, label="Predicted +/- 1 Std. Dev.")
ax.fill_between(chronos_forecasts["date"].values, chronos_forecasts["mean"].values + 2 * chronos_forecasts["std"].values, chronos_forecasts["mean"].values - 2 * chronos_forecasts["std"].values, color="#009ad3", alpha=0.1, lw=1, label="Predicted +/- 2 Std. Dev.")
ax.set(xlabel="Time", ylabel="Value")
ax.xaxis.set_tick_params(labelbottom=True)
ax.tick_params(axis="both", which="major", labelsize=7)
ax.tick_params(axis="both", which="minor", labelsize=7)
fig.legend(bbox_to_anchor=(1, 0, 0.3, 1), frameon=False)
fig.tight_layout()
fig.show()


In [None]:
chronos_metrics = pd.DataFrame(
    columns=["Metric", "Value"],
    data=[
        {"Metric": "RMSE", "Value": root_mean_squared_error(y_true=chronos_forecasts["actual"], y_pred=chronos_forecasts["mean"])},
        {"Metric": "MAE", "Value": mean_absolute_error(y_true=chronos_forecasts["actual"], y_pred=chronos_forecasts["mean"])},
    ]
).set_index("Metric")

In [None]:
chronos_metrics

In [None]:
test_forecast = pipeline.predict(
    context=test_target,  # Test target data
    prediction_length=30,  # Number of days ahead you want to predict
    num_samples=100,  # Number of samples
    #feat_dynamic_real=test_covariates.T  # Test covariates
)
