In [None]:
import pandas as pd
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping
from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.metrics.quantile import QuantileLoss
from lightning.pytorch.loggers import TensorBoardLogger
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer
from pytorch_forecasting.data.encoders import EncoderNormalizer

In [None]:
# Define model

data = pd.read_csv('dataset/preliminaryData/bss_activity_meteorological_popular-hours.csv')

max_prediction_length = 7*6
max_encoder_length = 15*6
num_workers=32

training_cutoff = data["time_idx"].max() - max_prediction_length
data = data[lambda x: x.time_idx < data["time_idx"].max() - 30*4*6]

data["station"] = data["station"].astype(str)
data["month"] = data["month"].astype(str)
data["weekday"] = data["weekday"].astype(str)
data["is_weekend"] = data["is_weekend"].astype(str)
data["time_of_day"] = data["time_of_day"].astype(str)

training = TimeSeriesDataSet(
    data[lambda x: x.time_idx < training_cutoff],
    group_ids=["station"],
    target="activity",
    time_idx="time_idx",
    min_encoder_length=max_encoder_length // 2,
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    time_varying_unknown_reals=["activity", 'temp', 'humidity', 'avg_activity_by_station', 'log_activity'],
    static_categoricals=["station"],
    time_varying_known_categoricals=["weekday", "is_weekend", "time_of_day", "month"],
    time_varying_known_reals=["is_public_hours"],
    target_normalizer=EncoderNormalizer(transformation="softplus"),
    lags={"activity": [6, 6*7,6*365]},
    add_relative_time_idx=True,
    add_encoder_length=True,
    add_target_scales=True,
)

validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=num_workers)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=num_workers)

pl.seed_everything(42)

In [None]:
# Model tuning

import pickle
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

# create study
study = optimize_hyperparameters(
    train_dataloader,
    val_dataloader,
    model_path="study_tft",
    n_trials=50,
    max_epochs=50,
    gradient_clip_val_range=(0.01, 1.0),
    hidden_size_range=(8, 256),
    hidden_continuous_size_range=(8, 186),
    attention_head_size_range=(1, 4),
    learning_rate_range=(0.0001, 0.1),
    dropout_range=(0.1, 0.5),
    trainer_kwargs=dict(limit_train_batches=30, devices=1),
    reduce_on_plateau_patience=4,
    log_dir="study_tft",
    use_learning_rate_finder=True,
)

# save study results - also we can resume tuning at a later point in time
with open("study_tft.pkl", "wb") as fout:
    pickle.dump(study, fout)

# show best hyperparameters
print(study.best_trial.params)

In [None]:
# Model run

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=5, verbose=True, mode="min")
logger = TensorBoardLogger("lightning_logs")

trainer = pl.Trainer(
    max_epochs=45,
    accelerator='auto',
    enable_model_summary=True,
    gradient_clip_val=0.023,
    callbacks=[early_stop_callback],
    logger=logger)


tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.001,
    hidden_size=51,
    attention_head_size=4,
    dropout=0.11493,
    lstm_layers=2,
    hidden_continuous_size=31,
    loss=QuantileLoss(),
    log_interval=10,  # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
    optimizer="Ranger",
    reduce_on_plateau_patience=4,
)

trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader)

In [None]:
# Model evaluation and interpretability

from pytorch_forecasting.metrics.point import MAE

lags = list(training.lags.values())[0] 

best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

raw_predictions = best_tft.predict(val_dataloader, mode="raw", return_x=True)
for idx in range(10):
    fig = best_tft.plot_prediction(raw_predictions.x, raw_predictions.output, idx=idx, add_loss_to_title=True)
    fig.savefig(f"performance/tft/performance_{idx}.jpg")

predictions = best_tft.predict(val_dataloader, return_y=True)
mean_losses = MAE(reduction="none")(predictions.output, predictions.y).mean(1)
indices = mean_losses.argsort(descending=True)

for idx in range(10): 
    fig = best_tft.plot_prediction(
        raw_predictions.x,
        raw_predictions.output,
        idx=indices[idx],
        add_loss_to_title=MAE(quantiles=best_tft.loss.quantiles),
    )
    fig.savefig(f"performance/tft/worst_performance_{idx}.jpg")

interpretation = best_tft.interpret_output(raw_predictions.output, reduction="sum")
figs = best_tft.plot_interpretation(interpretation)
for key, value in figs.items():
    value.savefig(f"performance/tft/importance_{key}.jpg")

predictions = best_tft.predict(val_dataloader, return_x=True)
predictions_vs_actuals = best_tft.calculate_prediction_actual_by_variable(predictions.x, predictions.output)
all_features = list(set(predictions_vs_actuals['support'].keys())-set(["activity_lagged_by_" + str(x) for x in lags]))

figs = []
for feature in all_features:
    figs.append(best_tft.plot_prediction_actual_by_variable(predictions_vs_actuals, name=feature))

for key, value in enumerate(figs):
    value.savefig(f"performance/tft/_p_vs_a_{key}.jpg")