# Analysis of Model flowerclass-efficientnetv2-2 2: with XAI SHAP method

### Goals

* Apply SHAP to explain decisions leading to model errors in `flowerclass-efficientnetv2-2-analysis2-imgvis` notebook
* Focus on local explanation of image, hence use  Kernel SHAP as we do not need to calculate explanations for many instances for a global interpretations

Note:

Adapted from https://shap-lrjball.readthedocs.io/en/latest/example_notebooks/kernel_explainer/ImageNet%20VGG16%20Model%20with%20Keras.html

In [None]:
import math, re, os
import numpy as np
import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)


import tensorflow_hub as hub

from flowerclass_read_tf_ds import get_validation_dataset

from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
tf.test.gpu_device_name()

# I. Data Loading

In [None]:
image_size = 224
batch_size = 64

In [None]:
class_names = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102
len(class_names)

In [None]:
class_names_mapping = {k:classname for k, classname in enumerate(class_names)}

# II. Model Loading and Predictions: EfficientNetV2

In [None]:
effnet2_base = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_s/feature_vector/2"

In [None]:
    effnet2_tfhub = tf.keras.Sequential([
    # Explicitly define the input shape so the model can be properly
    # loaded by the TFLiteConverter
    tf.keras.layers.InputLayer(input_shape=(image_size, image_size,3)),
    hub.KerasLayer(effnet2_base, trainable=False),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(104, activation='softmax')
])
effnet2_tfhub.build((None, image_size, image_size,3,)) #This is to be used for subclassed models, which do not know at instantiation time what their inputs look like.


effnet2_tfhub.summary()

In [None]:
best_phase = 12
effnet2_tfhub.load_weights("../input/flowerclass-efficientnetv2-2/training/"+"cp-"+f"{best_phase}".rjust(4, '0')+".ckpt")

# II. Explaining model decisions

In [None]:
import shap
from skimage.segmentation import slic
# make a color map
from matplotlib.colors import LinearSegmentedColormap

In [None]:
data_path = "../input/tpu-getting-started/tfrecords-jpeg-224x224"
VALIDATION_FILENAMES = tf.io.gfile.glob(data_path + '/val/*.tfrec')

In [None]:
def get_images_by_ids(image_ids_search):
    
    ds_valid = get_validation_dataset(VALIDATION_FILENAMES, 1, (image_size, image_size), None, True)
    
    imgs_found = []
    imgage_ids_found = []
    labels_found = []
    for imgs, labels, imgs_id in tqdm(ds_valid):
        for img, img_id, label in zip(imgs, imgs_id, labels) :
            if img_id in image_ids_search:
                imgage_ids_found.append(img_id)
                imgs_found.append(img)
                labels_found.append(tf.argmax(label))
                
    return (tf.stack(imgs_found, 0), tf.cast(tf.concat(labels_found, 0), tf.int64)), imgage_ids_found

# IIa). globe-flower predictions

Here I dive deeper to understand a prediction for the globe-flower class analyzed in flowerclass_efficientnetv2_2_analysis2_imgvis.ipynb.

## Image ed3a59a35

The image for analysis has the id ed3a59a35.



In [None]:
image_id_investigate = "ed3a59a35"

In [None]:
batch_found, imgage_ids_found = get_images_by_ids([image_id_investigate])

In [None]:
plt.imshow(batch_found[0][0].numpy())

In [None]:
img = batch_found[0][0].numpy()

In [None]:
img.shape

Segment the image into superpixels

In [None]:
# as in the anchor method I use slice and the same parameter for segmentation. Note for LIME we used their default quickshift algorithm
kwargs = {'n_segments': 15, 'compactness': 20, 'sigma': .5}
segments_slic = slic(img,  **kwargs)
segments_slic.shape

In [None]:
# define a function that depends on a binary mask representing if an image region is hidden
def mask_image(zs, segmentation, image, background=None):
    '''
    zs: matrix of perturbed examples of the form (nsamples, 50)
    segmentation: superpixel segmentation slices
    image: original image to mask based on superpixels
    background: contains background values for masking
    '''
    print('shape ', zs.shape)
    if background is None:
        background = image.mean((0,1)) # array of length 3, with mean value for each of the 3 channels, used as background
    out = np.zeros((zs.shape[0], image.shape[0], image.shape[1], image.shape[2]))
    for i in range(zs.shape[0]):
        out[i,:,:,:] = image
        # go over the 50 superpixels 
        for j in range(zs.shape[1]):
            if zs[i,j] == 0:
                out[i][segmentation == j,:] = background
    return out

def f(z):
    return effnet2_tfhub.predict(mask_image(z, segments_slic, img, 255)) # masking/background pixel value 255 is white for all channels

Parameters:

* model: black-box model to explain
* data: representative samples from our (train) dataset to help define missingness for all our 50 features/superpixels. As we aim to represent missingness through black background for all superpixels we just need one "sample" which consists of a black  pixel for each superpixel (1, 50).
* link: feature value to model output connection. Not needed in our case, hence default identity is suitable

In [None]:
# use Kernel SHAP to explain the network's predictions
explainer = shap.KernelExplainer(model=f, data=np.zeros((1,50)))

Parameters:

* X: sample to explain. In our special case of image explanations we use only sample of ones which gets perturbed into multiple samples by SHAP and then processed via ` mask_image` in the `predict` function into perturbed versions of the original image.
* nsamples: We perturbe our representative sample (just the black image) ` nsample` number of times and re-evaluate our model on these. The perturbations do not affect the image?
* l1_reg:   l1 regularization to use for feature selection of our local models. The auto option currently uses "aic" when less that 20% of the possible sample
    space is enumerated, otherwise it uses no regularization.

In [None]:
shap_values = explainer.shap_values(X=np.ones((1,50)), nsamples=1000)

For each of our flower classes, we have a matrix of shap values:

In [None]:
len(shap_values)


In [None]:
shap_values[0].shape

In [None]:
np.expand_dims(img.copy(), axis=0).shape

In [None]:
# get the top predictions from the model
preds = effnet2_tfhub.predict(np.expand_dims(img.copy(), axis=0))
top_preds = np.argsort(-preds)

In [None]:
preds.shape, preds[0, :5]

In [None]:
top_preds

In [None]:
class_names_mapping[15]

In [None]:
# create colormap
colors = []
for l in np.linspace(1,0,100):
    colors.append((245/255,39/255,87/255,l)) # (red, blue, green, alpha)
for l in np.linspace(0,1,100):
    colors.append((24/255,196/255,93/255,l))  # (red, blue, green, alpha)
cm = LinearSegmentedColormap.from_list("shap", colors)
cm

In [None]:
def fill_segmentation(values, segmentation):
    '''fill superpixels with their SHAP values and returns image
    values: SHAP values
    segmentation: superpixel segmentations
    '''
    out = np.zeros(segmentation.shape)
    for i in range(len(values)):
        out[segmentation == i] = values[i]
    return out

Visualize the shap values of the superpixels for the top 3 class with the highest class probablity:

In [None]:
# class indices 
inds = top_preds[0]

# plot our explanations
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(16, 8))

# plot original image
axes[0].imshow(img)
axes[0].axis('off')

# max SHAP value of all SHAP values for all classes
max_val = np.max([np.max(np.abs(shap_values[i][:,:-1])) for i in range(len(shap_values))])


for i in range(3):
    m = fill_segmentation(shap_values[inds[i]][0], segments_slic)
    
    axes[i+1].set_title(class_names_mapping[inds[i]])
    
    axes[i+1].imshow(img, alpha=0.08, cmap='gray', vmin=0, vmax=255) # .convert('LA')
    
    im = axes[i+1].imshow(m, cmap=cm, vmin=-max_val, vmax=max_val)
    
    axes[i+1].axis('off')

cb = fig.colorbar(im, ax=axes.ravel().tolist(), label="SHAP value", orientation="horizontal", aspect=60)
cb.outline.set_visible(False)
plt.show()

Results:

* For the predicted globe-flower class, according to SHAP,the model focus more on the surrounding areas rather than the relevant flower itself. This is in contrast to LIME and the anchor method which both identified the flower as the most relevant part, in alignment with reasoning.