# Visualize Prototypes

In [None]:
import json
import os
import numpy as np
import quanproto.datasets.config_parser as quan_dataset
import matplotlib.pyplot as plt
import skimage as ski
from quanproto.utils.workspace import DATASET_DIR, EXPERIMENTS_PATH

In [None]:
push_config = {
    "path": f"{EXPERIMENTS_PATH}/ProtoMask/dogs/protomask_photometric/medium_medium/resnet50/fold_0/models/elegant-river-162/push_log_epoch_100.json",
}

config = {
    "dataset_dir": DATASET_DIR,
    "dataset": "dogs",
    "fold_idx": 0,
    "seg_method": "sam2",
}

# we used 10 prototypes per class on the CUB200 dataset
proto_ids = [i for i in range(60, 70)]

In [None]:
# load the json file
with open(push_config["path"], "r") as file:
    push_log = json.load(file)

In [None]:
dataset = quan_dataset.get_dataset(config["dataset_dir"], config["dataset"])
root_img_dir = os.path.join(dataset.fold_dirs(config["fold_idx"])["train"])
root_mask_dir = os.path.join(dataset.fold_dirs(config["fold_idx"])["train_segmentations"], config["seg_method"])
print("root_img_dir", root_img_dir)
print("root_mask_dir", root_mask_dir)

data_info = dataset.fold_info(config["fold_idx"], "train")
img_paths = data_info["paths"]
mask_paths = data_info["masks"][config["seg_method"]]["paths"]
# only with cub200 and cars196
bbox = None
if "bboxes" in data_info:
    bbox = data_info["masks"][config["seg_method"]]["bounding_boxes"]

In [None]:
# make a 5x2 grid of images
fig, axs = plt.subplots(2, 5, figsize=(20, 10))
for i, proto_id in enumerate(proto_ids):
    img_index = push_log[str(proto_id)]["img_index"]
    mask_idx = push_log[str(proto_id)]["mask_index"]
    img_path = img_paths[img_index]
    mask_path = mask_paths[img_index][mask_idx]
    img = ski.io.imread(os.path.join(root_img_dir, img_path))
    mask = ski.io.imread(os.path.join(root_mask_dir, mask_path))
    # crop the image to the bounding box
    if bbox is not None:
        min_w = int(bbox[img_index][mask_idx][0])
        min_h = int(bbox[img_index][mask_idx][1])
        max_w = int(bbox[img_index][mask_idx][2])
        max_h = int(bbox[img_index][mask_idx][3])
        # img = img[min_h:max_h, min_w:max_w]
        # mask = mask[min_h:max_h, min_w:max_w]

    # resize the image to 224x224
    # img = ski.transform.resize(img, (224, 224))
    # mask = ski.transform.resize(mask, (224, 224))

    axs[i // 5, i % 5].imshow(img)
    axs[i // 5, i % 5].set_title(f"Proto {proto_id}")
    axs[i // 5, i % 5].axis("off")

    # draw the patch
    rect = plt.Rectangle((min_w, min_h), max_w - min_w, max_h - min_h, edgecolor="r", facecolor="none")
    axs[i // 5, i % 5].add_patch(rect) 
