# Visualize Prototypes

In [None]:
import json
import os
import numpy as np
import quanproto.datasets.config_parser as quan_dataset
import matplotlib.pyplot as plt
import skimage as ski
from quanproto.utils.workspace import DATASET_DIR

In [None]:
push_config = {
    "path": "/home/pschlinge/repos/quanproto/experiments/ProtoPNet/cub200/geometric_photometric/medium_medium/resnet50/fold_0/models/ocean-elegant-908/push_log_epoch_110.json",
}

config = {
    "dataset_dir": DATASET_DIR,
    "dataset": "cub200",
    "fold_idx": 0,
}

# we used 10 prototypes per class on the CUB200 dataset
proto_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [None]:
# load the json file
with open(push_config["path"], "r") as file:
    push_log = json.load(file)

In [None]:
dataset = quan_dataset.get_dataset(config["dataset_dir"], config["dataset"])
root_dir = os.path.join(dataset.fold_dirs(config["fold_idx"])["train"])
data_info = dataset.fold_info(config["fold_idx"], "train")
img_paths = data_info["paths"]
# only with cub200 and cars196
bbox = None
if "bboxes" in data_info:
    bbox = data_info["bboxes"]

In [None]:
# make a 5x2 grid of images
fig, axs = plt.subplots(2, 5, figsize=(20, 10))
for i, proto_id in enumerate(proto_ids):
    img_index = push_log[str(proto_id)]["img_id"]
    patch_idx = push_log[str(proto_id)]["patch_index"] # min_w, min_h, max_w, max_h
    img_path = img_paths[img_index]
    img = ski.io.imread(os.path.join(root_dir, img_path))
    # crop the image to the bounding box
    if bbox is not None:
        min_w = int(bbox[img_index][0][0])
        min_h = int(bbox[img_index][0][1])
        max_w = int(bbox[img_index][0][2])
        max_h = int(bbox[img_index][0][3])
        img = img[min_h:max_h, min_w:max_w]

    # resize the image to 224x224
    img = ski.transform.resize(img, (224, 224))
    # 7x7 patch index to pixel index
    min_w = patch_idx[0] * 32
    min_h = patch_idx[1] * 32
    max_w = patch_idx[2] * 32
    max_h = patch_idx[3] * 32

    axs[i // 5, i % 5].imshow(img)
    axs[i // 5, i % 5].set_title(f"Proto {proto_id}")
    axs[i // 5, i % 5].axis("off")

    # draw the patch
    rect = plt.Rectangle((min_w, min_h), max_w - min_w, max_h - min_h, edgecolor="r", facecolor="none")
    axs[i // 5, i % 5].add_patch(rect) 
