In [None]:
import os
import numpy as np
import json

import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img, img_to_array, array_to_img
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions
#from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input, decode_predictions
from tensorflow.keras.datasets import mnist

from PIL import Image as PilImage

from omnixai.data.image import Image
from omnixai.preprocessing.image import Resize
from omnixai.explainers.vision import IntegratedGradientImage
from omnixai.explainers.vision import GradCAM
from omnixai.explainers.vision import ShapImage
from omnixai.explainers.vision import LimeImage

## ResNet50 explanations

In [None]:
def load_image(path):
    return Resize((224, 224)).transform(Image(PilImage.open(path).convert('RGB')))

def preprocess(images):
    data = [np.expand_dims(img_to_array(img.to_pil()), axis=0) for img in images]
    data = np.concatenate(data, axis=0)
    preprocess_input(data)
    
    return data

img_dir = os.path.join(os.getcwd(), 'images')
with open('imagenet_class_index.json', 'r') as classes:
    class_idx = json.load(classes)
    idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]

In [None]:
resnet = ResNet50(weights = 'imagenet')
gradcam = GradCAM(model = resnet, target_layer = resnet.layers[-5], preprocess_function = preprocess)
integrated_gradients = IntegratedGradientImage(model = resnet, preprocess_function = preprocess)

def predict_and_explain(img_path, explainer):
    img = load_image(os.path.join(img_dir, img_path))
    print(decode_predictions(resnet.predict(preprocess(img)), top=2))
    explainer.explain(img).ipython_plot(index=0, class_names = idx2label)

### GradCam

In [None]:
predict_and_explain('n02088364_beagle.JPEG', gradcam)

In [None]:
predict_and_explain('n01531178_goldfinch.JPEG', gradcam)

In [None]:
predict_and_explain('n02119789_kit_fox.JPEG', gradcam)

In [None]:
dog_cat = load_image(os.path.join(img_dir, 'dog_cat.png'))
print(decode_predictions(resnet.predict(preprocess(dog_cat)), top=2))

gradcam.explain(dog_cat).ipython_plot(index=0, class_names = idx2label)
gradcam.explain(dog_cat, y=[243]).ipython_plot(index=0, class_names = idx2label) # 243 - bull mastiff 
gradcam.explain(dog_cat, y=[281]).ipython_plot(index=0, class_names = idx2label) # 281 - tabby, tabby cat

In [None]:
predict_and_explain('n01806143_peacock.JPEG', gradcam)

In [None]:
predict_and_explain('n01695060_Komodo_dragon.JPEG', gradcam)

In [None]:
gradcam2 = GradCAM(model = resnet, target_layer = resnet.layers[-3], preprocess_function = preprocess)
predict_and_explain('n01695060_Komodo_dragon.JPEG', gradcam2)

In [None]:
predict_and_explain('american_egret.jpg', gradcam)

In [None]:
predict_and_explain('n02087046_toy_terrier.JPEG', gradcam)

### Integrated Gradients

In [None]:
predict_and_explain('n02088364_beagle.JPEG', integrated_gradients)

In [None]:
predict_and_explain('dog_cat.png', integrated_gradients)

In [None]:
predict_and_explain('n01695060_Komodo_dragon.JPEG', integrated_gradients)

In [None]:
predict_and_explain('american_egret.jpg', integrated_gradients)

### LIME

In [None]:
lime = LimeImage(predict_function= lambda images: resnet.predict(preprocess(images)))

In [None]:
lime.explain(load_image(os.path.join(img_dir, 'n02088364_beagle.JPEG')), hide_color=0, num_samples=1000).ipython_plot(index=0, class_names=idx2label)

In [None]:
lime.explain(load_image(os.path.join(img_dir, 'dog_cat.png')), hide_color=0, num_samples=1000).ipython_plot(index=0, class_names=idx2label)

In [None]:
lime.explain(load_image(os.path.join(img_dir, 'n01806143_peacock.JPEG')), hide_color=0, num_samples=1000).ipython_plot(index=0, class_names=idx2label)

In [None]:
lime.explain(load_image(os.path.join(img_dir, 'n01695060_Komodo_dragon.JPEG')), hide_color=0, num_samples=1000).ipython_plot(index=0, class_names=idx2label)

In [None]:
lime.explain(load_image(os.path.join(img_dir, 'american_egret.jpg')), hide_color=0, num_samples=1000).ipython_plot(index=0, class_names=idx2label)

In [None]:
lime.explain(load_image(os.path.join(img_dir, 'n02087046_toy_terrier.JPEG')), hide_color=0, num_samples=1000).ipython_plot(index=0, class_names=idx2label)

## MNIST explanations with SHAP

In [None]:
import shap

mnist_model = tf.keras.models.load_model('Simple_MNIST_convnet')
(x_train, _), (x_test, _) = mnist.load_data()

img_rows, img_cols = 28, 28
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

In [None]:
background = x_train[np.random.choice(x_train.shape[0], 100, replace=False)]
shap1 = shap.DeepExplainer(mnist_model, background)

In [None]:
shap.image_plot(shap1.shap_values(x_test[0:10]), -x_test[0:10])

In [None]:
shap.image_plot(shap1.shap_values(x_test[0:5]), -x_test[0:5])

In [None]:
_, (x_test, _) = mnist.load_data()
    
x_test = Image(x_test.astype('float32'), batched=True)

shap2 = ShapImage(model = mnist_model, preprocess_function = lambda x: np.expand_dims(x.to_numpy() / 255, axis=-1))
explanations = shap2.explain(x_test[0:10])

In [None]:
explanations.ipython_plot(index = 0)

In [None]:
explanations.ipython_plot(index = 3)

In [None]:
explanations.ipython_plot(index = 4)

In [None]:
explanations.ipython_plot(index = 9)

In [None]:
explanations.ipython_plot(index = 8)