## Imports

Begin by installing the shap library:



In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install shap

In [None]:
import shap
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
# Set a seed for reproducibility
seed = 42
np.random.seed(seed)
tf.random.set_seed(seed)
random_ = True

In [None]:
# Download the dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
y_train_true, y_random_labels = y_train, np.random.permutation(y_train)
# Reshape and normalize data
x_train = x_train.reshape(60000, 28, 28, 1).astype("float32") / 255
x_test = x_test.reshape(10000, 28, 28, 1).astype("float32") / 255

In [None]:
inputs = keras.Input(shape=(28, 28, 1))
x = keras.layers.Conv2D(32, (3, 3), activation='relu')(inputs)
x = keras.layers.MaxPooling2D((2, 2))(x)
x = keras.layers.Conv2D(64, (3, 3), activation='relu')(x)
x = keras.layers.MaxPooling2D((2, 2))(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(256, activation='relu')(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)

# Create the model with the corresponding inputs and outputs
model = keras.Model(inputs=inputs, outputs=outputs, name="CNN")

# Compile the model
model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      optimizer=keras.optimizers.Adam(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
  )

# Train it!
if (random_):
  history = model.fit(x_train, y_random_labels, epochs=100, validation_data=(x_test, y_test), batch_size=128)
else:
  history = model.fit(x_train, y_train_true, epochs=50, validation_data=(x_test, y_test), batch_size=128)

In [None]:
# Name each one of the classes
class_names = ['0', '1', '2', '3', '4',
               '5', '6', '7', '8', '9']

# Save an example for each category in a dict
images_dict = dict()
for i, l in enumerate(y_random_labels):
  if len(images_dict)==10:
    break
  if l not in images_dict.keys():
    images_dict[l] = x_train[i].reshape((28, 28))

# Function to plot images
def plot_categories(images):
  fig, axes = plt.subplots(1, 11, figsize=(16, 15))
  axes = axes.flatten()

  # Plot an empty canvas
  ax = axes[0]
  dummy_array = np.array([[[0, 0, 0, 0]]], dtype='uint8')
  ax.set_title("reference")
  ax.set_axis_off()
  ax.imshow(dummy_array, interpolation='nearest')

  # Plot an image for every category
  for k,v in images.items():
    ax = axes[k+1]
    ax.imshow(v, cmap=plt.cm.binary)
    ax.set_title(f"{class_names[k]}")
    ax.set_axis_off()

  plt.tight_layout()
  plt.show()


# Use the function to plot
plot_categories(images_dict)

In [None]:
# Select 5000 random samples from x_test
background = x_test[np.random.choice(x_test.shape[0], 5000, replace=False)]

# Use DeepExplainer to explain predictions of the model
e = shap.DeepExplainer(model, background)

In [None]:
# Save an example of each class from the test set
x_test_dict = dict()
for i, l in enumerate(y_test):
  if len(x_test_dict)==10:
    break
  if l not in x_test_dict.keys():
    x_test_dict[l] = x_test[i]

# Convert to list preserving order of classes
x_test_each_class = [x_test_dict[i] for i in sorted(x_test_dict)]

# Convert to tensor
x_test_each_class = np.asarray(x_test_each_class)

# Print shape of tensor
print(f"x_test_each_class tensor has shape: {x_test_each_class.shape}")

In [None]:
shap_values = e.shap_values(x_test_each_class)

In [None]:
# Plot reference column
plot_categories(x_test_dict)

# Print an empty line to separate the two plots
print()

# Plot shap values
shap.image_plot(shap_values, -x_test_each_class)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import shap
from PIL import Image

# Assuming 'e' is your SHAP explainer and 'x_test_each_class' is your data

# Create a list to store SHAP images for each digit
shap_images = []

# Loop through each digit and create SHAP images
for i in range(10):
    shap_values = e.shap_values(x_test_each_class[i].reshape(1, 28, 28, 1))

    # Plot SHAP values
    shap.image_plot(shap_values, -x_test_each_class[i].reshape(1, 28, 28, 1), show=False)

    # Save the current plot as an image
    image_path = f"shap_image_digit_rand_{i}.png"
    plt.savefig(image_path)
    plt.close()  # Close the plot

    # Open the saved image and append to the list
    shap_images.append(Image.open(image_path))

# Concatenate the images vertically
combined_image = np.concatenate(shap_images, axis=0)

# Save the combined image
combined_image_path = "combined_shap_images_rand.png"
Image.fromarray(combined_image).save(combined_image_path)

# Display the combined image
plt.imshow(combined_image)
plt.axis('off')
plt.show()


In [None]:
import numpy as np

# Assuming shap_values_list is a list of numpy arrays, each with shape (1, 28, 28, 1)
shap_values_dict = {f'shap_value_{i}': arr.squeeze() for i, arr in enumerate(shap_values)}

# Save the SHAP values
if(random_):
  np.savez('/content/drive/MyDrive/mnist_rand/shap_values_rand_mnist.npz', **shap_values_dict)
else:
  np.savez('/content/drive/MyDrive/mnist/shap_values_mnist.npz', **shap_values_dict)


In [None]:
# Plot reference column
plot_categories(x_test_dict)
shap_values = e.shap_values(x_test_each_class[4].reshape(1, 28, 28, 1))
shap.image_plot(shap_values, -x_test_each_class[4].reshape(1,28,28,1))

In [None]:
import numpy as np

# Assuming shap_values_list is a list of numpy arrays, each with shape (1, 28, 28, 1)
shap_values_dict_single_dress = {f'shap_value_{i}': arr.squeeze() for i, arr in enumerate(shap_values)}

# Save the SHAP values
if(random_):
  np.savez('/content/drive/MyDrive/mnist_rand/shap_values_mnist_2.npz', **shap_values_dict_single_dress)
else:
  np.savez('/content/drive/MyDrive/mnist/shap_values_mnist_2.npz', **shap_values_dict_single_dress)

In [None]:
len(shap_values)

In [None]:
shap_values[0].shape

In [None]:
import matplotlib.pyplot as plt
import numpy as np


# Calculate the overall maximum and minimum values across all images
max_intensity = np.max([np.max(np.abs(image.squeeze(axis=0))) for image in shap_values])
min_intensity = -max_intensity

# Create a grid of subplots
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(10, 4))

# Iterate through each image in shap_values and plot it
for i, ax in enumerate(axes.flat):
    # Extract the 3D array from the list
    image = shap_values[i].squeeze(axis=0)  # Squeeze to remove the first dimension (1)

    # Plot the image with a custom colormap
    im = ax.imshow(image, cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
    ax.axis('off')  # Hide axis labels

# Add a colorbar to show the correspondence between color and intensity
cbar = fig.colorbar(im, ax=axes, orientation='vertical', fraction=0.03, pad=0.1)

# Show the plot
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np


# Create a grid of subplots
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(10, 4))

# Iterate through each image in shap_values and plot it
for i, ax in enumerate(axes.flat):
    # Extract the 3D array from the list
    image = shap_values[i].squeeze(axis=0)  # Squeeze to remove the first dimension (1)

    # Calculate the maximum absolute intensity value for setting color scale
    max_abs_intensity = np.max(np.abs(image))

    # Plot the image with a custom colormap
    im = ax.imshow(image, cmap='RdBu_r', vmin=-max_abs_intensity, vmax=max_abs_intensity)
    ax.axis('off')  # Hide axis labels



# Show the plot
plt.show()


In [None]:
import pickle


if(random_):
  # Save the entire model to a HDF5 file
  model.save("/content/drive/MyDrive/mnist_rand/your_model.keras")

  # Save the history object to a file using pickle
  with open("/content/drive/MyDrive/mnist_rand/training_history.pkl", "wb") as file:
      pickle.dump(history.history, file)
else:
  # Save the entire model to a HDF5 file
  model.save("/content/drive/MyDrive/mnist/your_model.keras")

  # Save the history object to a file using pickle
  with open("/content/drive/MyDrive/mnist/training_history.pkl", "wb") as file:
      pickle.dump(history.history, file)


In [None]:
for elem in shap_values:
  print(elem.shape)

In [None]:
# Select a specific digit, let's say digit 5
digit_to_analyze = 4
iteration_ = 500
acc = 0
# Save an example of each class from the test set
x_test_dict = dict()
for i, label in enumerate(y_test):
    if (label == digit_to_analyze and acc < iteration_):
      acc += 1
      x_test_dict.setdefault(label, []).append(x_test[i])

# Convert to list preserving order of classes
x_test_digit = [item for sublist in x_test_dict.values() for item in sublist]

# Convert to tensor
x_test_digit = np.asarray(x_test_each_class)

# Print shape of tensor
print(f"x_test_each_class tensor has shape: {x_test_digit.shape}")

In [None]:
# Calculate SHAP values for the selected digit


shap_arrays = []
for i in range (iteration_):
  shap_values = e.shap_values(x_test_each_class[i].reshape(1,28,28,1), check_additivity=False)
  shap_arrays += [shap_values[digit_to_analyze]]

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Assuming 'shap_arrays' is your list of numpy arrays
# Calculate the mean of the arrays along the first axis (axis=0)
mean_shap_array = np.mean(shap_arrays, axis=0)
print(mean_shap_array.shape)
# Squeeze the singleton dimensions to get an array of shape (28, 28)
mean_shap_array = mean_shap_array.squeeze()

# Plot the mean image with a custom colormap
plt.imshow(mean_shap_array, cmap='RdBu_r', vmin=-np.max(np.abs(mean_shap_array)), vmax=np.max(np.abs(mean_shap_array)))
plt.title('Mean of Shap Arrays')
plt.colorbar()  # Add a colorbar to show the correspondence between color and intensity
plt.axis('off')  # Hide axis labels
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Assuming 'shap_arrays' is your list of numpy arrays
# Calculate the mean of the arrays along the first axis (axis=0)
mean_shap_array = np.mean(x_test_digit[:500], axis=0)
print(mean_shap_array.shape)
# Squeeze the singleton dimensions to get an array of shape (28, 28)
mean_shap_array = mean_shap_array.squeeze()

# Plot the mean image with a custom colormap
plt.imshow(mean_shap_array)
plt.title('Mean - digit 4')
plt.colorbar()  # Add a colorbar to show the correspondence between color and intensity
plt.axis('off')  # Hide axis labels
plt.show()
