# Towards Robust Interpretability with Self-Explaining Neural Networks

Authors: *Rico Mossinkoff, Yke Rusticus, Roberto Schiavone, Ewoud Vermeij*

In [1]:
import warnings
warnings.filterwarnings("ignore")

from api import mnist, compas
from api.load import load_compas, load_mnist, RegLambda, HType, NConcepts
from api.utils import MNIST_TEST_SET

## MNIST
For the MNIST dataset, it is possible to load various models specifying the following parameters to the `load_mnist` function:
- `h_type=HType.INPUT` and `reg_lambda`
- `h_type=HType.CNN`, `reg_lambda` and `n_concepts`

If `load_mnist` is called with `h_type=HType.INPUT` and `n_concepts`, `n_concepts` is safely ignored.

The default parameters for MNIST are `h_type=HType.INPUT`, `n_concepts=NConcepts.FIVE` and `reg_lambda=RegLambda.E4`

The possible values for each parameter are shown below.

In [None]:
print('h_type possible values:')
for x in HType:
    print(x)

In [None]:
print('reg_lambda possible values:')
for x in RegLambda:
    value = ('{:0.0e}' if x.value != 0 and x.value != 1 else '{}').format(x.value)
    print(str(x) + ': ' + value)

In [None]:
print('n_concepts possible values:')
for x in NConcepts:
    print(str(x) + ': ' + str(x.value))

In [None]:
model = load_mnist(n_concepts=NConcepts.TWENTY)
# equivalent
# model = load_mnist(reg_lambda=RegLambda.E4, h_type=HType.CNN, n_concepts=NConcepts.FIVE)

In [None]:
n = 3
x = mnist.get_digit_image(n)

In [None]:
mnist.plot_digit_noise_activation(model, n, 5)

In [None]:
mnist.plot_digit_stability_concept_grid(model, 5)

In [None]:
model = load_compas(RegLambda.ONE)

In [None]:
x = {
    'Two_yr_Recidivism': 1., 
     'Number_of_Priors': 1.,
     'Age_Above_FourtyFive': 1.,
     'Age_Below_TwentyFive':1.,
     'African_American': 1.,
     'Asian': 1., 
     'Hispanic': 1.,
     'Native_American': 1.,
     'Other':1.,
     'Female': 0.,
     'Misdemeanor': 1.
    }

In [None]:
compas.plot_lipschitz_feature(model, x)

# Model evaluations


In [2]:
_, _, _, _, _, test , _, _ = compas.load_compas_data(shuffle=False, batch_size=64)

  1%|▏         | 82/6172 [00:00<00:07, 812.16it/s]

/home/evermeij/Documents/Github/FACT/SENN/data/fairml/doc/example_notebooks/propublica_data_for_fairml.csv


100%|██████████| 6172/6172 [00:03<00:00, 1659.75it/s]


In [3]:
print("<--------COMPAS MODELS-------->")
compass_acc = []
for reg in RegLambda:
    model = load_compas(reg, False)
    acc = compas.evaluate(model, test, print_freq=0, return_acc=True)
    compass_acc.append(acc)
    strreg = ("Lambda: " + str(reg)[10:] + ",").ljust(14, ' ')
    print("COMPAS model - %s accuracy: %0.4f" %(strreg, acc))

print("\n<--------MNIST MODELS--------->")
mnist_acc = []
for ht in HType:
    for concept in NConcepts:
        for reg in RegLambda:
            model = load_mnist(reg, ht, concept, False)
            acc = mnist.evaluate(model, MNIST_TEST_SET, print_freq=0, return_acc=True)
            mnist_acc.append(acc)            
            strht = ("h(x): "+str(ht)[6:] + ",").ljust(14, ' ')
            strconcept = ("nconcepts: "+str(concept)[10:] + ",").ljust(20, ' ')
            strreg = ("Lambda: " + str(reg)[10:] + ",").ljust(15, ' ')
            print("MNIST model - %s %s %s accuracy: %0.5f" %(strht, strconcept, strreg, acc))
            

<--------COMPAS MODELS-------->
COMPAS model - Lambda: ZERO,  accuracy: 0.8153
COMPAS model - Lambda: E4,    accuracy: 0.8060
COMPAS model - Lambda: E3,    accuracy: 0.8134
COMPAS model - Lambda: E2,    accuracy: 0.8041
COMPAS model - Lambda: E1,    accuracy: 0.7910
COMPAS model - Lambda: ONE,   accuracy: 0.6754

<--------MNIST MODELS--------->
MNIST model - h(x): CNN,     nconcepts: FIVE,     Lambda: ZERO,   accuracy: 0.98850
MNIST model - h(x): CNN,     nconcepts: FIVE,     Lambda: E4,     accuracy: 0.98750
MNIST model - h(x): CNN,     nconcepts: FIVE,     Lambda: E3,     accuracy: 0.98650
MNIST model - h(x): CNN,     nconcepts: FIVE,     Lambda: E2,     accuracy: 0.98250
MNIST model - h(x): CNN,     nconcepts: FIVE,     Lambda: E1,     accuracy: 0.96580
MNIST model - h(x): CNN,     nconcepts: FIVE,     Lambda: ONE,    accuracy: 0.15570
MNIST model - h(x): CNN,     nconcepts: TWENTY,   Lambda: ZERO,   accuracy: 0.99100
MNIST model - h(x): CNN,     nconcepts: TWENTY,   Lambda: E4,    

# MNIST faithfullness test

In [6]:
count = 0
for ht in HType:
    for concept in NConcepts:
        for reg in RegLambda:
            model = load_mnist(reg, ht, concept, False)          
            strht = ("h(x): "+str(ht)[6:] + ",").ljust(14, ' ')
            strconcept = ("nconcepts: "+str(concept)[10:] + ",").ljust(20, ' ')
            strreg = ("Lambda: " + str(reg)[10:] + ",").ljust(15, ' ')
            print("MNIST model - %s %s %s accuracy: %0.5f" %(strht, strconcept, strreg, mnist_acc[count]))
            faith_acc = mnist.faith_evaluate(model, concept.value, MNIST_TEST_SET, print_freq=0, return_acc=True, cuda=True)
            for x in faith_acc:
                line = "INcrease {:05.4f}".format(faith_acc[x]-mnist_acc[count]) if faith_acc[x] > mnist_acc[count] else "DEcrease {:05.4f}".format(mnist_acc[count]-faith_acc[x])
                print("Leaving concept %d out, giving an %s. accuracy result: %0.4f" %(x, line, faith_acc[x]))
            print("")
            count += 1

MNIST model - h(x): CNN,     nconcepts: FIVE,     Lambda: ZERO,   accuracy: 0.98850
Leaving concept 0 out, giving an DEcrease 0.0003. accuracy result: 0.9882
Leaving concept 1 out, giving an INcrease 0.0006. accuracy result: 0.9891
Leaving concept 2 out, giving an DEcrease 0.0002. accuracy result: 0.9883
Leaving concept 3 out, giving an DEcrease 0.0001. accuracy result: 0.9884
Leaving concept 4 out, giving an INcrease 0.0002. accuracy result: 0.9887

MNIST model - h(x): CNN,     nconcepts: FIVE,     Lambda: E4,     accuracy: 0.98750
Leaving concept 0 out, giving an INcrease 0.0006. accuracy result: 0.9881
Leaving concept 1 out, giving an INcrease 0.0004. accuracy result: 0.9879
Leaving concept 2 out, giving an DEcrease 0.0000. accuracy result: 0.9875
Leaving concept 3 out, giving an INcrease 0.0005. accuracy result: 0.9880
Leaving concept 4 out, giving an INcrease 0.0003. accuracy result: 0.9878

MNIST model - h(x): CNN,     nconcepts: FIVE,     Lambda: E3,     accuracy: 0.98650
Leavin

Leaving concept 0 out, giving an DEcrease 0.0001. accuracy result: 0.9705
Leaving concept 1 out, giving an DEcrease 0.0010. accuracy result: 0.9696
Leaving concept 2 out, giving an INcrease 0.0002. accuracy result: 0.9708
Leaving concept 3 out, giving an DEcrease 0.0007. accuracy result: 0.9699
Leaving concept 4 out, giving an DEcrease 0.0003. accuracy result: 0.9703
Leaving concept 5 out, giving an DEcrease 0.0014. accuracy result: 0.9692
Leaving concept 6 out, giving an INcrease 0.0003. accuracy result: 0.9709
Leaving concept 7 out, giving an DEcrease 0.0015. accuracy result: 0.9691
Leaving concept 8 out, giving an DEcrease 0.0003. accuracy result: 0.9703
Leaving concept 9 out, giving an DEcrease 0.0007. accuracy result: 0.9699
Leaving concept 10 out, giving an INcrease 0.0008. accuracy result: 0.9714
Leaving concept 11 out, giving an DEcrease 0.0005. accuracy result: 0.9701
Leaving concept 12 out, giving an DEcrease 0.0001. accuracy result: 0.9705
Leaving concept 13 out, giving an D

ValueError: too many values to unpack (expected 2)