In [None]:
from quanproto.evaluation.folder_utils import get_run_info
from quanproto.utils.workspace import EXPERIMENTS_PATH, DATASET_DIR
import json

In [None]:
experiment_config = {
    # INFO: All runs in the experiment will be evaluated so if you used multiple datasets you may
    # need to extend the experiment name to include the dataset name like PIPNet/cub200
    "experiment_dir": f"{EXPERIMENTS_PATH}/ProtoMask/dogs",
    "dataset_dir": DATASET_DIR,
    "prototypes_per_class": 10,
}

In [None]:
run_info = get_run_info(experiment_config)

unique_masks_per_class = {}
# load each json file
for run_name, run in run_info.items():
    if "push_log" not in run:
        print(f"Run {run_name} does not have a push log.")
    class_push_log = {}
    with open(run["push_log"], "r") as f:
        push_log = json.load(f)
        for k, v in push_log.items():
            class_index = int(k) // experiment_config["prototypes_per_class"]

            if class_index not in class_push_log:
                class_push_log[class_index] = [(v["img_index"], v["mask_index"])]
            class_push_log[class_index].append((v["img_index"], v["mask_index"]))
        
        for class_index, masks in class_push_log.items():
            if class_index not in unique_masks_per_class:
                unique_masks_per_class[class_index] = []
            unique_masks_per_class[class_index].append(len(set(masks)))

In [None]:
avg_unique_masks_per_class = []
std_unique_masks_per_class = []

max_index = 0
max_masks = 0

for class_index, num_masks in unique_masks_per_class.items():
    avg_unique_masks_per_class.append(sum(num_masks) / len(num_masks))
    std_unique_masks_per_class.append((max(num_masks) - min(num_masks)) / 2)
    if avg_unique_masks_per_class[-1] > max_masks:
        max_masks = avg_unique_masks_per_class[-1]
        max_index = class_index

print(f"Max unique masks per class: {max_masks} for class index {max_index}")


In [None]:
#plot the results as a graphs with avg and +- std
import matplotlib.pyplot as plt

plt.figure(figsize=(20, 5))
# first line is avg
plt.plot(avg_unique_masks_per_class, label='Average Unique Masks per Class', color='green')
# second line is avg + std
plt.fill_between(range(len(avg_unique_masks_per_class)),
                 [avg + std for avg, std in zip(avg_unique_masks_per_class, std_unique_masks_per_class)],
                 [avg - std for avg, std in zip(avg_unique_masks_per_class, std_unique_masks_per_class)],
                 color='green', alpha=0.2, label='Standard Deviation')
# make a grid in the background
plt.grid(True, linestyle='--', alpha=0.5)
plt.xlabel('Class Index')
plt.ylabel('Number of Unique Masks')
plt.title('Average Unique Masks per Class with Standard Deviation')
# plt.legend()
plt.tight_layout()
plt.show()