# Analysis of Model flowerclass-efficientnetv2-2 2: with XAI LIME method

### Goals

* Apply LIME to explain decisions leading to model errors in `flowerclass-efficientnetv2-2-analysis2-imgvis` notebook

In [None]:
import math, re, os
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
print(tf.__version__)
print(tfa.__version__)

from flowerclass_read_tf_ds import get_datasets, display_batch_by_class, display_batch_of_images #, load_dataset, display_batch_of_images, batch_to_numpy_images_and_labels, display_one_flower
import tensorflow_hub as hub
import pandas as pd
import math
import plotly_express as px
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import itertools

In [None]:
tf.test.gpu_device_name()

# I. Data prep, model Loading and Predictions with EfficientNetV2

In [None]:
image_size = 224
batch_size = 64

In [None]:
class_names = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102
len(class_names)

In [None]:
effnet2_base = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_s/feature_vector/2"

In [None]:
    effnet2_tfhub = tf.keras.Sequential([
    # Explicitly define the input shape so the model can be properly
    # loaded by the TFLiteConverter
    tf.keras.layers.InputLayer(input_shape=(image_size, image_size,3)),
    hub.KerasLayer(effnet2_base, trainable=False),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(104, activation='softmax')
])
effnet2_tfhub.build((None, image_size, image_size,3,)) #This is to be used for subclassed models, which do not know at instantiation time what their inputs look like.


effnet2_tfhub.summary()

In [None]:
best_phase = 12
effnet2_tfhub.load_weights("../input/flowerclass-efficientnetv2-2/training/"+"cp-"+f"{best_phase}".rjust(4, '0')+".ckpt")

In [None]:
from sklearn.metrics import classification_report, confusion_matrix

Ensure that validation data loader returns fixed order of elements.

In [None]:
ds_train, ds_valid, ds_test = get_datasets(BATCH_SIZE=batch_size, IMAGE_SIZE=(image_size, image_size), 
                                           RESIZE=None, tpu=False, with_id=True)

img_preds = []
img_labels = []
img_ids = []
for imgs, label, imgs_id in tqdm(ds_valid):
    img_preds.append(effnet2_tfhub.predict(imgs, batch_size=batch_size))
    img_labels.append(label.numpy())
    img_ids.append(imgs_id.numpy())
    
img_preds = np.concatenate([img_pred.argmax(1) for img_pred in img_preds])
img_labels = np.concatenate([img_label.argmax(1) for img_label in img_labels])
img_ids = np.concatenate([img_id for img_id in img_ids])


In [None]:
val_results = pd.DataFrame({'pred': img_preds, "label":img_labels, "id": img_ids})
val_results['id'] = val_results['id'].apply(lambda txt: txt.decode())

In [None]:
val_results.head()

# II. Explaining model decisions

In [None]:
def get_images_by_ids(image_ids_search):
    ds_train, ds_valid, ds_test = get_datasets(BATCH_SIZE=batch_size, IMAGE_SIZE=(image_size, image_size), 
                                               RESIZE=None, tpu=False, with_id=True)
    
    imgs_found = []
    imgage_ids_found = []
    labels_found = []
    for imgs, labels, imgs_id in tqdm(ds_valid):
        for img, img_id, label in zip(imgs, imgs_id, labels) :
            if img_id in image_ids_search:
                imgage_ids_found.append(img_id)
                imgs_found.append(img)
                labels_found.append(tf.argmax(label))
                
    return (tf.stack(imgs_found, 0), tf.cast(tf.concat(labels_found, 0), tf.int64)), imgage_ids_found

In [None]:
import lime
from lime import lime_image
from skimage.segmentation import mark_boundaries

# IIa). globe-flower predictions

Here I dive deeper to understand a prediction for the globe-flower class analyzed in `flowerclass_efficientnetv2_2_analysis2_imgvis.ipynb`.


## FP Image ed3a59a35

The image for analysis has the id ed3a59a35.

In [None]:
image_id_investigate = "ed3a59a35"

In [None]:
batch_found,  imgage_ids_found= get_images_by_ids([image_id_investigate])

In [None]:
plt.imshow(batch_found[0][0].numpy())

Create perturbed sample. Perturb them by sampling from a Normal(0,1).

Default parameters:
* kernel_width: width of kernel used in similarity measure to weight the surrounding data points when used in the surrogate model.
* kernel: kernel type/function used in similarity measure. I used the default exponential here. After similarities (with `distance_metric` parameter below) between perturbed samples are calculated, we apply the kernel function to get the weights for the surrogate model. 
* feature_selection: how to set the number of features in the surrogate model (when calling `explain_instances` below). due to our default choice of 'auto' together with high `num_features` below features are selected as follows: select superpixels/features that have the highest product of (absolute) feature  weight * explained imagea pixels. This weight for features is derived from a Ridge model (`Ridge(alpha=0.01)`) trained on all perturbed instances. This feature selection step happens before the actual surrogate model (`model_regressor`  parameter) below is fit based on the selected subset of features. Method flow: segmentation approach to create superpixels/features > feature selection of super pixels > fit of the surrogate model on selected feeatures

In [None]:
explainer = lime_image.LimeImageExplainer(kernel_width=0.25, kernel=None, feature_selection='auto', random_state=42) # all params are defaults 

Create superpixels, create perturbed sampes. For each class (in `top_labels`) perform feature selection, fit surrogate model and get model coefficients.

Parameters:
* labels: iterable with labels to be explained. as we chose top_labels parameter it is not used.
* top_labels: choose the top n classes (highest probability). *note that the top_labels=12 was chosen to include the ground truth class 'lotus'.*
* num_features: use default 100000 
* num_samples: we use the default of 1000 samples in the neighbourhood of the out image as training data to train our surrogate model
* distance_metric: consine (default) distance metric for the similarities between perturbed samples and image to be explained. in practice calculates similarity between perturbations only.
* model_regressor: uses by default Ridge model (`sklearn.linear_model.Ridget(alpha=1)`), as the surrogate model. the model is trained on perturbed samples with weights given through the kernel function
* segmentation_fn: quickshift segmentation algorithm from image based on skimage to create superpixels ([ref](https://scikit-image.org/docs/dev/api/skimage.segmentation.html?highlight=quickshift#skimage.segmentation.quickshift)). Image to be explained will be perturbed by turning on/off each superpixel with 50% probability to create`num_samples` samples.

In [None]:
%%time
explanation = explainer.explain_instance(image=batch_found[0][0].numpy().astype('double'), 
                                         classifier_fn=effnet2_tfhub.predict, 
                                         top_labels=12, hide_color=0, num_samples=1000,
                                        random_seed=42)

As seen in the flowerclass_efficientnetv2_2_analysis2_imgvis notebook, the top prediction is globe-flower. The true class, tulip, is not available.

In [None]:
explanation.top_labels

In [None]:
[class_names[cl] for cl in explanation.top_labels]

Get pre-computed feature importance for superpixel of class selected (`explanation.top_labels[0]`) and display top 5 most important super pixels.

Parameters:
* num_features: number of superpixels to include in explanation
* positive_only: only take superpixels positively contributing
* hide_rest: mark the non-explaining part of the image

In [None]:
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=True)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

> Note that repeated execution of `explainer.explain_instance` above (without further changes) has lead to slight variation of the
superpixels selected, despite of random_seed fixed.

In [None]:
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

> model focus mostly on lower end of the flower petals. The irrelevant background of gras is excluded.

increase the amount of top important superpixels from 5 to 10 to includes
now irrelevant background but also parts of the flower stem:

In [None]:
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=10, hide_rest=True)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

In [None]:
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=10, hide_rest=False)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

In [None]:
def plot_colormap_explain(image, explanation, top_class):
    '''plot segments of superpixels colored by the feature weight of the surrogate model'''

    ind =  explanation.top_labels[top_class]

    dict_heatmap = dict(explanation.local_exp[ind])
    heatmap = np.vectorize(dict_heatmap.get)(explanation.segments) 

    fig, axes = plt.subplots(1, 2, figsize=(10,4))

    axes[0].imshow(image)

    img = axes[1].imshow(heatmap, cmap = 'RdBu', vmin  = -heatmap.max(), vmax = heatmap.max())
    _ = plt.colorbar(img, ax=axes[1])
    
plot_colormap_explain(image=batch_found[0][0].numpy(), explanation=explanation, top_class=0)

> The colormap indicates that hte surrogate model identifies the lower part of the petals as the most import part
of the image for the globe-flower.

In the following I evaluate which superpixels are indicative of the class lotus which is the ground truth of the
image under investigation.

In [None]:
index_lotus_truth = 77

In [None]:

temp, mask = explanation.get_image_and_mask(index_lotus_truth, positive_only=True, num_features=5, hide_rest=True)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

In [None]:
plot_colormap_explain(image=batch_found[0][0].numpy(), explanation=explanation, top_class=10)

> * A different part of the flower is indicative of the true class lotus compared to the predicted globe-flower class. This can help
to get insights into why the image was wrongly predicted as globe-flower while for the class lotus only a small part of the image was deemed releveant.
> * Also, the weights in the right image are much lower, then for the predicted class above