In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from PIL import Image
import torchvision.transforms as transforms
import traceback
import random

In [None]:
import cv2

def show_cam_on_image(img: np.ndarray,
                      mask: np.ndarray,
                      use_rgb: bool = False,
                      colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
    """ This function overlays the cam mask on the image as an heatmap.
    By default the heatmap is in BGR format.
    :param img: The base image in RGB or BGR format.
    :param mask: The cam mask.
    :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
    :param colormap: The OpenCV colormap to be used.
    :returns: The default image with the cam overlay.
    """
    heatmap = cv2.applyColorMap(np.uint8(255 * (1-mask)), colormap)
    if use_rgb:
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = np.float32(heatmap) / 255

    if np.max(img) > 1:
        raise Exception(
            "The input image should np.float32 in the range [0, 1]")

    cam = heatmap + img
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

def open_image(filename):
    transformer = transforms.Compose([transforms.Resize((224, 224))])
    im = Image.open(filename)
    im = transformer(im)
    im = np.array(im) / 255.0
    return im

In [None]:
def get_path_mask(path_masks, dataset_name, model_name, method):
    return path_masks / Path('{}_{}_{}/'.format(dataset_name, model_name, method))

def load_original_image(path_images, filename):
    jpg_name = Path(str(filename)[:-4] + ".jpg")
    x = open_image(path_images / jpg_name)
    return x

In [None]:
methods =  [ "grad_cam", "rise", "extremal_perturbations", "igos_pp", "rt_saliency", "explainer"]

In [None]:
path_masks = Path("../src/evaluation/masks/")

In [None]:
def plot_filename(filename, dataset_name, model_name, methods):
    if dataset_name=="COCO":
        path_images = Path("../datasets/COCO2014/val2014/")
    else:
        path_images = Path("../datasets/VOC2007/VOCdevkit/VOC2007/JPEGImages/")
    try:
        x = load_original_image(path_images, filename)
        masks = []
        for method in methods: 
            p = get_path_mask(path_masks, dataset_name, model_name, method)
            # try:
            npz_name = Path(str(filename)[:-4] + ".npz")
            m_ = np.load(p / npz_name, dataset_name)["arr_0"]
            m_ /= m_.max()
                
            masks.append(m_)
            # except:
            #     jpg_name = Path(str(filename)[:-4] + ".jpg")
            #     masks.append(load_original_image(p, jpg_name))    
    except:
        traceback.print_exc()
    
    
    fig = plt.figure(figsize=(15, 4))
    n_methods = len(methods)
    plt.subplot(1,n_methods+1, 1)
    plt.imshow(x, vmin=0, vmax= 1)
    plt.axis("off")


    for i, (mask,m) in enumerate(zip(masks, methods)):
        plt.subplot(1,n_methods+1, i + 2)
        plt.imshow(show_cam_on_image(x, mask), vmin=0, vmax= 1)
        # plt.imshow( mask, vmin=0, vmax= 1, cmap=plt.cm.gray_r)
        plt.axis("off")
        
    fig.tight_layout()
    outfolder = Path("ATTENUATED_" + dataset_name + "_" + model_name)
    outfolder.mkdir(exist_ok=True, parents=True)
    pdf_name = str(filename)[:-4] + ".pdf"
    plt.savefig(outfolder / Path(pdf_name), bbox_inches='tight')


In [None]:
n_images = 5
dataset_name = "VOC"
model_name = "vgg16"
p_explainer = get_path_mask(path_masks, dataset_name, model_name, "explainer")
image_list = list(enumerate(p_explainer.glob("*.png")))
random.shuffle(image_list)
count = 0
for i, p in image_list:
    filename = p.name
    plot_filename(filename, dataset_name, model_name, methods)
    count += 1
    if count>=n_images:
        break