# Code used to generate the paper figures
This notebook contains the code used to generate the paper figures. 
It includes:
 - the code to generate the heatmaps/plots for LIME image with/without stratification
 - the code to generate the plots with the binomial distribution and the Shapley weight function
 - the plot to generate the comparison tables in the Experimental section, starting from the CSV file obtained from the long experiments.

In [None]:
import utils as ut
import numpy as np
import pandas as pd
import scipy.special
import json, math,cv2
import tensorflow as tf
import sys, os, importlib
import matplotlib.pyplot as plt
pd.set_option('display.max_columns', None)
from matplotlib.colors import LinearSegmentedColormap
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input

# Stretch Notebook Width to 98% size of the Screen
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

In [None]:
Main_dir        =   os.getcwd()
DS_path         =   os.path.join(Main_dir, "data")
result_folder   =   os.path.join(Main_dir, "result")
paper_figures   =   os.path.join(result_folder,"Paper_Figures")
json_file       =   os.path.join(DS_path,"imagenet_class_index.json")

ut.check_folders(result_folder)
ut.check_folders(paper_figures)

In [None]:
# load pre-trained model and data
model_name = 'ResNet50'
model = ut.load_model(model_name)

In [None]:
# getting ImageNet class names
class_names = ut.get_ImageNet_ClassLabels(json_file) 

In [None]:
from lime_stratified.lime import lime_image
from lime_stratified.lime.lime_image import LimeImageExplainer
from lime_stratified.lime.wrappers.scikit_image import SegmentationAlgorithm

In [None]:
image_name = 'bird5'
file = os.path.join(DS_path,image_name+'.png')
image_to_explain = ut.read_process_image(file,model)

# Black-box prediction function
NOTE: on some platforms this may require a change. 

In [None]:
def bb_predict(imgs):
    # On some platform, you will need model.predict(..) instead of model(..)
    return model.predict(preprocess_input(imgs), verbose=False)
#     return model(preprocess_input(imgs))

In [None]:

predicted = bb_predict(np.array([image_to_explain]))

(predicted_cls_idx,f_x,predicted_cls_lbl) =  ut.get_class_idx_label_score (predicted,class_names)

# predicted_cls = np.argmax(predicted[0])
# f_x = predicted[0][predicted_cls]
print('Predicted Class\t\t: \t',predicted_cls_lbl,'\nClass Probability\t:\t', f_x,'\nPredicted Class Index\t:\t', predicted_cls_idx)

In [None]:
from matplotlib import rc
rc('text',usetex=True)

from skimage.segmentation import mark_boundaries
seg_algo = 'quickshift'
max_dist,_,_,_ = ut.search_segment_number(image_to_explain, target_seg_no=50)

segments,num_segments,segmenter_fn = ut.own_seg(image_to_explain,md=max_dist,ks=4,random_seed=1234,ratio=0.2)

fig,axes = plt.subplots(1,2, figsize=(6,3))
axes[0].imshow(image_to_explain)
axes[1].imshow(mark_boundaries(image_to_explain, segments))
plt.suptitle(f'{num_segments} segments')
plt.savefig(f'{paper_figures}/{image_name}_image_{num_segments}_segments.svg', dpi=150, bbox_inches='tight', pad_inches=0.02)
plt.show()

In [None]:
lime_explainer = LimeImageExplainer(random_state=1234)
use_stratification = True
ret = lime_explainer.explain_instance(image_to_explain, bb_predict,
                                      labels=class_names,
                                      segmentation_fn=segmenter_fn,
                                      top_labels=3, 
                                      hide_color=None,
                                      use_stratification=use_stratification,
                                      batch_size=100,
                                      num_samples=1000)
X, all_Ys, expl = ret                 # when return (data, labels, ret_exp)
Y = all_Ys[:, predicted_cls_idx]

In [None]:
xpld_cls = expl.top_labels[0]
g_x = expl.local_pred[xpld_cls][0]
print('g(x) \t\t= ', g_x)
beta = ut.get_beta_from_expl(expl)
print('sum(beta) \t= ', np.sum(beta))
std_beta = np.std((beta))
mean_beta = np.mean((beta))
print('CV(beta) \t= ',std_beta/mean_beta)

In [None]:
from skimage.segmentation import mark_boundaries
# def axis_off(ax):
#     ax.set_xticks([], []) ; ax.set_yticks([], [])
expl.image = image_to_explain
heatmap = ut.heatmap_from_beta(segments, beta) 
v = np.max(np.abs(heatmap))
print(f'v = {v}')

temp_1, mask_1 = expl.get_image_and_mask(expl.top_labels[0], positive_only=True, num_features=1000, hide_rest=True, min_weight=v/2)
temp_2, mask_2 = expl.get_image_and_mask(expl.top_labels[0], positive_only=True, num_features=1000, hide_rest=False, min_weight=v/2)

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 4))
ax1.imshow(mark_boundaries(temp_1.astype(np.uint8), mask_1))
ax2.imshow(mark_boundaries(temp_2.astype(np.uint8), mask_2))

im = ax3.imshow(heatmap, cmap='bwr', vmin=-v, vmax=v)
fig.colorbar(im, ax=ax3, shrink=0.60)
ut.axis_off(ax1)
ut.axis_off(ax2)
ut.axis_off(ax3)
ax2.set_title(f'predicted as {class_names[predicted_cls_idx]}  f(x)={f_x:.5}  g(x)={g_x:.5}')
plt.savefig(f'{paper_figures}/{image_name}_image_mask_heatmap_single.svg', dpi=150, bbox_inches='tight', pad_inches=0.02)
plt.show()

In [None]:
# Feature Importance Heatmap generated based on beta values
heatmap = ut.heatmap_from_beta(segments, beta) 
v = np.max(np.abs(heatmap))
fig, ax = plt.subplots(1, 1, figsize=(2.5, 2.5))
im = ax.imshow(heatmap, cmap='bwr', vmin=-v, vmax=v)
fig.colorbar(im, ax=ax, shrink=0.55, orientation='horizontal', anchor=(.5, 1.5))
ut.axis_off(ax)
beta_title = '\\widehat{\\beta}' if use_stratification else '\\beta'
plt.title(f'$k = {num_segments}$ \n $CV({beta_title}) = \\mathbf{{ {ut.get_CV_beta(beta):.3} }}$')
plt.tight_layout()
plt.savefig(f'{paper_figures}/{image_name}_heatmap.svg', dpi=150, bbox_inches='tight', pad_inches=0.02)
plt.show()

In [None]:
# Classification Score Plot
fig, ax = plt.subplots(1, 1, figsize=(2.5, 2.5))
plt.gca().set_aspect('equal')
ut.plot_classification_score(ax, expl, X, Y, f_x)
plt.title(f'$RC(Y) = \\mathbf{{ {ut.get_RCY(Y, f_x):.3} }}$')
plt.tight_layout()
plt.savefig(f'{paper_figures}/{image_name}_RC_Y.png', dpi=150, bbox_inches='tight', pad_inches=0.02)

# Generate the plot for the binomial distribution and the Shapley weight function (Figure 3)

In [None]:
plt.rcParams.update({"text.usetex": True })
N = 100 # segments
markers = ['+','o','x']
fig,axes = plt.subplots(2,2, figsize=(6, 3.2), sharex=True)
for ii, k in enumerate([10, 20, 50]):
    MD = [ut.pdf_bern(k, s) for s in range(k+1)]
    print(np.min(MD))
    axes[0,0].plot(range(k+1), MD, marker=markers[ii], markersize=3, label=f'$k$={k}', lw=.5)
    axes[1,0].plot(range(k+1), MD, marker=markers[ii], markersize=3, label=f'$k$={k}', lw=.5)
axes[0,0].set_title('\\noindent {\\bf (A)} Binomial distribution PMF $\\mathcal{B}(k, |m|)$\\\\ \\phantom{niim} for a Bernoulli process $B(0.5)$')#, loc='center')
axes[1,0].set_xlabel('Number of preserved superpixels $|m|$')
axes[1,0].set_yscale('log')
axes[0,0].set_ylabel('Value')
axes[1,0].set_ylabel('Value (logscale)')
for ii, k in enumerate([10, 20, 50]):
    MD = [ut.shapley_p(k, s) for s in range(k+1)]
    print(np.min(MD))
    axes[0,1].plot(range(k+1), MD, marker=markers[ii], markersize=3, label=f'$k$={k}', lw=.5)
    axes[1,1].plot(range(k+1), MD, marker=markers[ii], markersize=3, label=f'$k$={k}', lw=.5)
axes[0,1].set_title('{\\bf (B)} Shapley weight function $\Gamma(k, |m|)$', loc='right')
axes[1,1].set_xlabel('Number of preserved superpixels $|m|$')
axes[1,1].set_yscale('log')
axes[0,1].legend(loc='upper right')
plt.tight_layout(pad=0.2)
plt.savefig(f'{paper_figures}/binom-shapley.pdf', transparent=True, pad_inches=0.1, bbox_inches='tight')
plt.show()

# Generate plots for the paper for Figure 4 and 5
This code may take several minutes to run.

In [None]:
file = os.path.join(DS_path,'ILSVRC2012_test_00000125.JPEG')
image_to_explain = ut.read_process_image(file,model)

predicted = bb_predict(np.array([image_to_explain]))

(predicted_cls_idx,f_x,predicted_cls_lbl) =  ut.get_class_idx_label_score (predicted,class_names)
f_x = predicted[0][predicted_cls_idx]
print('Predicted Class\t\t:\t',predicted_cls_lbl,'\nClass Probability\t:\t', f_x,'\nPredicted Class Index\t:\t', predicted_cls_idx)

for k in [50, 100, 150, 200]:
    max_dist,_,_,_ = ut.search_segment_number(image_to_explain, target_seg_no=k)
    segments,num_segments,segmenter_fn = ut.own_seg(image_to_explain,md=max_dist,ks=4,random_seed=1234,ratio=0.2)

    print(f'k, max_dist\t        :\t {k}, {max_dist}')   
    for use_stratification in [False, True]:
        sig = f'{k}_{use_stratification}'
        lime_explainer = LimeImageExplainer(random_state=1234)
        beta_arr, rcY_arr = [], []
        print('Explanation for\t\t:\t '+sig, end=' ')
        for repeat in range(10):
            print(repeat+1, end=' ')
            ret = lime_explainer.explain_instance(preprocess_input(image_to_explain), bb_predict,
                                                  labels=class_names, segmentation_fn=segmenter_fn,
                                                  top_labels=3, hide_color=None, use_stratification=use_stratification,
                                                  batch_size=100, num_samples=1000, progress_bar=False)
            
            X, all_Ys, expl = ret # when return (data, labels, ret_exp)
            xpld_cls = expl.top_labels[0]
            Y = all_Ys[:, xpld_cls]
            
            beta_arr.append(ut.get_beta_from_expl(expl)) 
            rcY_arr.append(ut.get_RCY(Y, f_x))
            
        beta = np.mean(beta_arr, axis=0)
        
        
        heatmap = ut.heatmap_from_beta(segments, beta)
        v = np.max(np.abs(heatmap))
        fig, ax = plt.subplots(1, 1, figsize=(2.5, 2.5))
        im = ax.imshow(heatmap, cmap='bwr', vmin=-v, vmax=v)
        fig.colorbar(im, ax=ax, shrink=0.55, orientation='horizontal', anchor=(.5, 1.5))
        ut.axis_off(ax)
        beta_title = '\\widehat{\\beta}' if use_stratification else '\\beta'
        plt.title(f'$k = {num_segments}$ \n $CV({beta_title}) = \\mathbf{{ {ut.get_CV_beta(beta):.3} }}$')
        plt.tight_layout()
        plt.savefig(f'{paper_figures}/heatmap_{sig}.svg', dpi=150, bbox_inches='tight', pad_inches=0.02)
#         plt.show()
        
        fig, ax = plt.subplots(1, 1, figsize=(2.5, 2.5))
        plt.gca().set_aspect('equal')
        ut.plot_classification_score(ax, expl, X, Y, f_x)
        Y_title = '\\widehat{Y}' if use_stratification else 'Y'
        plt.title(f'$RC({Y_title}) = \\mathbf{{ {np.mean(rcY_arr):.3} }}$')
        plt.tight_layout()
        plt.savefig(f'{paper_figures}/RC_Y_{sig}.svg', dpi=150, bbox_inches='tight', pad_inches=0.02)
        plt.show()

# Generate the comparison plots (Figure 6)
This code starts from the .CSV file of the long experiments, which may take more than 1 day to be computed.

In [None]:
num_samples                 = 1000
segs_list                   = [50,100,150,200]

# Change the filename if you do not want to use the precomputed results' CSV.
df = pd.read_csv (f'result/precomputed_results_1000_1_150_[50, 100, 150, 200].csv', sep=';', index_col=0, 
                 dtype={'Hide_color':str, 'filename':str})

df['IQR0595_Y'] = df.q95_Y - df.q05_Y
df['IQR0199_Y'] = df.q99_Y - df.q01_Y
df['IQR0595_Y_over_fx'] = df.IQR0595_Y / df.f_x
df['IQR0199_Y_over_fx'] = df.IQR0199_Y / df.f_x
df['RC_Y'] = df.IQR0199_Y_over_fx
df['CV_beta'] = df.cv_beta
df.head(3)

In [None]:
# FIGURE 6  CV PLOTS
fig, ax = plt.subplots(2,4,figsize=(6,3), sharey=True, sharex=False)

for j, sss in enumerate(segs_list):    
    df2 = df[(sss-50 <= df.segments) & (df.segments <= sss) & (df.hide_color =='mean-filled') ].copy()

    df_orig = df2[(df2.use_stratification == False)].sort_values(by='filename')
    df_seg = df2[df2.use_stratification == True].sort_values(by='filename')
    
    x1 = df_orig.RC_Y
    y1 = df_orig.CV_beta  

    x2 = df_seg.RC_Y
    y2 = df_seg.CV_beta 
    
    v = np.max([max(x2), max(y2), max(x1), max(y1)])
    for i in range(2):
        if i==0: ax[i,j].set_xlabel('RC($Y$)')
        if i==1: ax[i,j].set_xlabel('RC($\widehat{Y}$)')
        if j==0 and i==0: ax[i,j].set_ylabel('Monte Carlo\nCV($\\beta$)')
        if j==0 and i==1: ax[i,j].set_ylabel('Stratified\nCV($\\widehat\\beta$)')
        ax[i,j].set_xlim(-0.1, 2.1)
        ax[i,j].set_yscale('log'); ax[i,j].set_ylim(0.2, 20)
        ax[0,j].set_xticks([0,1,2]) ; ax[0,j].set_xticklabels(['', '', ''])
        ax[1,j].set_xticks([0,1,2])
    def scatter_cv(ax, x, y):
        H = 1.0
        sel = (y > H)
        ax.scatter(x[sel], y[sel], s=5, c='blue', alpha=0.25)
        sel = (y <= H) 
        ax.scatter(x[sel], y[sel], s=4, c='red')
        ax.plot([-10,10], [H,H], ls=':', c='red', lw=1)
        ax.text(1.2, 0.4, f'{np.sum(sel)} images', c='darkred')
    scatter_cv(ax[0,j], x1, y1)
    scatter_cv(ax[1,j], x2, y2)
    ax[0,j].set_title(f'$k={sss}$', fontsize=11)
    
plt.tight_layout(pad=0.2, w_pad=0)
plt.savefig(f'{paper_figures}/CV_plots.pdf')

In [None]:
######### FIGURE 6 R2
fig, ax = plt.subplots(1,4,figsize=(6,2), sharey=True, sharex=True)

for j, sss in enumerate(segs_list):    
    df2 = df[(sss-50 <= df.segments) & (df.segments <= sss) & (df.hide_color =='mean-filled') ].copy()

    df_orig = df2[(df2.use_stratification == False)].sort_values(by='filename')
    df_seg = df2[df2.use_stratification == True].sort_values(by='filename')
    
    x = df_orig.r2
    y = df_seg.r2 
    
    ax[j].set_aspect('equal')
    ax[j].set_xlabel('Monte Carlo $R^2$')
    if j==0: ax[j].set_ylabel('\nStratified $R^2$')
    ax[j].set_xlim(0,1)
    ax[j].set_ylim(0,1)
    
    sel = (x >= y)
    ax[j].scatter(x[sel], y[sel], s=5, c='blue', alpha=0.50)
    sel = (x < y)
    ax[j].scatter(x[sel], y[sel], s=5, c='green', alpha=0.50)
    ax[j].plot([0,1], [0,1], ls=':', c='black', lw=1)
    ax[j].text(0.05, 0.75, 'stratif.\nbetter')
    ax[j].text(0.3, 0.05, 'M.C. better')
    ax[j].set_title(f'$k={sss}$', fontsize=11)

plt.tight_layout()
plt.savefig(f'{paper_figures}/R2_plots.pdf')

# Generate plots for the selected examples (Figure 7)
This code may take a few hours to complete.

In [None]:
image_no = [
    #orange  wardrobe  milk_can  lynx  ringneck_snake
    114,     147,      60,       144,  66
]
for ino in image_no:
    file = os.path.join(DS_path,f'ILSVRC2012_test_{ino:08}.JPEG')
    image_to_explain = ut.read_process_image(file,model)
    predicted = bb_predict(np.array([image_to_explain]))
    (predicted_cls_idx,f_x,predicted_cls_lbl) =  ut.get_class_idx_label_score (predicted,class_names)
    f_x = predicted[0][predicted_cls_idx]
    print('Predicted Class\t\t:\t',predicted_cls_lbl,'\nClass Probability\t:\t', f_x,'\nPredicted Class Index\t:\t', predicted_cls_idx)

    for k in [50, 200]:
        max_dist,_,_,_ = ut.search_segment_number(image_to_explain, target_seg_no=k)
        segments,num_segments,segmenter_fn = ut.own_seg(image_to_explain,md=max_dist,ks=4,random_seed=1234,ratio=0.2)
        print(f'k = {k}  max_dist = {max_dist}')
        for use_stratification in [False, True]: 
            lime_explainer = LimeImageExplainer(random_state=1234)
            beta_arr, rcY_arr = [], []
            sig = f'{ino}_{k}_{use_stratification}'
            for repeat in range(10):
                ret = lime_explainer.explain_instance(image_to_explain, bb_predict,
                                                      labels=class_names, segmentation_fn=segmenter_fn,
                                                      top_labels=3, hide_color=None, use_stratification=use_stratification,
                                                      batch_size=500, num_samples=1000, progress_bar=False)
                X, all_Ys, expl = ret # when return (data, labels, ret_exp)
                expainled_cls = expl.top_labels[0]
                Y = all_Ys[:, expainled_cls]
                beta_arr.append(ut.get_beta_from_expl(expl)) 
                rcY_arr.append(ut.get_RCY(Y, f_x))
            beta = np.mean(beta_arr, axis=0)    
            
            heatmap = ut.heatmap_from_beta(segments, beta)
            fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(4, 2))
            ax1.set_aspect('equal'); ax2.set_aspect('equal')
            ut.axis_off(ax1); ut.axis_off(ax2)
            v = np.max(np.abs(heatmap))
            im = ax1.imshow(heatmap, cmap='bwr', vmin=-v, vmax=v)
            ut.plot_classification_score(ax2, expl, X, Y, f_x, plot_everything=False)
            plt.suptitle(f'{num_segments} $|$ {ut.get_CV_beta(beta):.3} $|$ {np.mean(rcY_arr):.3}', fontsize=28, y=1.05)
            plt.savefig(f'{paper_figures}/dual_{sig}.svg', dpi=150, bbox_inches='tight', pad_inches=0.02)
            plt.show()