In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('../src')

In [None]:
import torch
import cv2
import random
import matplotlib.pyplot as plt
import numpy as np

from torchray.utils import get_device
from pathlib import Path

from data.dataloader import VOCDataModule, COCODataModule
from utils.image_utils import get_unnormalized_image
from utils.helper import get_targets_from_annotations, get_filename_from_annotations, extract_masks
from models.explainer_classifier import ExplainerClassifierModel

In [None]:
def show_cam_on_image(img: np.ndarray,
                      mask: np.ndarray,
                      use_rgb: bool = False,
                      colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
    """ This function overlays the cam mask on the image as an heatmap.
    By default the heatmap is in BGR format.
    :param img: The base image in RGB or BGR format.
    :param mask: The cam mask.
    :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
    :param colormap: The OpenCV colormap to be used.
    :returns: The default image with the cam overlay.
    """
    heatmap = cv2.applyColorMap(np.uint8(255 * (1-mask)), colormap)
    if use_rgb:
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = np.float32(heatmap) / 255

    if np.max(img) > 1:
        raise Exception(
            "The input image should np.float32 in the range [0, 1]")

    cam = heatmap + img
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

In [None]:
def get_class_mask(explainer, image, class_id):
    class_mask = explainer(image)[0][class_id].sigmoid()

    return class_mask.cpu().numpy()

In [None]:
def get_topk_classes(explainer, image, k=5):
    class_masks = explainer(image)[0].sigmoid()
    class_mask_means = class_masks.mean(dim=(1,2))

    values, topk_classes = class_mask_means.topk(k)
    return values.cpu().numpy(), topk_classes.cpu().numpy()

In [None]:
def get_class_scores(explainer, classifier, image, class_id):
    class_mask = explainer(image)[0][class_id].sigmoid()
    masked_image = class_mask.unsqueeze(0).unsqueeze(0) * image

    unmasked_logits = classifier(image)[0]
    masked_logits = classifier(masked_image)[0]

    unmasked_class_prob = unmasked_logits.sigmoid()[class_id]
    masked_class_prob = masked_logits.sigmoid()[class_id]

    return unmasked_class_prob, masked_class_prob

In [None]:
def get_test_dataloader(dataset, data_path):
    if dataset == "VOC":
        data_module = VOCDataModule(data_path, test_batch_size=1)
    elif dataset == "COCO":
        data_module = COCODataModule(data_path, test_batch_size=1)
        
    data_module.setup(stage="test")
    test_dataloader = data_module.test_dataloader()

    return test_dataloader

In [None]:
dataset = "VOC"
num_classes = 20
data_path = Path("../datasets/VOC2007/")
classifier_type = "resnet50"
explainer_classifier_path = Path("../src/checkpoints/explainer_resnet50_voc.ckpt")
output_dir = Path(f"./topk_attributions/{classifier_type}")

explainer_classifier = ExplainerClassifierModel.load_from_checkpoint(explainer_classifier_path, 
                                                                     num_classes=num_classes, 
                                                                     dataset=dataset, 
                                                                     classifier_type=classifier_type)
                                                                     
device = get_device()
explainer = explainer_classifier.explainer.to(device)
explainer.freeze()
classifier = explainer_classifier.classifier.to(device)
classifier.freeze()

dataloader = get_test_dataloader(dataset, data_path)
image_list = list(enumerate(dataloader))

In [None]:
output_dir.mkdir(parents=True, exist_ok=True)

random.shuffle(image_list)
n_images = 8
count = 0
for i, item in image_list:
    image, annotation = item
    image = image.to(device)
    filename = get_filename_from_annotations(annotation, dataset=dataset)[:-4]
    targets = get_targets_from_annotations(annotation, dataset=dataset)
    target_classes = [i for i, val in enumerate(targets[0]) if val == 1.0]
    topk_values, topk_classes = get_topk_classes(explainer, image, k=5)

    fig = plt.figure(figsize=(25, 5))
    original_image = np.transpose(get_unnormalized_image(image).cpu().numpy().squeeze(), (1, 2, 0))
    plt.imsave(output_dir / f"{filename}_original.png", original_image, format="png")
    plt.subplot(1, 7, 1)
    plt.imshow(original_image)
    plt.axis("off")

    segmentations = explainer(image)
    aggregated_mask, _ = extract_masks(segmentations, targets)
    aggregated_mask = aggregated_mask[0].cpu().numpy()
    aggregated_attribution = show_cam_on_image(original_image, aggregated_mask)
    plt.imsave(output_dir / f"{filename}_aggregated.png", aggregated_attribution, format="png")
    plt.subplot(1, 7, 2)
    plt.imshow(aggregated_attribution)
    plt.axis("off")
    for j, class_id in enumerate(topk_classes):
        unmasked_class_prob, masked_class_prob = get_class_scores(explainer, classifier, image, class_id)
        class_mask = get_class_mask(explainer, image, class_id)
        attribution = show_cam_on_image(original_image, class_mask)
        plt.imsave(output_dir / f"{filename}_rank_{j}_class_{class_id+1}.png", attribution, format="png")
        plt.subplot(1, 7, j+3)
        plt.imshow(attribution, vmin=0, vmax=1)
        class_title = f"{class_id+1}**" if class_id in target_classes else f"{class_id+1}"
        plt.title(f"{class_title}: CLS={unmasked_class_prob*100:.2f}, MASK={topk_values[j]*100:.2f}")
        plt.axis("off")

    count += 1
    if count >= n_images:
        break