## Init notebook

This is run once.

## Preliminaries

In [1]:
import quantus

  from .autonotebook import tqdm as notebook_tqdm
2023-09-13 17:30:10.435382: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:


# Import libraries.
from IPython.display import clear_output
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.font_manager as font_manager
import warnings
import torch
import captum
import torchvision
import wandb
from tqdm import tqdm
import os
import json

from zennit import attribution as zattr
from zennit import image as zimage
from zennit import composites as zcomp

from models import models
from data import dataloaders, datasets, transforms
from attribution import zennit_utils as zutils
from utils import arguments as argument_utils
from main import *

try:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('Using device:', torch.cuda.get_device_name(0))
except:
    pass

/home/lweber/work-code/Quantus/quantus/__init__.py
Using device: NVIDIA TITAN RTX


## Download or Load Results

In [3]:
os.environ["WANDB_API_KEY"] = "343d642ebb6021c0303b16436ddf0be59cb0696f"
wandb_projectname = "denoise-sanity-checks-2"
legend_str = {"top_down": "top-down", "bottom_up": "bottom_up"}
method_str = {"gradient": "Gradient", "lrp-epsilon": r"LRP-$\varepsilon$", "lrp-zplus": r"LRP-$z^+$", "guided-backprop": "Guided Backprop", "grad-cam": "GradCAM"}
jsonsavepath = "/media/lweber/f3ed2aae-a7bf-4a55-b50d-ea8fb534f1f5/mptc/eMPRT-sMPRT-scores-from-wandb.json"
figurepath = "/media/lweber/f3ed2aae-a7bf-4a55-b50d-ea8fb534f1f5/mptc/figures/raw-plots-sMPRT-eMPRT"
redownload_results = False

download_filter = {}
#download_filter = {"xai_methodname": "grad-cam"}

# Get all results from wandb
if not os.path.exists(jsonsavepath) or redownload_results:
    print("Downloading Results from wandb...")
    scores = []
    api = wandb.Api()
    entity, project = "leanderweber", wandb_projectname  # set to your entity and project 
    runs = api.runs(entity + "/" + project) 
    with tqdm(total=len(runs)) as pbar:
        for run in runs: 
            # .summary contains the output keys/values for metrics like accuracy.
            #  We call ._json_dict to omit large files 
            config = {k: v for k,v in run.config.items() if not k.startswith('_')}
            if all([config[s] != v for s, v in download_filter.items()]):
                summary = run.summary
                if "scores" in summary.keys():
                    scores.append(config, summary["scores"])

            pbar.update(1)

    with open(jsonsavepath, "w") as jsonfile:
        json.dump(scores, jsonfile)

else:
    with open(jsonsavepath, "r") as jsonfile:
        scores = json.load(jsonfile)

Downloading Results from wandb...


375it [04:26,  1.41it/s]                         


## Plots



In [None]:
# General Plot Config
mpl.rcParams['font.family']='serif'
cmfont = font_manager.FontProperties(fname=mpl.get_data_path() + '/fonts/ttf/cmr10.ttf')
mpl.rcParams['font.serif']=cmfont.get_name()
mpl.rcParams['mathtext.fontset']='cm'
mpl.rcParams['axes.unicode_minus']=False
plt.rcParams.update({'font.size': 15})

# Setting up grouping variables etc. for plots
between_plot_filter = ["eval_layer_order", "model_name"]
within_plot_filter = ["xai_methodname"]

print("Setting up filters...")
between_plot_filters = []
within_plot_filters = []
for config, _ in scores:
    duplicate = False
    for fil in between_plot_filters:
        if all([config[filtered_category] == fil[filtered_category] for filtered_category in between_plot_filter]):
            duplicate=True
    if not duplicate:
        between_plot_filters.append({filtered_category: config[filtered_category] for filtered_category in between_plot_filter})

    duplicate = False
    for fil in within_plot_filters:
        if all([config[filtered_category] == fil[filtered_category] for filtered_category in within_plot_filter]):
            duplicate=True
    if not duplicate:
        within_plot_filters.append({filtered_category: config[filtered_category] for filtered_category in within_plot_filter})

In [None]:
# Define functions

def select_runs(runs, filter):
    return [(c, r) for c, r in runs if all([c[filtered_category] == filter[filtered_category] for filtered_category in filter.keys()])]

def plot_mprt_lineplot(runs, title, within_plot_filters, savefilename):
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))

    ax.set_title(title)

    ax.set_ylabel("SSIM")
    ax.set_ylim([0.0, 1.1])
    ax.set_yticks([0, 0.5, 1.0])
    ax.set_yticklabels([0, 0.5, 1])

    layers = list(runs[0][1].keys())
    ax.set_xlabel("Layers")
    ax.set_xticks(list(range(len(layers))))
    ax.set_xticklabels(layers)

    std_alpha = 0.2
    palette = cm.get_cmap("tab10")(np.linspace(0, 1, 10))
    linewidth = 2.5

    for w, wpf in enumerate(within_plot_filters):
        
        wpf_runs = select_runs(wpf)

        noisedraw_filters = []
        for config, _ in wpf_runs:
            duplicate = False
            for fil in noisedraw_filters:
                if config["xai_n_noisedraws"] == fil["xai_n_noisedraws"]:
                    duplicate=True
            if not duplicate:
                noisedraw_filters.append({"xai_n_noisedraws": config["xai_n_noisedraws"]})

        for n, ndf in noisedraw_filters:
            ndf_runs = select_runs(ndf)

            # Sort Runs
            runs_to_plot = {l: [] for l, lay in enumerate(layers)}
            for c, r in ndf_runs:
                for l, lay in enumerate(layers):
                    runs_to_plot[l] += r[lay]

            means = np.array([np.mean(r) for l, r in runs_to_plot.items()])
            stds = np.array([np.std(r) for l, r in runs_to_plot.items()])
            delta = means[-1] - means[0]

            methodname = c["xai_methodname"]
            noisedraws = c["xai_n_noisedraws"]

            ax.plot(list(range(len(means))), means, alpha=1.0*(0.8**n), linewidth=linewidth, marker=".", color=palette[w], label=r"{} with N={}".format(method_str[methodname], noisedraws))
            ax.fill_between(list(range(len(means))), means+stds, means-stds, facecolor=palette[w], alpha=std_alpha)

    plt.legend()
    plt.tight_layout()
    plt.grid(True)

    fig.savefig(savefilename)
    plt.show()

def plot_twineplot():
    pass

### sMPRT - Line Plots

In [None]:
selection_filters = [{"eval_metricname": "smprt"}]
plotted_xai_n_noisedraws = [1, 300]

# Iterate through filters
for selection_filter in selection_filters:
    selected_runs = select_runs(scores, selection_filter)
    selected_runs = [(c, r) for c, r in selected_runs if c["xai_n_noisedraws"] in plotted_xai_n_noisedraws]
    for bpf in between_plot_filters:
        bpf_runs = select_runs(selected_runs, bpf)

        print(bpf)
        fname = f"smpr-lineplot-imagenet"
        for k, v in bpf.items():
            fname += f"-{v}"
        filepath = os.path.join(figurepath, fname)
        plot_mprt_lineplot(bpf_runs, "sMPRT", within_plot_filters, filepath)


### eMPRT - Line Plots

### sMPRT vs. eMPRT - Twine Plots

In [None]:
between_plot_filter = ["eval_layer_order", "model_name"]
within_plot_filter = ["xai_methodname"]
x_axis_attribute = "xai_n_noisedraws"

print("Setting up filters...")
between_plot_filters = []
within_plot_filters = []
for config, _ in res:
    duplicate = False
    for fil in between_plot_filters:
        if all([config[filtered_category] == fil[filtered_category] for filtered_category in between_plot_filter]):
            duplicate=True
    if not duplicate:
        between_plot_filters.append({filtered_category: config[filtered_category] for filtered_category in between_plot_filter})

    duplicate = False
    for fil in within_plot_filters:
        if all([config[filtered_category] == fil[filtered_category] for filtered_category in within_plot_filter]):
            duplicate=True
    if not duplicate:
        within_plot_filters.append({filtered_category: config[filtered_category] for filtered_category in within_plot_filter})

print(len(res))

for bpf in between_plot_filters:

    bpf_runs = [(c, m) for c, m in res if all([c[filtered_category] == bpf[filtered_category] for filtered_category in between_plot_filter])]

    # Init Plot
    import matplotlib as mpl
    import matplotlib.font_manager as font_manager
    mpl.rcParams['font.family']='serif'
    cmfont = font_manager.FontProperties(fname=mpl.get_data_path() + '/fonts/ttf/cmr10.ttf')
    mpl.rcParams['font.serif']=cmfont.get_name()
    mpl.rcParams['mathtext.fontset']='cm'
    mpl.rcParams['axes.unicode_minus']=False
    import matplotlib.pyplot as plt
    plt.rcParams.update({'font.size': 15})
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))

    ax.set_title(f"S-MPRT Variation over #Noise Draws (ImageNet, {bpf['model_name']}, {legend_str[bpf['eval_layer_order']]}")

    ax.set_ylabel("SSIM")
    ax.set_ylim([0.0, 1.1])
    ax.set_yticks([0, 0.5, 1.0])
    ax.set_yticklabels([0, 0.5, 1])

    ax.set_xlabel("#Noise Draws")
    # ax.set_xlim([0, 1000])
    # ax.set_xticks([1, 50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000])
    # ax.set_xticklabels([1, 50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
    ax.set_xlim([0, 300])
    ax.set_xticks([1, 50, 100, 200, 300])
    ax.set_xticklabels([1, 50, 100, 200, 300])

    alphas = [1.0, 0.75, 0.5]
    std_alpha = 0.2
    hlinealpha = 0.5
    palette = cm.get_cmap("tab10")(np.linspace(0, 1, 10))
    markers = [".","d", "D"]
    linewidth = 2.5


    for w, wpf in enumerate(within_plot_filters):
        
        wpf_runs = [(c, m) for c, m in bpf_runs if all([c[filtered_category] == wpf[filtered_category] for filtered_category in within_plot_filter])]

        # Sort Runs

        layer_ids = [0, int(len(bpf_runs[0][1])/2),len(bpf_runs[0][1])-1]
        for l, layer_id in enumerate(layer_ids):

            x_axis_attribute_vals = sorted(list(set([c[x_axis_attribute] for c, m  in wpf_runs])))
            print(x_axis_attribute_vals)
            runs_to_plot = {p: [] for p in x_axis_attribute_vals}
            for c, m in wpf_runs:
                x_attr = c[x_axis_attribute]
                runs_to_plot[x_attr].append(m[layer_id])

            means = np.array([np.mean(m) for p, m in runs_to_plot.items()])
            stds = np.array([np.std(m) for p, m in runs_to_plot.items()])
            delta = means[-1] - means[0]

            ax.plot(x_axis_attribute_vals, means, alpha=alphas[l], linewidth=linewidth, marker=markers[l], color=palette[w], label=r"{} Layer {}, $\Delta=${:.2f}".format(method_str[wpf["xai_methodname"]], layer_id, delta))
            #plt.axhline(y=means[0], color=palette[w], linestyle='-', linewidth=1, zorder=0, alpha=hlinealpha)
            ax.fill_between(x_axis_attribute_vals, means+stds, means-stds, facecolor=palette[w], alpha=std_alpha)

    plt.legend()
    plt.tight_layout()
    plt.grid(True)

    savefile = os.path.join("/media/lweber/f3ed2aae-a7bf-4a55-b50d-ea8fb534f1f5/mptc/figures/smprt-raw", "{}-{}-{}.svg".format("imagenet", bpf['model_name'], legend_str[bpf['eval_layer_order']]))
    fig.savefig(savefile)
    plt.show()