# Prototype explanations

This script generates the typical prototype explanations for a prediction.

1) make sure to run the explanation script to generate a json file containing the nearest images in the training set for each prototype.

In [None]:
import torch
import json
import os
import skimage as ski
from quanproto.metrics import helpers

from quanproto.utils.workspace import EXPERIMENTS_PATH, DATASET_DIR
from quanproto.evaluation import folder_utils as eval
from quanproto.dataloader.single_augmentation import test_dataloader, prune_dataloader
from quanproto.explanations.config_parser import load_model
from torch.nn.functional import max_pool2d
from quanproto.utils.vis_helper import save_image_mask
from quanproto.metrics.helpers import label_prediction

In [None]:

experiment_config = {
    "experiment_dir": f"{EXPERIMENTS_PATH}/ProtoPool/nico",
    "dataset_dir": DATASET_DIR,
    "model": "protopool",
    "run": "ocean-mountain-547",
    "explanation": "prp",
    "train_phase": "fine_tune",
    "crop": True,
}

# test_image_idx = 970 # index of the image to explain
test_image_idx = 205 # index of the image to explain
# test_image_idx = 6010 # index of the image to explain


In [None]:
# info from all runs in the experiment dir
run_info = eval.get_run_info(experiment_config)
run_info = run_info[experiment_config["run"]]
run_topk_info = eval.get_technique_results(
    experiment_config)["topk_prototype_images"]

with open(run_topk_info[experiment_config["run"]], "r") as f:
    topk_info = json.load(f)

# load the config file
with open(run_info["config"], "r") as f:
    run_config = json.load(f)

# make sure to use the datasets on your computer and not the path that was used to train the model
run_config["dataset_dir"] = experiment_config["dataset_dir"]

dataloader = test_dataloader(run_config,crop=experiment_config["crop"])

model = load_model(
    run_config,
    experiment_config["explanation"],
    run_info[experiment_config["train_phase"]],
)

model.cuda()
model.eval()

In [None]:
def vis_prototypes(ids, dataloader, model, topk_info):
    """
    Visualize the prototypes for the given image ids.
    """
    # make the ids tensor with BxN into a list 
    for img_i in range(ids.shape[0]): # all images
        for proto_j in range(ids.shape[1]): # all prototypes
            proto_idx = ids[img_i][proto_j]
            prototype_info = topk_info[f"{proto_idx}"]

            img_ids = prototype_info["ids"]

            img_batch = torch.zeros((len(img_ids), 3, 224, 224))
            for proto_nearest_i, proto_nearest_id in enumerate(img_ids):
                # get the image tensor
                img, _ = dataloader.dataset.getitem_by_id(proto_nearest_id)
                img_batch[proto_nearest_i] = img

            img_batch = img_batch.cuda()

            with torch.no_grad():
                # make a Bx1 tensor with the prototype index
                prototype_idx = torch.tensor([proto_idx]).cuda()
                prototype_idx = prototype_idx.expand(
                    img_batch.shape[0], prototype_idx.shape[0]
                )
                saliency_maps = model.saliency_maps(img_batch, prototype_idx)
                saliency_masks = torch.stack(
                    [
                        helpers.percentile_mask(saliency_maps[b, 0])
                        for b in range(saliency_maps.shape[0])
                    ]
                ).unsqueeze(1)
                saliency_maps = saliency_maps * saliency_masks

                for save_img_i in range(img_batch.shape[0]):
                    # get the image tensor
                    img = img_batch[save_img_i]
                    # get the saliency maps for the image
                    saliency_map = saliency_maps[save_img_i][0]

                    # save the saliency map
                    save_image_mask(img, saliency_map, f"{experiment_config["experiment_dir"]}/topk_explain/train_nearest_img_{save_img_i}_prototype_{proto_idx}.png")
                    break

In [None]:
def vis_test_images(ids, dataloader, model):
    """
    Visualize the test image for the given image ids.
    """
    # 1) get the image tensor from the dataloader
    img_batch = torch.zeros((len(ids), 3, 224, 224))
    for i, id in enumerate(ids):
        # get the image tensor
        img, label = dataloader.dataset[id] 
        img_batch[i] = img
        print(f"Loaded image {i+1}/{len(ids)}: {id} label: {label}")

    img_batch = img_batch.cuda()

    with torch.no_grad():

        # region TopK Prototypes ---------------------------------------------------------------
        logits, similarity_maps, _ = model.explain(img_batch)
        if not model.multi_label:
            if logits.dim() == 2:
                print(f"max logit id: {logits.argmax(dim=1)}")
            else:
                print(f"max logit id: {logits.argmax()}")
        else:
            logits = label_prediction(logits, model.multi_label, 1.4552792310714722)
            print(f"shape: {logits.shape}")
            print(f"logits: {logits}")
            # print how many 1s are at the same position between label and logits
            print(f"true pos: {((logits.to(device='cpu') == label) & (label == 1.0)).sum()}/{label.sum()}")
            print(f"true neg: {((logits.to(device='cpu') == label) & (label == 0.0)).sum()}/{(label * -1 + 1).sum()}")

        similarity_scores = (
            max_pool2d(similarity_maps, kernel_size=similarity_maps.shape[2:])
            .squeeze(-1)
            .squeeze(-1)
        )
        # get the top k similarity scores indices
        k = 5
        _, topk_indices = torch.topk(similarity_scores, k=k, dim=1)
        # endregion TopK Prototypes ------------------------------------------------------------
        saliency_maps = model.saliency_maps(img_batch, topk_indices)
        saliency_masks = torch.stack(
            [
                torch.stack(
                    [
                        helpers.percentile_mask(saliency_maps[b, i])
                        for i in range(k)
                    ]
                )
                for b in range(saliency_maps.shape[0])
            ]
        )
        saliency_maps = saliency_maps * saliency_masks

        for i in range(img_batch.shape[0]):
            # get the image tensor
            img = img_batch[i]
            # get the saliency maps for the image
            saliency_map = saliency_maps[i]

            for j in range(saliency_map.shape[0]):
                # get the saliency map for the prototype
                saliency_map_j = saliency_map[j]
                # save the saliency map
                save_image_mask(img, saliency_map_j, f"{experiment_config["experiment_dir"]}/topk_explain/test_img_{i}_prototype_{j}.png")
    return topk_indices.detach().cpu().numpy()


In [None]:
prototype_ids = vis_test_images([test_image_idx], dataloader, model)

In [None]:
train_data = prune_dataloader(run_config, crop=experiment_config["crop"]) 
vis_prototypes(prototype_ids, train_data, model, topk_info)