In [1]:
import os
from functools import partial

import pandas as pd
import statsforecast
from statsforecast import StatsForecast
from statsforecast.feature_engineering import mstl_decomposition
from statsforecast.models import ARIMA, MSTL
from utilsforecast.evaluation import evaluate
from utilsforecast.losses import smape, mase


  from tqdm.autonotebook import tqdm


In [2]:
os.environ['NIXTLA_ID_AS_COL'] = '1'

In [6]:
df = pd.read_parquet('https://datasets-nixtla.s3.amazonaws.com/m4-hourly.parquet')
uids = df['unique_id'].unique()[:10]
df = df[df['unique_id'].isin(uids)]
df.head()


Unnamed: 0,unique_id,ds,y
0,H1,1,605.0
1,H1,2,586.0
2,H1,3,586.0
3,H1,4,559.0
4,H1,5,511.0


In [7]:
freq = 1
season_length = 24
horizon = 2 * season_length
valid = df.groupby('unique_id').tail(horizon)
train = df.drop(valid.index)
model = MSTL(season_length=24)
transformed_df, X_df = mstl_decomposition(train, model=model, freq=freq, h=horizon)



In [8]:
transformed_df.head()


Unnamed: 0,unique_id,ds,y,trend,seasonal
0,H1,1,605.0,502.87291,131.419934
1,H1,2,586.0,507.873456,93.100015
2,H1,3,586.0,512.822533,82.155386
3,H1,4,559.0,517.717481,42.412749
4,H1,5,511.0,522.555849,-11.40189


In [9]:
X_df.head()


Unnamed: 0,unique_id,ds,trend,seasonal
0,H1,701,643.801348,-29.189627
1,H1,702,644.328207,-99.680432
2,H1,703,644.749693,-141.169014
3,H1,704,645.086883,-173.325625
4,H1,705,645.356634,-195.86253


In [10]:
sf = StatsForecast(
    models=[ARIMA(order=(1, 0, 1), season_length=season_length)],
    freq=freq
)
preds = sf.forecast(h=horizon, df=transformed_df, X_df=X_df)
preds.head()

Unnamed: 0,unique_id,ds,ARIMA
0,H1,701,612.737671
1,H1,702,542.851807
2,H1,703,501.931824
3,H1,704,470.24826
4,H1,705,448.115814


In [11]:
transformed_df

Unnamed: 0,unique_id,ds,y,trend,seasonal
0,H1,1,605.0,502.872910,131.419934
1,H1,2,586.0,507.873456,93.100015
2,H1,3,586.0,512.822533,82.155386
3,H1,4,559.0,517.717481,42.412749
4,H1,5,511.0,522.555849,-11.401890
...,...,...,...,...,...
6995,H107,696,4708.0,3947.720625,676.891540
6996,H107,697,4451.0,3955.741399,530.573828
6997,H107,698,4303.0,3963.834683,382.338985
6998,H107,699,4207.0,3971.979313,274.809658
