In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
from ExplainableModels import ExplainableModel
from agreement import attributions_preprocessing
import os
from matplotlib import pyplot as plt
import matplotlib as mtpl
from mpl_toolkits.axes_grid1 import make_axes_locatable

from typing import List
import numpy as np

In [4]:
def load_data(data_name:str, n_classes:int, batch_size:int=1,algorithm:str='SmoothGrad'):
    model = ExplainableModel(model_name = "resnet50",
                                        train_data_name = data_name,
                                        n_classes = n_classes)
    attributions_clean = model.explain_dataset(algorithm=algorithm,
                        data_name=data_name,
                        data_split="test", batch_size=batch_size,)
    
    data_clean=ExplainableModel.load_data(data_name=data_name,
                                        data_split="test")
    
    return attributions_clean,data_clean

In [15]:
def plot(imgs: List[np.ndarray]):
    titles = [r'$\boldsymbol{x}$', 
                r'$\mathcal{I}(\boldsymbol{x},f)$',
                r'$\boldsymbol{x}_{\text{adv}}$',
                r'$\mathcal{I}(\boldsymbol{x}_{\text{adv}}, f)$']

    fig, axs = plt.subplots(1, 4, figsize=(20, 10), constrained_layout=False)  # Ensures equal sizes
    fig.subplots_adjust(wspace=0.2, hspace=0.2)  # Adjust spacing between images

    for i, (ax, img, title) in enumerate(zip(axs, imgs, titles)):
        ax.patch.set_edgecolor('black')
        ax.patch.set_linewidth(3)
        #print(title)
        ax.set_title(title, fontsize=26, pad=12)
        im = ax.imshow(img, cmap="terrain")

        # Create a divider for each image to ensure colorbars donâ€™t affect sizing
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="3%", pad=0.15)  

        if i % 2 != 0:  # Add colorbar only to odd images
            cbar = fig.colorbar(im, cax=cax, orientation='vertical')
            cbar.ax.tick_params(labelsize=16)  # Adjust tick label size if needed
            
        else:
            cax.remove()  # Remove extra space for even images

        ax.set_xticks([])
        ax.set_yticks([])

In [None]:
def plots_dataset(data_name:str, n_classes:int, batch_size:int,algorithm:str):
    attributions_clean, data_clean = load_data(data_name, n_classes, batch_size,algorithm)
    for i, (name, adv) in enumerate(attributions_clean.items()):
        plot([
            data_clean.__getitem__from_name__(name)[0].numpy().transpose(1, 2, 0),
            attributions_preprocessing(attributions_clean[name], "l1", None, 1, 'quantil_local', None, None).detach().cpu().numpy(),
            attributions_preprocessing(attributions_clean[name], "l1", None, 1, 'quantil_local', None, None).detach().cpu().numpy(),
            attributions_preprocessing(attributions_clean[name], "l1", None, 1, 'quantil_local', None, None).detach().cpu().numpy()
        ])
        
        plt.show()

In [None]:
plots_dataset("dermamnist",n_classes=7, batch_size=50,algorithm='SmoothGrad')