# TiDE
This notebook walks through how to use TiDE and benchmarks it against N-HiTS.

TiDE (Time-series Dense Encoder) is a pure DL encoder-decoder architecture. It is special in that the temporal decoder can help mitigate the effects of anomalous samples on a forecast (Fig. 4 in the paper).

See the original paper and model description here: [http://arxiv.org/abs/2304.08424](http://arxiv.org/abs/2304.08424).

In [None]:
# fix python path if working locally
from utils import fix_pythonpath_if_working_locally

fix_pythonpath_if_working_locally()
%matplotlib inline

In [None]:
import torch
import numpy as np
import pandas as pd
import shutil

from darts.models import NHiTSModel, TiDEModel
from darts.datasets import AusBeerDataset
from darts.dataprocessing.transformers.scaler import Scaler
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from darts.metrics import mae, mse

import matplotlib.pyplot as plt

np.random.seed(42)
torch.manual_seed(42)

# Model Parameter Setup
Boilerplate code is no fun, especially in the context of training multiple models to compare performance. To avoid this, common configuration values are stored to be re-used by the different models.

A few interesting things about these parameters:
1. Gradient clipping

This mitigate exploding gradients during backpropagation. This allows for us to set an upper limit on the gradient for a batch.

2. Learning rate 

The majority of the learning done by a model is in the earlier epochs. As training goes on it is often helpful to reduce the learning rate to fine-tine the model. That being said, it can also lead to significant overfitting.

3. Early stopping

To avoid overfitting the model, early stopping is used. This allows for the validation to be monitored and for training to be stopped based on some preset conditions.

In [None]:
optimizer_kwargs = {
    "lr": 1e-3,
}

pl_trainer_kwargs = {
    "gradient_clip_val": 1,
    "max_epochs": 200,
    "accelerator": "auto",
    "callbacks": [],
}

lr_scheduler_cls = torch.optim.lr_scheduler.ExponentialLR
lr_scheduler_kwargs = {
    "gamma": 0.999,
}

early_stopping_args = {
    "monitor": "val_loss",
    "patience": 10,
    "min_delta": 1e-3,
    "mode": "min",
}

common_model_args = {
    "input_chunk_length": 12,
    "output_chunk_length": 12,
    "optimizer_kwargs": optimizer_kwargs,
    "pl_trainer_kwargs": pl_trainer_kwargs,
    "lr_scheduler_cls": lr_scheduler_cls,
    "lr_scheduler_kwargs": lr_scheduler_kwargs,
    "likelihood": None,
    "save_checkpoints": True,
    "batch_size": 256,
}

# Data Loading and preparation
We consider the Australian quarterly beer sales in megaliters. 

Before training, we split the data into train, validation, and test sets. The model will learn from the train set, use the validation set to determine when to stop training, and finally be evaluated on the test set.

To avoid leaking information from the validation and test sets, we scale the data based on the properties of the test set.

In [None]:
series = AusBeerDataset().load()

train, temp = series.split_after(0.6)
val, test = temp.split_after(0.5)

In [None]:
train.plot(label="train")
val.plot(label="val")
test.plot(label="test")

In [None]:
scaler = Scaler()  # Scaler(StandardScaler())
train = scaler.fit_transform(train)
val = scaler.transform(val)
test = scaler.transform(test)

# Model configuration
Using the already established shared arguments, we can see that the default parameters for NHiTS and TiDE are used. The only exception is that TiDE is tested both with and without Reversible Instance Normalization.

We then iterate through the model dictionary and train all of the models. When using early stopping it is important to save checkpoints. This allows for us to continue past the best model configuration and then restore the optimal weights once training has completed.

In [None]:
model_nhits = NHiTSModel(
    **common_model_args,
)

model_tide = TiDEModel(
    **common_model_args,
    use_reversible_instance_norm=False,
)

model_tide_rin = TiDEModel(
    **common_model_args,
    use_reversible_instance_norm=True,
)

model_list = {
    "NHiTS": model_nhits,
    "TiDE": model_tide,
    "TiDE+RIN": model_tide_rin,
}

In [None]:
for model in model_list.values():

    # early stopping needs to get reset for each model
    pl_trainer_kwargs["callbacks"] = [
        EarlyStopping(
            **early_stopping_args,
        )
    ]

    model.fit(
        series=train,
        val_series=val,
        verbose=False,
    )

    model.load_from_checkpoint(model_name=model.model_name, best=True)

In [None]:
scaled_series = scaler.transform(series)
pred_steps = common_model_args["output_chunk_length"] * 2
pred_input = test[:-pred_steps]

In [None]:
fig, ax = plt.figure(figsize=(15, 5)), plt.gca()

pred_input.plot(label="input")


test[-pred_steps:].plot(label="ground truth", ax=ax)

result_accumulator = {}

for model_name, model in model_list.items():
    pred_series = model.predict(n=pred_steps, series=pred_input)
    pred_series.plot(label=model_name, ax=ax)

    result_accumulator[model_name] = {
        "mae": mae(scaled_series[-pred_steps:], pred_series),
        "mse": mse(scaled_series[-pred_steps:], pred_series),
    }

# Results
In this case, TiDE is shown to outperform NHiTS. Inclusion of reversible instance normalization (RIN) helps reduce the forecasting error as well; however, it is not always going to improve performance.

In [None]:
results_df = pd.DataFrame.from_dict(result_accumulator, orient="index")
results_df.plot.bar()

# Cleanup
A concern when using model checkpointing is that a significant amount of disk space can be used when training a large number of models. Be sure to cleanup when you no longer need your model artifacts!

In [17]:
try:
    shutil.rmtree("darts_logs")
except FileNotFoundError:
    pass