In [4]:
import os

if "temporal_fusion_transformer_pytorch" not in os.listdir():
    os.chdir("..")

In [None]:
import pickle
import warnings


import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger
from pytorch_lightning.loggers import TensorBoardLogger

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pathlib import Path
import pandas as pd
import numpy as np

from pytorch_forecasting.metrics import PoissonLoss, QuantileLoss, SMAPE
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
from pytorch_forecasting.utils import profile

In [None]:
from data import get_stallion_data

data = get_stallion_data()

data["month"] = data.date.dt.month
data["log_volume"] = np.log(data.volume + 1e-8)

data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
data["time_idx"] -= data["time_idx"].min()


max_prediction_length = 6
max_encoder_length = 2 * max_prediction_length
training_cutoff = data["time_idx"].max() - max_prediction_length

In [None]:
training = TimeSeriesDataSet(
    data[lambda x: x.time_idx < training_cutoff],
    time_idx="time_idx",
    target="volume",
    group_ids=["agency", "sku"],
    min_encoder_length=max_encoder_length,
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=["agency", "sku"],
    static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
    time_varying_known_categoricals=[
        "easter_day",
        "good_friday",
        "new_year",
        "christmas",
        "labor_day",
        "independence_day",
        "revolution_day_memorial",
        "regional_games",
        "fifa_u_17_world_cup",
        "football_gold_cup",
        "beer_capital",
        "music_fest",
    ],
    time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=["volume", "log_volume", "industry_volume", "soda_volume", "avg_max_temp"],
)


validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)

In [None]:
# configure network and trainer
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=15, verbose=False, mode="min")
lr_logger = LearningRateLogger()
logger = TensorBoardLogger("lightning_logs")

trainer = pl.Trainer(
    max_epochs=100,
    gpus=0,
    weights_summary="top",
    gradient_clip_val=0.1,
    early_stop_callback=early_stop_callback,
    limit_train_batches=20,
    # limit_val_batches=1,
    # fast_dev_run=True,
    # logger=logger,
    # profiler=True,
    callbacks=[lr_logger],
)


tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.05,
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    output_size=1,
    loss=SMAPE(log_space=True),
    log_interval=-1,
    reduce_on_plateau_patience=4,
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

In [None]:
# find optimal learning rate
res = trainer.lr_find(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, early_stop_threshold=1000.0, max_lr=0.3,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

In [None]:
# fit network
trainer.fit(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader,
)

# make a prediction on entire validation set
preds, index = tft.predict(val_dataloader, return_index=True, fast_dev_run=True)

In [None]:
# tune hypterparameters
study = optimize_hyperparameters(
    train_dataloader,
    val_dataloader,
    model_path="optuna_test",
    n_trials=15,
    max_epochs=30,
    gradient_clip_val_range=(0.01, 1.0),
    hidden_size_range=(16, 64),
    hidden_continuous_size_range=(8, 64),
    attention_head_size_range=(1, 4),
    dropout_range=(0.1, 0.3),
    learning_rate_range=(0.03, 0.03),
    trainer_kwargs=dict(limit_train_batches=20, logger=logger),
    reduce_on_plateau_patience=4,
)
with open("test_study.pickle", "wb") as fout:
    pickle.dump(study, fout)

In [None]:
# profile speed
profile(
    trainer.fit,
    profile_fname="profile.prof",
    model=tft,
    period=0.001,
    filter="pytorch_forecasting",
    train_dataloader=train_dataloader,
    val_dataloaders=val_dataloader,
)