# TabNet Interpretability

### Interpretability

Interpretability is the extent to which a cause and effect can be observed in a system, and the model can be understood in human terms.  Interpretability allows us to understand what exactly a model is learning, what other information the model has to offer, and the justifications behind its decisions, and evaluate all of these in the context of the real-world problem we are trying to solve.
TabNet uses sequential attention to choose which features to reason from at each decision step, enabling interpretability. Feature selection is instance-wise, e.g. it can be different for each row of the training dataset.TabNet enables two kinds of interpretability: local interpretability that visualizes the importance of features and how they are combined for a single row, and global interpretability which quantifies the contribution of each feature to the trained model across the dataset.


Here we will look at the interpretability of TabNet using the Friedman3 dataset.

## Importing Libraries

In [None]:
import warnings
import sys
sys.path.insert(0, '../../src')
warnings.filterwarnings('ignore')

import numpy  as np
import pandas as pd

# plotting
import matplotlib.pyplot as plt
import plotly.express as px

from pytorch_tabnet.tab_model import TabNetRegressor
import torch
import friedman3Dataset
import dataset 
from sklearn.model_selection import train_test_split

## Importing the dataset

In [None]:
n_features = 4
n_samples= 100
n_target = 1
X,Y = friedman3Dataset.friedman3_data(n_samples)

# Train test split for dataset 
real_dataset = dataset.CustomDataset(X,Y)
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2)

y_train = np.reshape(y_train, (-1, 1))
y_test = np.reshape(y_test, (-1, 1))

## TabNet Regressor

In [None]:
n_epochs =1000
batch_size = 32

regressor = TabNetRegressor(optimizer_fn=torch.optim.Adam, 
                        optimizer_params=dict(lr = 0.001),
                        mask_type= 'sparsemax',
                        verbose = 1)  

regressor.fit(X_train = X_train,y_train = y_train, 
        eval_set=[(X_train, y_train), (X_test, y_test)],
        eval_name=['train', 'valid'], 
        eval_metric=[ 'mae'], 
        max_epochs = n_epochs, 
        batch_size = batch_size,
        patience=50)

In [None]:
explainability_matrix , masks = regressor.explain(X_test)

# Normalize the importance by sample
normalized_explain_mat = np.divide(explainability_matrix, explainability_matrix.sum(axis=1).reshape(-1, 1)+1e-8)

# Add prediction to better understand correlation between features and predictions
val_preds = regressor.predict(X_test)

# explain_and_preds = np.hstack([normalized_explain_mat, val_preds.reshape(-1, 1)])
explain_and_preds = np.hstack([normalized_explain_mat])

## Feature importance

As with the majority of estimators, TabNet provides access to a ranking of features in terms of their overall importance:

In [None]:
feat_importances = regressor.feature_importances_
indices = np.argsort(feat_importances)
# plot
fig, ax = plt.subplots(figsize=(10, 2))
plt.title("Overall feature importances")
plt.barh(range(len(feat_importances)), feat_importances[indices],color="b", align="center")
features = ['feature_{}'.format(i) for i in range(0, 4)]
plt.yticks(range(len(feat_importances)), [features[idx] for idx in indices])
# all features
# plt.ylim([-1, len(feat_importances)])
# Top 25 features
plt.ylim([len(feat_importances)-5, len(feat_importances)])
plt.show();

## Local interpretability

However, the beauty of TabNet is that it allows us to not only to obtain the overall feature importances, but also inspect the importance of each of the features for each of the individual rows, here for the validation data:

In [None]:
px.imshow(explain_and_preds[:,:],
          labels=dict(x="Features", y="Samples", color="Importance"),
          #x=features+["prediction"],
          title="Sample wise feature importance",
          color_continuous_scale='Jet',
          height=1000)

It is interesting to see the variation of certain feature importances along the rows. This explains more about how each feature behaves throughout the dataset, which cannot be brought out by the simple overall ranking of the features. 

We can also produce a correlation matrix for the importance of the features with respect to each other

In [None]:
explain_and_preds = np.hstack([normalized_explain_mat, val_preds.reshape(-1, 1)])
correlation_importance = np.corrcoef(explain_and_preds.T)
px.imshow(correlation_importance,
          labels=dict(x="Features", y="Features", color="Correlation"),
          x=features+["prediction"], y=features+["prediction"],
          title="Correlation between attention mechanism for each feature and predictions",
          color_continuous_scale='Jet')