# Temporal Fusion Transformer (TFT) - Keras implementation

Paper: https://arxiv.org/pdf/1912.09363.pdf

![](https://github.com/rohanmohapatra/tft-transformer-keras/blob/master/images/optimal_params.png?raw=True)
---
## [![GitHub](https://img.shields.io/badge/GitHub-000?style=flat&logo=github&link=https://github.com/tkostas/tft-transformer-keras)](https://github.com/tkostas/tft-transformer-keras) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rohanmohapatra/tft-transformer-keras/blob/master/TFT_Transformer_Keras.ipynb)



## Download the Project to use Dependencies

In [None]:
! git clone https://github.com/rohanmohapatra/tft-transformer-keras.git
%cd tft-transformer-keras

## Install Requirements

In [None]:
!pip install -r requirements.txt

## Import packages

In [None]:
from src.config import Config
from src.datasets.managers import datasets_factory
from src.eval import quick_evaluation
from src.plots import plot_history, plot_examples
from src.train_utils import build_model, save_weights_and_inference_model

## Auxilary Class Definitions

In [None]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

## Define Model Optimal Parameters for
- Electricity
- Traffic
- Favorita

#### Change the dataset to use the required optimal config


|                       |Electricity|Traffic|Retail|Vol.  |
|-----------------------|-----------|-------|------|------|
|**Dataset Details**    |           |       |      |      |
|Target Type            |R          |[0, 1] |R     |R     |
|Number of Entities     |370        |440    |130k  |41    |
|Number of Samples      |500k       |500k   |500k  |~100k |
|**Network Parameters** |           |       |      |      |
|k                      |168        |168    |90    |252   |
|Tmar                   |24         |24     |30    |5     |
|Dropout Rate           |0.1        |0.3    |0.1   |0.3   |
|State Size             |160        |320    |240   |160   |
|Number of Heads        |4          |4      |4     |1     |
|**Training Parameters**|           |       |      |      |
|Minibatch Size         |64         |128    |128   |64    |
|Learning Rate          |0.001      |0.001  |0.001 |0.01  |
|Max Gradient Norm      |0.01       |100    |100   |0.01  |

In [None]:
config = {
    'dataset': 'electricity',
    'model_version': 1,
    'log_dir': 'logs',
    'optimizer': 'adamw',
    'load_model_weights': None,
    'sample_sz': 0,
    'l2_reg': None,
    'clip_value': None,
    'masked_value': None,
    'n_samples_to_plot': 50,
    'plot_attn_weights': False
}

In [None]:
optimal_params = {
    'electricity': {
        'n_enc_steps': 168, # timeseries length defined as n_enc_steps + n_dec_steps
        'n_dec_steps': 24,
        'd_model': 160,
        'epochs': 50,
        'batch_sz': 64,
        'lr': 0.001,
        'clip_norm': 0.01,
        'dropout_rate': 0.1
    },
    'traffic': {
        'n_enc_steps': 168, # timeseries length defined as n_enc_steps + n_dec_steps
        'n_dec_steps': 24,
        'd_model': 320,
        'epochs': 100,
        'batch_sz': 128,
        'lr': 0.001,
        'clip_norm': 100,
        'dropout_rate': 0.3
    },
    'favorita': {
        'n_enc_steps': 90, # timeseries length defined as n_enc_steps + n_dec_steps
        'n_dec_steps': 30,
        'd_model': 240,
        'epochs': 50,
        'batch_sz': 128,
        'lr': 0.001,
        'clip_norm': 0.01,
        'dropout_rate': 0.1
    }
}

In [None]:
config.update(optimal_params[config['dataset']])

## Setup

In [None]:
config = Config(AttrDict(config))

In [None]:
dataset_name = config.dataset
model_version = config.model_version
dset = datasets_factory[dataset_name](ts_len=config.ts_len,
                                      n_enc_steps=config.n_enc_steps,
                                      sample_sz=config.sample_sz)

## Train

In [None]:
(x_train, y_train), (x_val, y_val), (x_test, y_test) = dset.extract_xy_pairs(
                                                            ts_len=config.ts_len,
                                                            n_dec_steps=config.n_dec_steps)

callbacks, model = build_model(config, dset, x_train)

history = model.fit(x_train, y_train,
                    batch_size=config.batch_sz,
                    epochs=config.epochs,
                    validation_data=(x_val, y_val),
                    shuffle=True,
                    callbacks=callbacks)

plot_history(history, dataset_name, model_version)

save_weights_and_inference_model(model, config=config)

quick_evaluation(model,
                  x_test, x_train,
                  x_val, y_test,
                  y_train, y_val,
                  config)

plot_examples(model, x_train, y_train,
              quantiles=config.quantiles,
              dataset_name=dataset_name,
              tag='train',
              plot_n_samples=config.n_samples_to_plot,
              plot_attn_weights=config.plot_attn_weights)
plot_examples(model, x_val, y_val,
              quantiles=config.quantiles,
              dataset_name=dataset_name,
              tag='val',
              plot_n_samples=config.n_samples_to_plot,
              plot_attn_weights=config.plot_attn_weights)
plot_examples(model, x_test, y_test,
              quantiles=config.quantiles,
              dataset_name=dataset_name,
              tag='test',
              plot_n_samples=config.n_samples_to_plot,
              plot_attn_weights=config.plot_attn_weights)