# TabNet and Interpretability

From the abstract of [TabNet](https://arxiv.org/pdf/1908.07442.pdf):

> "*TabNet uses [sequential attention](https://arxiv.org/pdf/1706.03762.pdf) to choose which features to reason from at each decision step, enabling interpretability*"

Here we shall look at the interpretability of TabNet in the context of the kaggle [Jane Street Market Prediction competition](https://www.kaggle.com/c/jane-street-market-prediction). For this notebook we shall be using the recently released version 3.0.0 of [pyTorch TabNet](https://github.com/dreamquark-ai/tabnet).

In [None]:
import numpy  as np
import pandas as pd

# plotting
import matplotlib.pyplot as plt
import plotly.express as px

!pip install -q pytorch-tabnet
from pytorch_tabnet.tab_model import TabNetRegressor

!pip install -q datatable 
import datatable as dt

import torch

Load in the data and create `X_train` and `y_train`. Here we shall be using regression, with `resp` as the target. A well known advantage of deep neural network is there is little need for *a priori* feature engineering, which is currently a key aspect in tree-based tabular data learning methods. In view of this we shall provide TabNet with access to all 130 features of this dataset.

For the training data we shall use the first 495 days of the `train.csv` file, and for the validation data we shall use the remaining 5 days.

In [None]:
# read in the train dataset
train_data = dt.fread('../input/jane-street-market-prediction/train.csv').to_pandas()

# filter out the zero weights
train_data = train_data.query('weight > 0').reset_index(drop = True)

# split for train and validation
validation_data = train_data.query('date   > 494').reset_index(drop = True) 
train_data      = train_data.query('date <=  494').reset_index(drop = True) 

# training data
X_train = train_data.loc[:, train_data.columns.str.contains('feature')]
X_train = X_train.fillna(X_train.mean()).to_numpy()
y_train = train_data.loc[:, 'resp'].to_numpy().reshape(-1, 1)

# validation data
X_valid = validation_data.loc[:, validation_data.columns.str.contains('feature')]
X_valid = X_valid.fillna(X_valid.mean()).to_numpy()
y_valid = validation_data.loc[:, 'resp'].to_numpy().reshape(-1, 1)

We shall now perform the regression.
pyTorch TabNet has two options for the masking function, `softmax` [[1]](https://arxiv.org/pdf/1602.02068.pdf) and 
`entmax`[[2](https://papers.nips.cc/paper/2004/file/96f2b50b5d3613adf9c27049b2a888c7-Paper.pdf)], where `gamma` is the coefficient for feature reusage in the masks. Here we shall use `entmax`.

In [None]:
%%time

# define the batch size, here 2^13
BS = 8192

# Training for more epochs might improve the model performance at the cost of longer training time
MAX_EPOCH = 20

regressor = TabNetRegressor(n_d=64, n_a=64, 
                            n_steps         =5, 
                            gamma           =1.2,
                            n_independent   =2, 
                            n_shared        =2,
                            lambda_sparse   =0., 
                            seed            =0,
                            clip_value      =1,
                            mask_type       ='entmax',
                            device_name     ='auto',
                            optimizer_fn=torch.optim.Adam,
                            optimizer_params=dict(lr=2e-3),
                            scheduler_params=dict(max_lr=0.05,
                                                  steps_per_epoch=int(X_train.shape[0] / BS),
                                                  epochs=MAX_EPOCH,
                                                  is_batch_level=True),
                            scheduler_fn=torch.optim.lr_scheduler.OneCycleLR,
                            verbose=1)

regressor.fit(X_train=X_train, y_train=y_train,
          eval_set=[(X_train, y_train), (X_valid, y_valid)],
          eval_name=["train", "valid"],
          eval_metric=["mae"],
          batch_size=BS,
          virtual_batch_size=256,
          max_epochs=MAX_EPOCH,
          drop_last=True,
          pin_memory=True)

In [None]:
explainability_matrix , masks = regressor.explain(X_valid)

# Normalize the importance by sample
normalized_explain_mat = np.divide(explainability_matrix, explainability_matrix.sum(axis=1).reshape(-1, 1)+1e-8)

# Add prediction to better understand correlation between features and predictions
val_preds = regressor.predict(X_valid)

explain_and_preds = np.hstack([normalized_explain_mat, val_preds.reshape(-1, 1)])

## Feature importance
As with the majority of estimators, TabNet provides access to a ranking of features in terms of their overall importance:

In [None]:
feat_importances = regressor.feature_importances_
indices = np.argsort(feat_importances)
# plot
fig, ax = plt.subplots(figsize=(10, 6))
plt.title("Top 25 feature importances")
plt.barh(range(len(feat_importances)), feat_importances[indices],color="b", align="center")
features = ['feature_{}'.format(i) for i in range(0, 130)]
plt.yticks(range(len(feat_importances)), [features[idx] for idx in indices])
# all features
# plt.ylim([-1, len(feat_importances)])
# Top 25 features
plt.ylim([len(feat_importances)-25, len(feat_importances)])
plt.show();

## Local interpretability
However, the beauty of TabNet is that it allows us to not only to obtain the overall feature importances, but also inspect the importance of each of the features for each of the individual rows, here for the validation data:

In [None]:
px.imshow(explain_and_preds[:,:],
          labels=dict(x="Features", y="Samples", color="Importance"),
          x=features+["prediction"],
          title="Sample wise feature importance",
          color_continuous_scale='Jet',
          height=1000)

It is particularly interesting to now see the daily variation in the Tag 22 features `feature_64`, `feature_65` and `feature_66`, as well as the daily importance of `feature_37` when the Tag 22 features are seemingly *'inactive'*. The most important message we can see here is that a simple overall ranking of features is not telling the whole story of what is really going on when predictions are being made.

We can also produce a correlation matrix for the importance of the features with respect to each other

In [None]:
correlation_importance = np.corrcoef(explain_and_preds.T)
px.imshow(correlation_importance,
          labels=dict(x="Features", y="Features", color="Correlation"),
          x=features+["prediction"], y=features+["prediction"],
          title="Correlation between attention mechanism for each feature and predictions",
          color_continuous_scale='Jet')

although as we can see, in this example the matrix seems somewhat uninformative.
# Related reading
* [Sercan O. Arik and Tomas Pfister "*TabNet: Attentive Interpretable Tabular Learning*", arXiv:1908.07442 (2019)](https://arxiv.org/pdf/1908.07442.pdf)
* [TabNet on AI Platform: High-performance, Explainable Tabular Learning](https://cloud.google.com/blog/products/ai-machine-learning/ml-model-tabnet-is-easy-to-use-on-cloud-ai-platform) (Google Cloud)
* [pytorch-tabnet](https://github.com/dreamquark-ai/tabnet) (GitHub)
* [Christoph Molnar, Giuseppe Casalicchio, and Bernd Bischl "*Interpretable Machine Learning -- A Brief History, State-of-the-Art and Challenges*", arXiv:2010.09337 (2020)](https://arxiv.org/pdf/2010.09337.pdf)

***See also***:

* [Jane Street: TabNet 3.0.0 starter notebook](https://www.kaggle.com/carlmcbrideellis/jane-street-tabnet-3-0-0-starter-notebook) - a simple notebook using TabNet classification for the Jane Street competition.