Der Code basiert auf: https://github.com/rachtibat/zennit-crp/tree/master/tutorials und wählt die 8 Bilder mit der höchsten Relevanz für die ausgewählten Features.

Die Features habe ich mit Hilfe des JS-Scripts [getRel.mjs](getRel.mjs) ausgewählt.

Dieses Script nutzt die [crp](crp.yml) conda umgebung

In [None]:
import torch
from torchvision.models.vgg import vgg16_bn
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image
from zennit.canonizers import SequentialMergeBatchNorm
from zennit.composites import EpsilonPlusFlat

import os

# Only tested mps, if it doesnt run on NVIDIA Cuda try to run on CPU 
device = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))
# device = torch.device("cpu")
print("Device:", device)

weights = models.VGG16_BN_Weights.IMAGENET1K_V1.DEFAULT
model = vgg16_bn(weights=weights).to(device)
model.eval()

canonizers = [SequentialMergeBatchNorm()]
composite = EpsilonPlusFlat(canonizers)

In [None]:
import torchvision
from crp.concepts import ChannelConcept
from crp.helper import get_layer_names
from crp.attribution import CondAttribution
from crp.visualization import FeatureVisualization
from VGG16_ImageNet.download_imagenet import download

cc = ChannelConcept()

layer_names = get_layer_names(model, [torch.nn.Conv2d, torch.nn.Linear])
layer_map = {layer : cc for layer in layer_names}

attribution = CondAttribution(model)

transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
preprocessing =  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

data_path = "ImageNet_data"

if data_path is None:
    data_path = "ImageNet_data"
    download(data_path)
    
# apply no normalization here!
imagenet_data = torchvision.datasets.ImageNet(data_path, transform=transform, split="val")  

In [None]:
from typing import Dict, Any, Tuple, Iterable
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
def plot_grid_save(ref_c: Dict[int, Any], filename: str, cmap_dim=1, cmap="bwr", 
                   vmin=None, vmax=None, symmetric=True, resize=None, padding=True, figsize=(6, 6)):
    keys = list(ref_c.keys())
    nrows = len(keys)
    value = next(iter(ref_c.values()))

    if cmap_dim not in (0, 1):
        raise ValueError("'cmap_dim' must be 0 or 1.")

    if isinstance(value, Tuple) and isinstance(value[0], Iterable):
        nsubrows = len(value)
        ncols = len(value[0])
    elif isinstance(value, Iterable):
        nsubrows = 1
        ncols = len(value)
    else:
        raise ValueError("'ref_c' dictionary must contain an iterable of torch.Tensor, np.ndarray or PIL Image or a tuple of thereof.")

    fig = plt.figure(figsize=figsize)
    outer = gridspec.GridSpec(nrows, 1, wspace=0, hspace=0.2)

    for i in range(nrows):
        inner = gridspec.GridSpecFromSubplotSpec(nsubrows, ncols, subplot_spec=outer[i], wspace=0, hspace=0.1)

        for sr in range(nsubrows):

            if nsubrows > 1:
                img_list = ref_c[keys[i]][sr]
            else:
                img_list = ref_c[keys[i]]

            for c in range(ncols):
                ax = plt.Subplot(fig, inner[sr, c])

                if sr == cmap_dim:
                    img = imgify(img_list[c], cmap=cmap, vmin=vmin, vmax=vmax,
                                 symmetric=symmetric, resize=resize, padding=padding)
                else:
                    img = imgify(img_list[c], resize=resize, padding=padding)

                ax.imshow(img)
                ax.set_xticks([])
                ax.set_yticks([])

                if sr == 0 and c == 0:
                    ax.set_ylabel(keys[i])

                fig.add_subplot(ax)

    outer.tight_layout(fig)

    fig.savefig(filename, format="png", bbox_inches="tight")
    plt.close(fig)

In [None]:
fv_path = "VGG16_ImageNet"
fv = FeatureVisualization(attribution, imagenet_data, layer_map, preprocess_fn=preprocessing, path=fv_path)

In [None]:
%matplotlib inline
from crp.image import plot_grid
from crp.image import imgify
from crp.image import vis_opaque_img

rel = [327, 71, 239, 79, 340, 465, 19, 146, 113, 261, 321, 57, 107, 282, 117, 476, 168, 194, 60, 484, 384, 325, 351, 326, 16, 504, 115, 30, 210, 373, 492, 307, 25, 153, 195, 294, 152, 260, 42, 469, 35, 161, 89, 316, 401, 256, 507, 306, 361, 12, 451, 39, 202, 211, 277, 324, 188, 236, 203, 328, 292, 354, 131, 487]

destinationFolder = "results_attr_show"
isExist = os.path.exists(destinationFolder)
if not isExist:
    os.makedirs(destinationFolder)

for r in rel:
    ref_c = fv.get_max_reference([r], "features.40", "relevance", (0, 8), composite=composite, plot_fn=vis_opaque_img)

    img = plot_grid_save(ref_c, f"{destinationFolder}/{r}.png", cmap="bwr", symmetric=True, figsize=(6, 1))