Der Code berechnet die Heatmaps (LRP) und auch die Heatmaps der relevantesten Features.<br>
Außerdem wird von den Pertubierten die Differenz zum Originalbild berechnet und visuell dargestellt (die Differenzberechnung basiert auf den PNG Bildern und nicht auf den rohen Tensor Daten)<br>
Außerdem werden noch die relevantesten Features und deren Relevanz in ner txt Datei abgespeichert.

Dieses Script nutzt die [crp](crp.yml) conda umgebung

In [None]:
import torch
import torchvision.models as models
import torchvision.transforms as T
import numpy as np
from PIL import Image

import json
import urllib.request

import matplotlib.pyplot as plt

from pathlib import Path

# Code only tested on mps, if it doesnt run on NVIDIA Cuda try to run on CPU
device = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))
# device = torch.device("cpu")
print("Device:", device)

imagenet_labels = models.VGG16_Weights.DEFAULT.meta["categories"]


weights=models.VGG16_BN_Weights.IMAGENET1K_V1.DEFAULT
model = models.vgg16_bn(weights=weights).to(device)
model.eval()

transform01 = T.Compose([T.Resize(256),T.CenterCrop(224)])
transform02 = T.Compose([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

pass

In [None]:
from zennit.composites import EpsilonPlusFlat
from zennit.canonizers import SequentialMergeBatchNorm
from crp.attribution import CondAttribution
from crp.concepts import ChannelConcept
from crp.helper import get_layer_names

from crp.image import imgify

composite = EpsilonPlusFlat([SequentialMergeBatchNorm()])
attribution = CondAttribution(model, no_param_grad=True)
cc = ChannelConcept()

In [None]:
def list_relevance(concept_ids,rel_values):
    concept_ids = concept_ids.cpu().tolist()
    rel_values = rel_values.cpu().tolist()

    resStr = ""
    for cid, val in zip(concept_ids, rel_values):
        line = f"{cid}: {val*100:.3f}%"
        # print(line)
        resStr+=line+'\n'
    return resStr

In [None]:
def load_img(img_src):
    image = Image.open(img_src)

    image = transform01(image)
    return image

In [None]:
def crp(sample):
    sample.requires_grad = True
    output = model(sample)  
    prediction = torch.argmax(output, dim=1).item()
    prediction_text = imagenet_labels[prediction]
    
    # print(f"Vorhergesagte Klasse: {prediction_text} ({prediction})")

    conditions = [{"y": [prediction]}]
    attr = attribution(sample, conditions, composite)

    totalHeatMap = attr.heatmap
    
    layer_names = get_layer_names(model, [torch.nn.Conv2d, torch.nn.Linear])

    attr = attribution(sample, conditions, composite, record_layer=layer_names)

    attr.activations['features.40'].shape, attr.relevances['features.40'].shape

    # layer features.40 has 512 channel concepts
    rel_c = cc.attribute(attr.relevances['features.40'], abs_norm=True)
    rel_c.shape

    # the six most relevant concepts and their contribution to final classification in percent
    rel_values, concept_ids = torch.topk(rel_c[0], 6)

    conditions = [{'features.40': [id], 'y': [prediction]} for id in concept_ids]

    attr = attribution(sample, conditions, composite)
    return {
        'prediction': prediction,
        'total': totalHeatMap,
        'topk': attr.heatmap,
        'rel_values': rel_values,
        'concept_ids': concept_ids,
    }

In [None]:
def display(src):
    sample = torch.load(src).to(device).detach()
    result = crp(sample)
    img = result['total'].squeeze().cpu()
    return imgify(img, symmetric=True)

def show_diff(img1_path, img2_path):
    img1 = np.array(Image.open(img1_path).convert("RGB")).astype(np.float32) / 255.0
    img2 = np.array(Image.open(img2_path).convert("RGB")).astype(np.float32) / 255.0

    diff = np.abs(img1 - img2)

    # plt.figure(figsize=(12,4))
    # plt.subplot(1,3,1); plt.imshow(img1); plt.title("Image 1"); plt.axis("off")
    # plt.subplot(1,3,2); plt.imshow(img2); plt.title("Image 2"); plt.axis("off")
    # plt.subplot(1,3,3); plt.imshow(diff / diff.max()); plt.title("Diff (normalized)"); plt.axis("off")
    # plt.show()
    
    return imgify(diff/diff.max())

In [None]:
def saveTop(srcImg,desFolder):
    sample = torch.load(srcImg).to(device).detach()
    res = crp(sample)
    rel_values = res['rel_values']
    concept_ids = res['concept_ids']
    img = imgify(res['topk'], symmetric=True, grid=(1, len(concept_ids)))
    img.save(f"{desFolder}/topk.png")
    info_str = list_relevance(concept_ids,rel_values)
    with open(f"{desFolder}/info.txt", "w") as f:
        pred = res['prediction']
        f.write(f"{imagenet_labels[pred]} ({pred})\n")
        f.write("Relevant Concepts:\n")
        f.write(info_str)

In [None]:
def calcAll(dirPath):
    first = True
    saveTop(f"{dirPath}/original.pt",dirPath)
    for folder in Path(dirPath).iterdir():
        if not folder.is_dir():
            continue
            
        img = show_diff(f"{dirPath}/original.png", f"{folder}/perturbed.png")
        img.save(f"{folder}/diff.png")
        
        saveTop(f"{folder}/perturbed.pt",folder)

In [None]:
def saveAll(dirPath):
    for file in Path(dirPath).rglob("*.pt"):
        img = display(str(file))
        img.save(f"{file.parents[0]}/{file.stem}_hmap.png")

In [None]:
saveAll("results/beagle3.png")
calcAll("results/beagle3.png")

In [None]:
imgNames = ["beagle1.jpg","dog.jpg","eel1.jpg","espresso.jpg","goldfish.jpg","lizard.jpg","plane.png","snail.png" , "warplane.jpg"]

for i in imgNames:
    result_path = f"results/{i}"
    print(result_path)
    saveAll(result_path)
    calcAll(result_path)

In [None]:
import IPython
# This audio-file is not with in the repo
IPython.display.Audio("audio.mp3", autoplay=True)