# Prototype explanations

This script generates the typical prototype explanations for a prediction.

1) make sure to run the explanation script to generate a json file containing the nearest images in the training set for each prototype.

In [None]:
import torch
import json
import os
import skimage as ski
import albumentations as A
from albumentations.pytorch import ToTensorV2
from quanproto.augmentation import enums
from quanproto.datasets import functional as F
import matplotlib.pyplot as plt
import numpy as np

from quanproto.models import helper
from quanproto.models import receptive_field
from quanproto.metrics import helpers

from quanproto.utils.explanations import load_model
from quanproto.utils.evaluation import get_experiment_info
from quanproto.utils.evaluation import get_experiment_results

import quanproto.utils.dataloader as qf
import quanproto.datasets.functional as sf

In [None]:
USER = os.environ["USER"]

experiment_config = {
    "experiment_dir": f"/home/{USER}/repos/QuanProto/experiments/PIPNet",
    "dataset_dir": f"/home/{USER}/data/quanproto",
    "model": "pipnet",
    "dataset": "cub200",
    "augmentation": "geometric_photometric",
    "feature": "resnet50",
    "fold": 0,
    "run": "chocolate-hill-3",
    "explanation": "prp",
}

RESULT_DIR = f"/home/{USER}/repos/Quanproto/example_maps"

k = 3 # number of top k prototypes to explain the image
test_image_idx = 2 # index of the image to explain


In [None]:
experiment_info = get_experiment_info(experiment_config)

dataset = sf.get_dataset(experiment_config["dataset_dir"], experiment_config["dataset"])

train_data_info = dataset.fold_info(experiment_config['fold'], "train")
train_root_dir = dataset.fold_dirs(experiment_config['fold'])["train"]

test_data_info = dataset.test_info()
test_root_dir = dataset.test_dirs()["test"]

fold_info = experiment_info[f"fold_{experiment_config['fold']}"]
run_info = fold_info[experiment_config["run"]]

try:
    # read the config json file
    with open(run_info["config"], "r") as f:
        config = json.load(f)
except KeyError:
    raise FileNotFoundError(f"Could not find config file for run {experiment_config['run']}")

model = load_model(
    experiment_config["model"],
    experiment_config["explanation"],
    config,
    dataset.num_classes(),
    dataset.multi_label(),
    run_info["state_dict"],
)
model.cuda()
model.eval()

# 1 = prototype spatial size like 1x1, 224 = input image size like 224x224
proto_layer_rf_info = receptive_field.prototype_receptive_field(model.backbone, 1, 224)


technique = "topk_prototype_images"
experiment_results = get_experiment_results(
    experiment_config, technique=technique
)
fold_info = experiment_results[f"fold_{experiment_config['fold']}"]
run_info = fold_info[experiment_config["run"]]
try:
    # read the config json file
    with open(run_info[technique], "r") as f:
        topk_prototype_images = json.load(f)
except KeyError:
    raise FileNotFoundError(
        f"Could not find config file for run {experiment_config['run']}"
    )


In [None]:
IMAGENET_MEAN: tuple[float, float, float] = (0.485, 0.456, 0.406)
IMAGENET_STD: tuple[float, float, float] = (0.229, 0.224, 0.225)
transform = [
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2(),
]

train_bboxes = None
test_bboxes = None
if "bboxes" in train_data_info and "bboxes" in test_data_info:
    train_bboxes = train_data_info["bboxes"]
    test_bboxes = test_data_info["bboxes"]

    transform = A.Compose(enums.AugmentationPipelines["crop_resize"] + transform)
else:
    transform = A.Compose(enums.AugmentationPipelines["resize"] + transform)

In [None]:
def load_img(root, data_info, idx):
    # get the first batch
    img = ski.io.imread(os.path.join(root, data_info["paths"][idx]))
    if len(img.shape) == 2:
        # convert to 3 channels
        img = ski.color.gray2rgb(img)
    if img.shape[2] == 4:
        img = ski.color.rgba2rgb(img)
    if img.shape[2] == 2:
        raise ValueError("Image has 2 channels")

    if "bboxes" in data_info:
        largest_bbox = F.combine_bounding_boxes(data_info["bboxes"][idx])
        img = transform(image=img, cropping_bbox=largest_bbox)["image"]
    else:
        img = transform(image=img)["image"]

    # check what dtype is returned
    if isinstance(img, torch.Tensor):
        img = img.float()
    else:
        img = torch.tensor(img).float()
    img = img.cuda()
    # expand the batch dimension
    img = img.unsqueeze(0)

    return img

In [None]:
test_img = load_img(test_root_dir, test_data_info, test_image_idx)

with torch.no_grad():
    logits, similarity_maps, _ = model.explain(test_img)

# get the maximum value of the similarity maps
similarity_scores = torch.functional.F.max_pool2d(
    similarity_maps, kernel_size=similarity_maps.shape[2:]
).squeeze().unsqueeze(0)

# get the top k similarity scores indices
_, topk_indices = torch.topk(similarity_scores, k=k, dim=1)

# print(topk_indices[0])
# print(similarity_scores[0, topk_indices[0]])

# use only the top k similarity maps
similarity_maps = torch.stack(
    [similarity_maps[i, topk_indices[i]] for i in range(similarity_maps.shape[0])]
)
proto_rf_info = receptive_field.prototype_rf(similarity_maps, proto_layer_rf_info)
saliency_maps = model.saliency_maps(test_img, topk_indices)

In [None]:
def train_image_explanation(prototype_id):
    prototype_info = topk_prototype_images[str(prototype_id)]

    train_imgs = torch.empty(0).cuda()
    for img_id in prototype_info["ids"]:
        img = load_img(train_root_dir, train_data_info, img_id)
        train_imgs = torch.cat((train_imgs, img), 0)
        
    with torch.no_grad():
        _, similarity_maps, _ = model.explain(train_imgs)

    # get the similarity maps from the prototype id
    similarity_maps = similarity_maps[:,prototype_id]

    proto_id_tensor = torch.tensor([prototype_id]).cuda().repeat(train_imgs.shape[0]).unsqueeze(1)
    saliency_maps = model.saliency_maps(train_imgs,proto_id_tensor)

    bboxes = []
    for i in range(saliency_maps.shape[0]):
        bboxes.append(helpers.bounding_box(saliency_maps[i].squeeze()))

    return train_imgs, saliency_maps, bboxes, prototype_info["labels"]


In [None]:
def show_test_explanation(img, saliency_map, similarity_map, proto_rf_bb):
    saliency_map_percentile_mask = helpers.percentile_mask(saliency_map)
    bb = helpers.bounding_box(saliency_map)
    cropped_saliency_map = saliency_map * saliency_map_percentile_mask

    # invert the normalization from the dataloader
    img = helper.invert_normalize(img).cpu()
    # change the channel order to (height, width, channels)
    img = img.squeeze(0).permute(1, 2, 0).numpy()

    # show original image, cropped saliency map, bbox, proto_rf box, similarity map
    fig, axes = plt.subplots(1, 5, figsize=(15, 5))


    # show the original image
    axes[0].imshow(img)
    axes[0].set_title("Original Image")
    axes[0].axis("off")

    # show the cropped saliency map
    axes[1].imshow(img)
    axes[1].imshow(cropped_saliency_map.cpu(), alpha=0.5, cmap="viridis")
    axes[1].set_title("Cropped Saliency Map")
    axes[1].axis("off")

    # show the bounding box
    lower_y, upper_y, lower_x, upper_x = bb
    axes[2].imshow(img)
    axes[2].add_patch(
        plt.Rectangle(
            (lower_x, lower_y),
            upper_x - lower_x,
            upper_y - lower_y,
            linewidth=1,
            edgecolor="r",
            facecolor="none",
        )
    )
    axes[2].set_title("Bounding Box")
    axes[2].axis("off")

    # show the prototype receptive field
    lower_y, upper_y, lower_x, upper_x = proto_rf_bb
    axes[3].imshow(img)
    axes[3].add_patch(
        plt.Rectangle(
            (lower_x +1, lower_y +1),
            upper_x - lower_x -3,
            upper_y - lower_y -3,
            linewidth=2,
            edgecolor="r",
            facecolor="none",
        )
    )
    axes[3].set_title("Prototype RF")
    axes[3].axis("off")

    # show the similarity map
    axes[4].imshow(similarity_map.cpu(), cmap="viridis")
    axes[4].set_title("Similarity Map")
    # axes[4].axis("off")

    plt.show()



In [None]:
def show_train_images(train_images, bboxes):
    k = train_images.shape[0]

    fig, axes = plt.subplots(1, k, figsize=(15, 3))

    for i in range(k):
        img = helper.invert_normalize(train_images[i]).cpu()
        img = img.squeeze(0).permute(1, 2, 0).numpy()

        lower_y, upper_y, lower_x, upper_x = bboxes[i]
        axes[i].imshow(img)
        axes[i].add_patch(
            plt.Rectangle(
                (lower_x, lower_y),
                upper_x - lower_x,
                upper_y - lower_y,
                linewidth=1,
                edgecolor="r",
                facecolor="none",
            )
        )
        axes[i].set_title(f"Train Image {i}")
        axes[i].axis("off")

    plt.show()



# Top-1 Prototype

## Test Input Reasoning:

In [None]:
show_test_explanation(test_img, saliency_maps[0][0], similarity_maps[0][0], proto_rf_info[0][0])

## Nearest Train Input Reasoning:

In [None]:
train_images, _, bbs, labels = train_image_explanation(topk_indices.tolist()[0][0])
print(labels)
show_train_images(train_images, bbs)

# Top-2 Prototype

## Test Input Reasoning:

In [None]:
show_test_explanation(test_img, saliency_maps[0][1], similarity_maps[0][1], proto_rf_info[0][1])

## Nearest Train Input Reasoning:

In [None]:
train_images, _,  bbs, labels = train_image_explanation(topk_indices.tolist()[0][1])
print(labels)
show_train_images(train_images, bbs)

# Top-3 Prototype

## Test Input Reasoning:

In [None]:
show_test_explanation(test_img, saliency_maps[0][2], similarity_maps[0][2], proto_rf_info[0][2])

## Nearest Train Input Reasoning:

In [None]:
train_images, _, bbs, labels = train_image_explanation(topk_indices.tolist()[0][2])
print(labels)
show_train_images(train_images, bbs)