In [1]:
## This notebook shows the functionality of the LocalizationTest.


In [2]:
import sys
sys.path.append("/home/motzkus/work/xai-quantification-toolbox")

In [3]:
from xai_quantification_toolbox.measures.localization_test import *
from xai_quantification_toolbox.loaders.model_interface import *
#from xai_quantification_toolbox.quantifier.base import *
#from ....NoiseGrad.src.models import ResNet18



#!pip install captum
#!pip install opencv-python

import torch
import torchvision
from torchvision import transforms
import numpy as np

import collections
from xml.etree import ElementTree
import xmltodict
import cv2
#import h5py
#from tqdm import tqdm
from captum.attr import Saliency, IntegratedGradients
#from pathlib import Path
import warnings

# Retrieve source code.
#from drive.MyDrive.Projects.xai_quantification_toolbox import * #import xaiquantificationtoolbox

# Notebook settings.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings("ignore", category=UserWarning)
#%load_ext autoreload
#%autoreload 2



In [4]:
### Load model, data and attributions.

In [5]:
# Load pre-trained vgg16 model.
model = torchvision.models.vgg16(pretrained=False)
model.classifier[-1] = torch.nn.Linear(4096, 20)
model.to(device)
model.load_state_dict(torch.load("../../../xai_discriminability/models/pytorch/vgg16_voc/model.pt", map_location=device))
model.eval()

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [6]:
class VOC2012Sample:
    """ Implements a pascal voc 2012 sample. """

    def __init__(self, datum, filename, label, one_hot_label, binary_mask):
        self.datum = datum
        self.filename = filename
        self.label = label
        self.one_hot_label = one_hot_label
        self.binary_mask = binary_mask

class VOC2012Dataset:
    """ Implements the pascal voc 2012 dataset. """

    def __init__(self, datapath, partition, classidx=None):
        """ Initialize pascal voc 2012 dataset. """
        #super().__init__(datapath, partition)
        self.datapath = datapath
        self.partition = partition
        self.samples = []

        self.cmap = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
                     'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
                     'tvmonitor']
        if not classidx:
            self.classes = self.cmap

        else:
            self.classes = []
            for idx in classidx:
                self.classes.append(self.cmap[int(idx)])

        self.labels = []

        if not classidx:
            f = open(datapath + "ImageSets/Main/" + partition + ".txt", "r")
        else:
            f = []
            for idx in classidx:
                with open(datapath + "ImageSets/Main/" + self.cmap[int(idx)] + "_" + partition + ".txt", "r") as classfile:
                    for line in classfile:
                        filename, in_class = [value for value in line.split(" ") if value]
                        if in_class.startswith("1") and (filename not in f):
                            f.append(filename)

        for line in f:
            if line.endswith("\n"):
                line = line[:-1]
            # get image filepath
            self.samples.append(datapath + "JPEGImages/" + line + ".jpg")

            # parse annotations
            tree = ElementTree.parse(datapath + "Annotations/" + line + ".xml")
            xml_data = tree.getroot()
            xmlstr = ElementTree.tostring(xml_data, encoding="utf-8", method="xml")
            annotation = dict(xmltodict.parse(xmlstr))['annotation']

            objects = annotation["object"]

            if type(objects) != list:
                self.labels.append([objects['name']])

            else:
                label = []
                for object in annotation['object']:
                    if type(object) == collections.OrderedDict:
                        if object['name'] not in label:
                            label.append(object['name'])

                self.labels.append(label)
        
        print("{} samples loaded".format(len(self.samples)))

    def __getitem__(self, index):
        """ Get the datapoint at index. """

        filename = self.samples[index]
        label = self.labels[index]

        image = self.preprocess_image(filename)
        one_hot_label = self.preprocess_label(label)
        binary_mask = self.preprocess_binary_mask(filename)

        sample = VOC2012Sample(
            image,
            filename,
            label,
            one_hot_label,
            binary_mask
        )

        return sample

    def classname_to_idx(self, class_name):
        """ convert a classname to an index. """
        return self.cmap.index(class_name)

    def preprocess_image(self, image):

        read_image = cv2.imread(image, cv2.IMREAD_COLOR)
        image_resized = cv2.resize(read_image, (224, 224), interpolation=cv2.INTER_CUBIC)
        image_normalized = image_resized.astype(np.float32) / 127.5 - 1.0

        return image_normalized

    def preprocess_label(self, label):
        """ Convert label to one hot encoding. """
        one_hot_label = np.zeros(len(self.cmap))

        for classname in label:
            one_hot_label[self.cmap.index(classname)] = 1

        return one_hot_label

    def preprocess_binary_mask(self, filename):
        """ Get the bounding box as binary mask."""

        binary_mask = {}
        #filename = extract_filename(filename)
        filename = filename.split("/")[-1].split(".")[0]

        # parse annotations
        tree = ElementTree.parse(self.datapath + "Annotations/" + filename + ".xml")
        xml_data = tree.getroot()
        xmlstr = ElementTree.tostring(xml_data, encoding="utf-8", method="xml")
        annotation = dict(xmltodict.parse(xmlstr))['annotation']

        width = int(annotation["size"]["width"])
        height = int(annotation["size"]["height"])

        # iterate objects
        objects = annotation["object"]

        if type(objects) != list:
            # self.labels.append([objects['name']])
            mask = np.zeros((width, height), dtype=int)

            mask[int(objects['bndbox']['xmin']):int(objects['bndbox']['xmax']), int(objects['bndbox']['ymin']):int(objects['bndbox']['ymax'])] = 1

            binary_mask[objects['name']] = mask

        else:
            for object in annotation['object']:
                if type(object) == collections.OrderedDict:
                    if object['name'] in binary_mask.keys():
                        mask = binary_mask[object['name']]
                    else:
                        mask = np.zeros((width, height), dtype=np.uint8)

                    mask[int(object['bndbox']['xmin']):int(object['bndbox']['xmax']), int(object['bndbox']['ymin']):int(object['bndbox']['ymax'])] = 1

                    binary_mask[object['name']] = mask

        # preprocess binary masks to fit shape of image data
        for key in binary_mask.keys():
            # binary_mask[key] = tf.image.resize(binary_mask[key][:, :, np.newaxis], [224, 224]).numpy().astype(int)
            binary_mask[key] = cv2.resize(binary_mask[key], (224, 224), interpolation=cv2.INTER_NEAREST).astype(np.int)[:, :, np.newaxis]

        return binary_mask

In [7]:
classidx = 4

dataset = VOC2012Dataset("../../../data/VOC2012/", "val", classidx=[classidx])

341 samples loaded


In [8]:
data = [dataset[i] for i in range(10)]
#x_batch, y_batch, a_batch, s_batch
#a_batch = explain(model, x_batch.to(device), y_batch.to(device), explanation_func="Saliency")
x_batch = np.array([sample.datum for sample in data])
y_batch = np.array([sample.one_hot_label for sample in data])
s_batch = np.array([sample.binary_mask[dataset.cmap[classidx]] for sample in data])[:, :, :, 0]

In [9]:
a_batch_saliency = explain(model.to(device), x_batch, classidx, explanation_func="Saliency")

In [10]:
print(a_batch_saliency.shape)
a_batch_saliency = a_batch_saliency.cpu().numpy()

print(s_batch.shape)

torch.Size([10, 224, 224])
(10, 224, 224)


In [11]:
### Option 1. Evaluate the localization authenticity of attributions in one line of code.

In [12]:
# One-liner to measure pointing game results of provided attributions.
scores = PointingGame()(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, s_batch=s_batch, device=device, **{"explanation_func": "Saliency"})
print(scores)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [13]:
# One-liner to measure the attribution localization results of provided attributions.
scores = AttributionLocalization()(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, s_batch=s_batch, device=device, **{"explanation_func": "Saliency"})
print(scores)

[0.061176937, 0.003850803, 0.018578487, 0.004330946, 0.03572835, 0.03354629, 0.0054374468, 0.021349076, 0.013994433, 0.010142415]


In [None]:
# One-liner to measure the top-k intersection results of provided attributions.
scores = TopKIntersection()(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, s_batch=s_batch, device=device, **{"explanation_func": "Saliency"})
print(scores)

In [None]:
# One-liner to measure the relevance rank accuracy results of provided attributions.
scores = RelevanceRankAccuracy()(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, s_batch=s_batch, device=device, **{"explanation_func": "Saliency"})
print(scores)

In [14]:
### Option 2. Evaluate the robustness of provided attributions while enjoying more functionality of Quantifier and Graph.

In [15]:
# Specify the tests.
tests = [RobustnessTest(**{
    "similarity_function": similarity_fn,
    "perturbation_function": gaussian_blur,
}) for similarity_fn in [lipschitz_constant, distance_euclidean, cosine]]

# Load attributions of another explanation method.
a_batch_intgrad = IntegratedGradients(model).attribute(inputs=x_batch, targets=y_batch)

# Init the quantifier object.
quantifier = Quantifier(measures=tests, io_object=h5py.File("PATH_TO_H5PY_FILE"), checkpoints=..)

# Score the tests.
results = [quantifier.score(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch)
           for a_batch in [a_batch_saliency, a_batch_intgrad]]

# Plot Saliency vs Integrated Gradients.
Plotting(results, show=False, path_to_save="PATH_TO_SAVE_FIGURE")




SyntaxError: invalid syntax (<ipython-input-15-40298831611c>, line 11)