In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('../src')

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from evaluation.eval_utils.compute_masks import vedaldi2019
from evaluation.eval_utils.compute_scores import segmented_generator, get_model_and_data
from torchray.utils import get_device
from torchray.attribution.rise import rise
from torchray.attribution.grad_cam import grad_cam
from torchray.attribution.guided_backprop import guided_backprop
from PIL import Image

from utils.image_utils import get_unnormalized_image
from models.explainer_classifier import ExplainerClassifierModel

In [None]:
import cv2

def show_cam_on_image(img: np.ndarray,
                      mask: np.ndarray,
                      use_rgb: bool = False,
                      colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
    """ This function overlays the cam mask on the image as an heatmap.
    By default the heatmap is in BGR format.
    :param img: The base image in RGB or BGR format.
    :param mask: The cam mask.
    :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
    :param colormap: The OpenCV colormap to be used.
    :returns: The default image with the cam overlay.
    """
    heatmap = cv2.applyColorMap(np.uint8(255 * (1-mask)), colormap)
    if use_rgb:
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = np.float32(heatmap) / 255

    if np.max(img) > 1:
        raise Exception(
            "The input image should np.float32 in the range [0, 1]")

    cam = heatmap + img
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

In [None]:
path_masks = Path("../src/evaluation/masks/")
data_path = Path("../datasets/VOC2007/")
dataset_name = "VOC"
model_name = "vgg16"
model_path = "../src/checkpoints/pretrained_classifiers/vgg16_voc.ckpt"
path_segmentation = Path('../datasets/VOC2007/VOCdevkit/VOC2007/SegmentationClass/')

In [None]:
voc_classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", 
               "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", 
               "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor" ]
d_classes = {i: e for i,e in enumerate(voc_classes)}

In [None]:
def open_segmentation_mask(segmentation_filename):
    import torchvision.transforms as transforms
    from PIL import Image

    transformer = transforms.Compose([transforms.Resize((224, 224))])
    mask = Image.open(segmentation_filename).convert('L')
    mask = transformer(mask)
    mask = np.array(mask) / 255.0
    # mask[mask > 0] = 1
    return mask


In [None]:
def plot_torch_image(x):
    img = x.detach().cpu().numpy().squeeze()
    if len(img.shape)==2:
        
        plt.imshow(img, vmin=0, vmax=1)
    else:
        plt.imshow(np.transpose(img, (1,2,0)),  vmin=0, vmax=1)
    plt.axis("off")


In [None]:
model, data_module = get_model_and_data(data_path, dataset_name, model_name, model_path)

In [None]:
imgs = []
for s in segmented_generator(data_module, path_segmentation):
    x, category_id, filename = s
    imgs.append(get_unnormalized_image(x))

In [None]:
plt.figure(figsize=(20,20))
for i in range(10):
    for j in range(10):
        plt.subplot(10,10, 10*i+j+1)
        plot_torch_image(imgs[10*i+j+1])

In [None]:
plt.figure(figsize=(20,20))
for i in range(10):
    for j in range(10):
        plt.subplot(10,10, 10*i+j+1)
        plot_torch_image(imgs[100+10*i+j+1])

In [None]:
# positive examples
# c = 54
# c = 145

# c = 4
# c = 77

# negative example (horses and people)
# c = 6
# c = 165
# c = 175
# c = 67

# mixed result
# c = 21
# c = 189


# c = 192

c = 35
# c = 53

for s in segmented_generator(data_module, path_segmentation):
    x, category_id, filename = s
    c -= 1
    if c<0:
        break


In [None]:
category_id

In [None]:
[d_classes[e] for e in category_id]

In [None]:
seg_mask = open_segmentation_mask(path_segmentation / filename)
plt.figure(figsize=(10, 5))
plt.subplot(121)
plot_torch_image(get_unnormalized_image(x))
plt.colorbar()
plt.subplot(122)
plt.imshow(seg_mask)
plt.colorbar()

In [None]:
# plt.hist(seg_mask.flatten(), 100);

In [None]:
classes = [1, 7, 11, 13, 14]


In [None]:
num_classes = len(voc_classes)
explainer_classifier = ExplainerClassifierModel(num_classes=num_classes)
explainer_classifier = explainer_classifier.load_from_checkpoint(
            "../src/checkpoints/explainer_vgg16_voc.ckpt",num_classes=num_classes)
explainer = explainer_classifier.explainer
device = get_device()
model = model.to(device)
explainer = explainer.to(device)
explainer.freeze()

In [None]:
x = x.to(device)

In [None]:
p_ci = model.forward(x).sigmoid().detach().cpu().numpy().squeeze()

In [None]:
# Explainer
explainer_masks = explainer.forward(x).sigmoid().detach().cpu().numpy().squeeze()#[classes]

In [None]:
true_classes = np.unique(category_id) # explainer_masks.shape

f,a = plt.subplots(3,7,figsize=(40, 20))

for ci, c in enumerate(explainer_masks):
    if ci == 0:
        a[0][0].imshow(np.transpose(get_unnormalized_image(x[0,...].cpu()), (1,2,0)))
        a[0][0].axis("off")
    
    y_pos = int((ci+1)/7)
    x_pos = (ci+1)%7
    a[y_pos][x_pos].imshow(c)
    a[y_pos][x_pos].axis("off")
    if (ci in true_classes):
        sub_title = f"*{ci}: {d_classes[ci]}: {p_ci[ci]:.2f}"
    else: 
        sub_title = f"{ci}: {d_classes[ci]}: {p_ci[ci]:.2f}"

    a[y_pos][x_pos].set_title(sub_title, fontsize=12)
    # print(f"{ci}: {d_classes[ci]}: {p_ci[ci]:.2f}")
    
#plt.tight_layout()


In [None]:
# Extremal perturbation
vedaldi_masks = []
for c in classes:
    vedaldi_masks.append(vedaldi2019(model, x, c).detach().cpu().numpy().squeeze())
vedaldi_masks = np.array(vedaldi_masks)

In [None]:
# Grad CAM
# gcam = GradCAM(model=model, target_layer=model.feature_extractor[-1])
# gradcam_masks = []
# for c in classes:
#     gradcam_masks.append(cam(input_tensor=input_tensor, target_category=c).detach().cpu().numpy().squeeze())
# gradcam_masks = np.array(gradcam_masks)

gradcam_masks = []
for c in classes:
    gradcam_masks.append(grad_cam(model, x, c, saliency_layer=model.feature_extractor[-1], resize=True).detach().cpu().numpy().squeeze())
gradcam_masks = np.array(gradcam_masks)
gradcam_masks = gradcam_masks- np.min(gradcam_masks)
gradcam_masks = gradcam_masks/np.max(gradcam_masks)

In [None]:
# RISE
# rise_masks = []
# for c in classes:
#     rise_masks.append(rise(model, x, target=c).detach().cpu().numpy().squeeze())
# rise_masks = np.array(rise_masks)
rise_masks = []
segmentations = rise(model, x).detach().cpu().numpy().squeeze()
for c in classes:
    class_mask = segmentations[c]
    class_mask = class_mask-np.amin(class_mask)
    class_mask = class_mask/np.amax(class_mask)
    rise_masks.append(class_mask)

In [None]:
gbackprop_masks = []
for c in classes:
    gbackprop_masks.append(guided_backprop(model, x, c,resize=True).detach().cpu().numpy().squeeze())
gbackprop_masks = np.array(gbackprop_masks)
# gbackprop_masks = gbackprop_masks- np.min(gbackprop_masks)
# gbackprop_masks = gbackprop_masks/np.max(gbackprop_masks)


In [None]:
category_id, [d_classes[e] for e in category_id]

In [None]:
explainer_masks = np.take(explainer_masks, classes, axis=0)

In [None]:
im_ori = np.transpose(get_unnormalized_image(x).detach().cpu().numpy().squeeze(), (1,2,0))
n_classes = len(classes)

all_masks = [explainer_masks, gradcam_masks, rise_masks, vedaldi_masks]
n_methods = len(all_masks)

plt.figure(figsize=(20, 10))
for i, masks in enumerate(all_masks):
    for j, c in enumerate(classes):
        plt.subplot(n_methods,n_classes,i*n_classes+j+1)
        im = show_cam_on_image(im_ori, masks[j])
        plt.imshow(im, vmin=0, vmax=1)
        plt.title(d_classes[classes[j]])
        plt.axis("off")


In [None]:
plt.figure(figsize=(20, 4))
cmax = np.max(gbackprop_masks)
cmin = 0
for j, c in enumerate(classes):
    plt.subplot(1,n_classes,j+1)
    plt.imshow(gbackprop_masks[j], vmin = 0, vmax =cmax, cmap=plt.cm.gray_r)
    plt.title(d_classes[classes[j]])
    plt.axis("off")

In [None]:
# Save the figures
c_list =  [d_classes[e] for e in classes]
im_ori = np.transpose(get_unnormalized_image(x).detach().cpu().numpy().squeeze(), (1,2,0))

save_folder = Path("negative")

def save_fig(masks, im_ori, c_list, base_name, save_folder):
    save_folder.mkdir(exist_ok=True, parents=True)
    for mask, c in zip(masks, c_list):
        im = Image.fromarray(show_cam_on_image(im_ori, mask))
        im.save( save_folder / Path(base_name + "_" + c + ".png"))

save_fig(explainer_masks, im_ori, c_list, "ours", save_folder)
save_fig(vedaldi_masks, im_ori, c_list, "vedaldi", save_folder)
save_fig(gradcam_masks, im_ori, c_list, "gradcam", save_folder)
cmax = np.max(gbackprop_masks)
save_fig(gbackprop_masks/cmax, im_ori, c_list, "guided_backprop", save_folder)


# cmax = np.max(gbackprop_masks)
# for j, c in enumerate(c_list):
#     im = Image.fromarray(((1-gbackprop_masks[j]/cmax)*255).astype(np.uint8))
#     im.save( save_folder / Path("guided_backprop" + "_" + c + ".png"))
