# sample for pytorch forecasting
- c.f. https://towardsdatascience.com/introducing-pytorch-forecasting-64de99b9ef46

In [1]:
import numpy as np
import pandas as pd

In [2]:
from pytorch_forecasting.data.examples import get_stallion_data
data = get_stallion_data()  # load data as pandas dataframe


  iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec])
  iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec])
  iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec])


In [3]:
data.date.drop_duplicates().head(3)

  and should_run_async(code)


0      2013-01-01
7233   2013-02-01
9011   2013-03-01
Name: date, dtype: datetime64[ns]

In [4]:
# add time index
data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
data["time_idx"] -= data["time_idx"].min()
# add additional features
# categories have to be strings
data["month"] = data.date.dt.month.astype(str).astype("category")
data["log_volume"] = np.log(data.volume + 1e-8)
data["avg_volume_by_sku"] = (
    data
    .groupby(["time_idx", "sku"], observed=True)
    .volume.transform("mean")
)
data["avg_volume_by_agency"] = (
    data
    .groupby(["time_idx", "agency"], observed=True)
    .volume.transform("mean")
)
# we want to encode special days as one variable and 
# thus need to first reverse one-hot encoding
special_days = [
    "easter_day", "good_friday", "new_year", "christmas",
    "labor_day", "independence_day", "revolution_day_memorial",
    "regional_games", "fifa_u_17_world_cup", "football_gold_cup",
    "beer_capital", "music_fest"
]
data[special_days] = (
    data[special_days]
    .apply(lambda x: x.map({0: "-", 1: x.name}))
    .astype("category")
)
# show sample data
data.sample(10, random_state=521)

Unnamed: 0,agency,sku,volume,date,industry_volume,soda_volume,avg_max_temp,price_regular,price_actual,discount,...,football_gold_cup,beer_capital,music_fest,discount_in_percent,timeseries,time_idx,month,log_volume,avg_volume_by_sku,avg_volume_by_agency
291,Agency_25,SKU_03,0.5076,2013-01-01,492612703,718394219,25.845238,1264.162234,1152.473405,111.688829,...,-,-,-,8.835008,228,0,1,-0.678062,1225.306376,99.6504
871,Agency_29,SKU_02,8.748,2015-01-01,498567142,762225057,27.584615,1316.098485,1296.804924,19.293561,...,-,-,-,1.465966,177,24,1,2.168825,1634.434615,11.397086
19532,Agency_47,SKU_01,4.968,2013-09-01,454252482,789624076,30.665957,1269.25,1266.49049,2.75951,...,-,-,-,0.217413,322,8,9,1.603017,2625.472644,48.29565
2089,Agency_53,SKU_07,21.6825,2013-10-01,480693900,791658684,29.197727,1193.842373,1128.124395,65.717978,...,-,beer_capital,-,5.504745,240,9,10,3.076505,38.529107,2511.035175
9755,Agency_17,SKU_02,960.552,2015-03-01,515468092,871204688,23.60812,1338.334248,1232.128069,106.206179,...,-,-,music_fest,7.935699,259,26,3,6.867508,2143.677462,396.02214
7561,Agency_05,SKU_03,1184.6535,2014-02-01,425528909,734443953,28.668254,1369.556376,1161.135214,208.421162,...,-,-,-,15.218151,21,13,2,7.077206,1566.643589,1881.866367
19204,Agency_11,SKU_05,5.5593,2017-08-01,623319783,1049868815,31.915385,1922.486644,1651.307674,271.17897,...,-,-,-,14.105636,17,55,8,1.715472,1385.225478,109.6992
8781,Agency_48,SKU_04,4275.1605,2013-03-01,509281531,892192092,26.767857,1761.258209,1546.05967,215.198539,...,-,-,music_fest,12.218455,151,2,3,8.360577,1757.950603,1925.272108
2540,Agency_07,SKU_21,0.0,2015-10-01,544203593,761469815,28.987755,0.0,0.0,0.0,...,-,-,-,0.0,300,33,10,-18.420681,0.0,2418.71955
12084,Agency_21,SKU_03,46.3608,2017-04-01,589969396,940912941,32.47891,1675.922116,1413.571789,262.350327,...,-,-,-,15.654088,181,51,4,3.836454,2034.293024,109.3818


In [5]:
# data[special_days].apply(lambda x: x.map({0: "-", 1: x.name})).astype("category")
data[special_days].head(3)
# .apply(lambda x: print(type(x)))
# len(special_days)


  and should_run_async(code)


Unnamed: 0,easter_day,good_friday,new_year,christmas,labor_day,independence_day,revolution_day_memorial,regional_games,fifa_u_17_world_cup,football_gold_cup,beer_capital,music_fest
0,-,-,new_year,-,-,-,-,-,-,-,-,-
238,-,-,new_year,-,-,-,-,-,-,-,-,-
237,-,-,new_year,-,-,-,-,-,-,-,-,-


In [6]:
data["avg_population_2017"]

  and should_run_async(code)


0         48151
238       32769
237     1219986
236      135561
235     3044268
         ...   
6765      71662
6764    2180611
6763      48146
6771    2180611
6650    1901290
Name: avg_population_2017, Length: 21000, dtype: int64

In [7]:
data[["agency", "sku"]].drop_duplicates().head(3)

Unnamed: 0,agency,sku
0,Agency_22,SKU_01
238,Agency_37,SKU_04
237,Agency_59,SKU_03


In [8]:
from pytorch_forecasting.data import (
    TimeSeriesDataSet,
    GroupNormalizer
)
max_prediction_length = 6  # forecast 6 steps/months
max_encoder_length = 24  # use 24 steps/months of history
training_cutoff = data["time_idx"].max() - max_prediction_length

data_spec = dict(
    time_idx="time_idx",
    target="volume",
    group_ids=["agency", "sku"],
    target_normalizer=GroupNormalizer(
        groups=["agency", "sku"], coerce_positive=1.0
    ),  # use softplus with beta=1.0 and normalize by group
    static_categoricals=["agency", "sku"],
    static_reals=[
        "avg_population_2017",
        "avg_yearly_household_income_2017"
    ],
    time_varying_known_categoricals=["special_days", "month"],
    # group of categorical variables can be treated as 
    # one variable
    variable_groups={"special_days": special_days},
    time_varying_known_reals=[
        "time_idx",
        "price_regular",
        "discount_in_percent"
    ],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=[
        "volume",
        "log_volume",
        "industry_volume",
        "soda_volume",
        "avg_max_temp",
        "avg_volume_by_agency",
        "avg_volume_by_sku",
    ],
)

preprocess_spec = dict(
    add_relative_time_idx=True,  # add as feature
    add_target_scales=True,  # add as feature
    add_encoder_length=True,  # add as feature
)

prediction_spec = dict(
    min_encoder_length=0,  # allow predictions without history
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
)

training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= training_cutoff],
    **data_spec,
    **preprocess_spec,
    **prediction_spec,
)
# create validation set (predict=True) which means to predict the
# last max_prediction_length points in time for each series
validation = TimeSeriesDataSet.from_dataset(
    training, data, predict=True, stop_randomization=True
)
# create dataloaders for model
batch_size = 128
train_dataloader = training.to_dataloader(
    train=True, batch_size=batch_size, num_workers=0
)
val_dataloader = validation.to_dataloader(
    train=False, batch_size=batch_size * 10, num_workers=0
)

  and should_run_async(code)


In [9]:
training.get_parameters()

  and should_run_async(code)


{'time_idx': 'time_idx',
 'target': 'volume',
 'group_ids': ['agency', 'sku'],
 'weight': None,
 'max_encoder_length': 24,
 'min_encoder_length': 24,
 'min_prediction_idx': 0,
 'min_prediction_length': 1,
 'max_prediction_length': 6,
 'static_categoricals': ['agency', 'sku'],
 'static_reals': ['avg_population_2017',
  'avg_yearly_household_income_2017',
  'decoder_length',
  'volume_center',
  'volume_scale'],
 'time_varying_known_categoricals': ['special_days', 'month'],
 'time_varying_known_reals': ['time_idx',
  'price_regular',
  'discount_in_percent',
  'relative_time_idx'],
 'time_varying_unknown_categoricals': [],
 'time_varying_unknown_reals': ['volume',
  'log_volume',
  'industry_volume',
  'soda_volume',
  'avg_max_temp',
  'avg_volume_by_agency',
  'avg_volume_by_sku'],
 'variable_groups': {'special_days': ['easter_day',
   'good_friday',
   'new_year',
   'christmas',
   'labor_day',
   'independence_day',
   'revolution_day_memorial',
   'regional_games',
   'fifa_u_17_wo

In [10]:
training.data

{'reals': tensor([[-0.9593, -0.6123,  0.0000,  ..., -2.9171, -1.0676,  1.0738],
         [-0.9593, -0.6123,  0.0000,  ..., -2.1644, -1.0561,  1.3626],
         [-0.9593, -0.6123,  0.0000,  ..., -0.9712, -1.0254,  1.6461],
         ...,
         [ 1.2221,  1.2074,  0.0000,  ..., -0.3227,  1.0434, -1.4095],
         [ 1.2221,  1.2074,  0.0000,  ...,  0.0315,  1.5186, -1.4097],
         [ 1.2221,  1.2074,  0.0000,  ...,  0.3235,  1.4691, -1.4095]]),
 'categoricals': tensor([[ 0,  0,  0,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  0,  4],
         [ 0,  0,  3,  ...,  0,  7,  5],
         ...,
         [57, 17,  3,  ...,  0,  0,  6],
         [57, 17,  0,  ...,  0,  0,  7],
         [57, 17,  0,  ...,  0,  0,  8]]),
 'groups': tensor([[ 0,  0],
         [ 0,  0],
         [ 0,  0],
         ...,
         [57, 17],
         [57, 17],
         [57, 17]]),
 'target': tensor([ 80.6760,  98.0640, 133.7040,  ...,   1.2600,   0.0000,   2.5200]),
 'time': tensor([ 0,  1,  2,  ..., 51, 52, 5

In [11]:
training.data["reals"].shape
training.data["groups"].shape


torch.Size([18900, 2])

In [12]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateLogger
)
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.models import TemporalFusionTransformer
# stop training, when loss metric does not improve on validation set
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=1e-4,
    patience=10,
    verbose=False,
    mode="min"
)
lr_logger = LearningRateLogger()  # log the learning rate
logger = TensorBoardLogger("../result/lightning_logs")  # log to tensorboard
# create trainer
trainer = pl.Trainer(
    max_epochs=30,
    gpus=0,  # train on CPU, use gpus = [0] to run on GPU
    # gpus=[1],  # for GPU
    gradient_clip_val=0.1,
    early_stop_callback=early_stop_callback,
    limit_train_batches=30,  # running validation every 30 batches
    # fast_dev_run=True,  # comment in to quickly check for bugs
    callbacks=[lr_logger],
    logger=logger,
)
# initialise model
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,  # biggest influence network size
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    output_size=7,  # QuantileLoss has 7 quantiles by default
    loss=QuantileLoss(),
    log_interval=10,  # log example every 10 batches
    reduce_on_plateau_patience=4,  # reduce learning automatically
)
print(tft.size())   # 29.6k parameters in model
# fit network
trainer.fit(
    tft,
    train_dataloader=train_dataloader,
    val_dataloaders=val_dataloader
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

   | Name                               | Type                            | Params
----------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0     
1  | input_embeddings                   | ModuleDict                      | 1 K   
2  | prescalers                         | ModuleDict                      | 256   
3  | static_variable_selection          | VariableSelectionNetwork        | 3 K   
4  | encoder_variable_selection         | VariableSelectionNetwork        | 8 K   
5  | decoder_variable_selection         | VariableSelectionNetwork        | 2 K   
6  | static_context_variable_selection  | GatedResidualNetwork            | 1 K   
7  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 1 K   
8  | static_context_initial_cell_lstm   | GatedResidualNetwork            | 1 K   
9  | 

1

In [13]:
import torch

  and should_run_async(code)


In [14]:
from pytorch_forecasting.metrics import MAE
# load the best model according to the validation loss (given that
# we use early stopping, this is not necessarily the last epoch)
best_model_path = trainer.checkpoint_callback.best_model_path
print("best_model_path:", best_model_path)
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
# calculate mean absolute error on validation set
actuals = torch.cat([y for x, y in iter(val_dataloader)])
predictions = best_tft.predict(val_dataloader)
MAE()(predictions, actuals)

best_model_path: ../result/lightning_logs/default/version_7/checkpoints/epoch=27.ckpt


tensor(264.8931)