# Explicaciones Contrafactuales
## Método de Wachter
***

* Ejemplo obtenido de la documentación de [alibibi](https://docs.seldon.io/projects/alibi/en/stable/examples/cf_mnist.html)
* Se generan instancias contrafactuales (nuevas imágenes) para explicar clasifiación del dataset MNIST, utilizando el Método presentado en [Wachter et al, 2017]



[Wachter et al, 2017] Wachter, S., Mittelstadt, B., and Russell, C. (2017). Counterfactual explanations without opening the black box: Automated decisions and the GDPR. *Harv. JL & Tech*.

## Bibliotecas

In [None]:
import tensorflow as tf
tf.get_logger().setLevel(40) # suppress deprecation messages
tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs
from tensorflow.keras.layers import Conv2D, Dense, Dropout, Flatten, MaxPooling2D, Input
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.utils import to_categorical
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
from time import time

In [None]:
print('TF version: ', tf.__version__)
print('Eager execution enabled: ', tf.executing_eagerly()) # False

## Biblioteca alibibi de Python
***
* [alibibi][https://docs.seldon.io/projects/alibi/en/stable/index.html] es una bibliteca que implementa diferentes métodos para explicar modelos de aprendizaje automático
* En este ejemplo, se usará el método [Counterfactual](https://docs.seldon.io/projects/alibi/en/stable/methods/CF.html)

In [None]:
from alibi.explainers import Counterfactual

In [None]:
#pip install alibi
#!pip uninstall typing_extensions --yes

In [None]:
#pip install typing_extensions==4.7.1

## Dataset MNIST

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
print('x_train shape:', x_train.shape, 'y_train shape:', y_train.shape)

plt.figure()
plt.imshow(x_test[1], cmap="gray")
plt.show()

## Preprocesamiento de los imágenes

In [None]:
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
x_train = np.reshape(x_train, x_train.shape + (1,))
x_test = np.reshape(x_test, x_test.shape + (1,))
print('x_train shape:', x_train.shape, 'x_test shape:', x_test.shape)
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
print('y_train shape:', y_train.shape, 'y_test shape:', y_test.shape)

## Escalamiento de las imágenes

In [None]:
xmin, xmax = -.5, .5
x_train = ((x_train - x_train.min()) / (x_train.max() - x_train.min())) * (xmax - xmin) + xmin
x_test = ((x_test - x_test.min()) / (x_test.max() - x_test.min())) * (xmax - xmin) + xmin

## Modelo ML
***
* Red convolucional para clasificar imágenes

In [None]:
def cnn_model():
    x_in = Input(shape=(28, 28, 1))
    x = Conv2D(filters=64, kernel_size=2, padding='same', activation='relu')(x_in)
    x = MaxPooling2D(pool_size=2)(x)
    x = Dropout(0.3)(x)

    x = Conv2D(filters=32, kernel_size=2, padding='same', activation='relu')(x)
    x = MaxPooling2D(pool_size=2)(x)
    x = Dropout(0.3)(x)

    x = Flatten()(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.5)(x)
    x_out = Dense(10, activation='softmax')(x)

    cnn = Model(inputs=x_in, outputs=x_out)
    cnn.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

    return cnn

## Entrenamiento 

In [None]:
cnn = cnn_model()
cnn.summary()
cnn.fit(x_train, y_train, batch_size=64, epochs=3, verbose=0)
cnn.save('mnist_cnn.h5')

In [None]:
cnn = cnn_model()
cnn.summary()
cnn.fit(x_train, y_train, batch_size=64, epochs=3, verbose=0)
cnn.save('mnist_cnn.h5')

## Se carga modelo CNN y se evalua

In [None]:
cnn = load_model('mnist_cnn.h5')
score = cnn.evaluate(x_test, y_test, verbose=0)
print('Test accuracy: ', score[1])

## Imagen del conjunto de testing

In [None]:
X = x_test[0].reshape((1,) + x_test[0].shape)
plt.imshow(X.reshape(28, 28), cmap="gray");

## Parámetros para generar instancia contrafactual

In [None]:
shape = (1,) + x_train.shape[1:] # generando dimensión adecuada de la instancia a explicar
target_proba = 1.0 # nueva predicción deseada
tol = 0.01 # tolerancia que permita aceptar instancias contrafactuales con p(class)>0.99
target_class = 'other' # cualquier clase que no sea 7
max_iter = 1000
lam_init = 1e-1 # valor inicial de lambda
max_lam_steps = 10 # número de paso para buscar un valor distinto de lambda
learning_rate_init = 0.1
feature_range = (x_train.min(),x_train.max()) #valores máximos y mínimos por característica 
                                              #para la instancia perturbada

In [None]:
shape

## Explicación Contrafactual
***


In [None]:
# initialize explainer
cf = Counterfactual(cnn, shape=shape, target_proba=target_proba, tol=tol,
                    target_class=target_class, max_iter=max_iter, lam_init=lam_init,
                    max_lam_steps=max_lam_steps, learning_rate_init=learning_rate_init,
                    feature_range=feature_range)

start_time = time()
explanation = cf.explain(X) #instancia escogida para explicar la predicción
print('Tiempo transcurrido {:.3f} sec'.format(time() - start_time))


El método explain() retorna un objeto "Explanation", que tiene los sgtes. atributos:

* cf: dictionary containing the counterfactual instance found with the smallest distance to the test instance, it has the following keys:
* X: the counterfactual instance
* distance: distance to the original instance
* lambda: value of  corresponding to the counterfactual
* index: the step in the search procedure when the counterfactual was found
* class: predicted class of the counterfactual
* proba: predicted class probabilities of the counterfactual
* loss: counterfactual loss
* orig_class: predicted class of original instance
* orig_proba: predicted class probabilites of the original instance
* all: dictionary of all instances encountered during the search that satisfy the counterfactual constraint but have higher distance to the original instance than the returned counterfactual. This is organized by levels of , i.e. explanation['all'][0] will be a list of dictionaries corresponding to instances satisfying the counterfactual condition found in the first iteration over  during bisection.

In [None]:
pred_class = explanation.cf['class']
proba = explanation.cf['proba'][0][pred_class]

print(f'Predicción contrafactual: {pred_class} con probabilidac {proba}')

plt.figure()
plt.imshow(explanation.cf['X'].reshape(28, 28), cmap="gray")
plt.show()

* La instancia contrafactual que comienza con la imagen 7 se mueve hacia la clase más cercana según lo determinado por el modelo y los datos, en este caso un 9. 

* La evolución del contrafactual durante las iteraciones se muestran a continuación

In [None]:
examples[0]["lambda"]

In [None]:
n_cfs = np.array([len(explanation.all[iter_cf]) for iter_cf in range(max_lam_steps)])
examples = {}
for ix, n in enumerate(n_cfs):
    if n>0:
        examples[ix] = {'ix': ix, 'lambda': explanation.all[ix][0]['lambda'],
                       'X': explanation.all[ix][0]['X']}
        print(ix, "lambda:", explanation.all[ix][0]['lambda'])
columns = len(examples) + 1
rows = 1

fig = plt.figure(figsize=(16,6))

for i, key in enumerate(examples.keys()):
    
    ax = plt.subplot(rows, columns, i+1)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    plt.imshow(examples[key]['X'].reshape(28,28))
    plt.title(f'Iteration: {key}')

### Buscando un target específico

In [None]:
target_class = 1 # se busca una instancia contrafactual que genere una prediccion en la clase 1

cf = Counterfactual(cnn, shape=shape, target_proba=target_proba, tol=tol,
                    target_class=target_class, max_iter=max_iter, lam_init=lam_init,
                    max_lam_steps=max_lam_steps, learning_rate_init=learning_rate_init,
                    feature_range=feature_range)

explanation = start_time = time()
explanation = cf.explain(X)
print('Tiempo transcurrido {:.3f} sec'.format(time() - start_time))

### Instancia contrafactual encontrada

In [None]:
pred_class = explanation.cf['class']
proba = explanation.cf['proba'][0][pred_class]
print(f'Predicción contrafactual: {pred_class} con probabilidad {proba}')

plt.figure()
plt.imshow(explanation.cf['X'].reshape(28, 28), cmap="gray")
plt.show()

* Ahora, al indicar una clase target específica, el proceso de búsqueda no puede ir hacia la clase más cercana asociada al dato que se quiere explicar (en este caso un 9 como se vio anteriormente), y por lo tanto, el contrafactual puede ser menos interpretable. 
* La diferencia entre el caso contrafactual y la imagen original: 

In [None]:
plt.figure()
plt.imshow((explanation.cf['X'] - X).reshape(28, 28), cmap="gray");
plt.show()
