# Tutorial 1 - SimplEx for Tabular Data

In this tutorial we we create a simplex explainer object and use it to explain a test record. The explainer is then saved to disk and can be given to someone else to view in the Interpretability App. (TODO: add link to app).

We will be explaining the predictions of pytorch multi-layer perceptron that we have trained and saved separately on the iris dataset from sci-kit learn. The Interpretability.models module provides a few pytorch models that are compatible with trained models `state_dict`s available on the Google Drive link below.

### Import the relevant modules

In [1]:
# IMPORTS
# Standard
import os
import pathlib

# Third Party
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
import torch
import pandas as pd

# Interpretability
from interpretability.interpretability_models import simplex_explainer
from interpretability.interpretability_models.utils import io
from interpretability.models.multilayer_perceptron import IrisMLP # This is the class of the model we have already trained

### Load the data 
Load the data and split it into the corpus of examples used for explanation and the test examples we will explain.

In [2]:
# Load the data
X, y = load_iris(return_X_y=True, as_frame=True)

# Get feature names
feature_names = X.columns.to_list()

# Split the data
X_corpus, X_test, y_corpus, y_test = train_test_split(X, y, test_size=0.2)



### Download the trained model from Google Drive

You could train your own model using the IrisMLP class and load it here, but we have trained one already.

Download the model using this link: https://drive.google.com/file/d/1MbQX1PYABB4XNO9c_SR-Mo3i6HjU0hB-/view?usp=sharing and save it in a location matching the path `TRAINED_MODEL_STATE_PATH` below. The default is the desktop.

### Load the model

In [3]:
## Load the model
model = IrisMLP(n_cont=4, input_feature_num=len(feature_names))

def load_trained_model(model, trained_model_state_path):
    model.load_state_dict(torch.load(trained_model_state_path))
    model.eval()
    return model

desktop_path = pathlib.Path.home() / 'Desktop'

TRAINED_MODEL_STATE_PATH = os.path.join(desktop_path, "model_cv1.pth")
model = load_trained_model(model, TRAINED_MODEL_STATE_PATH)



### Initialize SimplEX
Initialize the explainer object by passing the predictive model and corpus.

In [4]:
my_explainer = simplex_explainer.SimplexTabluarExplainer(
    model,
    X_corpus,
    y_corpus,
    estimator_type="classifier",
    feature_names=feature_names,
    corpus_size=100,
    device="cpu",
)

### Fit the explainer

Fit the explainer on the test data. This makes explanations of the test data available in the subsequent step.

In [5]:
my_explainer.fit(X_test, y_test, n_epochs=10000)

Weight Fitting Epoch: 2000/10000 ; Error: 400 ; Regulator: 24.3 ; Reg Factor: 1
Weight Fitting Epoch: 4000/10000 ; Error: 118 ; Regulator: 17.5 ; Reg Factor: 1
Weight Fitting Epoch: 6000/10000 ; Error: 66.7 ; Regulator: 11.8 ; Reg Factor: 1
Weight Fitting Epoch: 8000/10000 ; Error: 53.2 ; Regulator: 7.65 ; Reg Factor: 1
Weight Fitting Epoch: 10000/10000 ; Error: 48.1 ; Regulator: 4 ; Reg Factor: 1


### Get the explanation
Explain any given record in the test set by changing the index, i.

In [6]:
# Explain
i = 29
explanation = my_explainer.explain(
    i,
    baseline="median",
)

### Plot the explanation

The explanation is plotted as a styled df, in this notebook, but it is also viewable in the browser, if the `return_type` is set to "html".

In [7]:
explain_record_df, display_corpus_df = my_explainer.summary_plot(
    example_importance_threshold=0.000000001,
    output_file_prefix="",
    return_type="styled_df",
    rescaler=scaler,
)
display(explain_record_df)
display(display_corpus_df)

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),Test Prediction,Test Label
0,5.9,3.0,4.2,1.5,0,1


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),Example Importance,Corpus Prediction,Corpus Label
0,5.7,2.9,4.2,1.3,66.24%,0,1
1,6.3,3.3,4.7,1.6,18.93%,1,2
2,6.5,3.2,5.1,2.0,9.75%,1,1
3,6.8,3.2,5.9,2.3,0.53%,1,2
4,5.7,3.0,4.2,1.2,0.39%,1,1
5,6.9,3.2,5.7,2.3,0.16%,0,0
6,5.6,2.8,4.9,2.0,0.15%,1,1
7,6.9,3.1,5.1,2.3,0.15%,1,1
8,6.5,3.0,5.8,2.2,0.14%,1,1
9,6.9,3.1,5.4,2.1,0.13%,0,0


### Save the explainer to file
This file can now be uploaded to the Interpretability Suite app (TODO: add link). This provides a non-programtic interface with which to view the various explanations, allowing you to send the explainer to a colleague who is less fluent in python.

In [8]:
io.save_explainer(
    my_explainer, "my_new_iris_mlp_simplex_explainer.p"
)

Saving explainer to: /home/rob/Documents/projects/Interpretability/Notebooks/my_new_iris_mlp_simplex_explainer.p
