In [8]:
import pickle
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
# Load the pickle file
pickle_path = "./min_covers.pkl"  # Replace with the actual path to the pickle file
with open(pickle_path, "rb") as file:
    min_covers = pickle.load(file)

# Define the dataset and model names
dataset_names = [
    "imagenet_1k",
    "imagenet_1k_real",
    "imagenet_IN_plus_real",
    "imagenet_r",
    "imagenet_a",
    "imagenet_sketch",
    "objectnet",
]
model_names = ["resnet18", "resnet50", "vit32", "vgg16", "alexnet", "clip_vit_l_14"]
pretty_model_names = {
    "resnet18": "ResNet-18",
    "resnet50": "ResNet-50",
    "vit32": "ViT-B/32",
    "vgg16": "VGG-16",
    "alexnet": "AlexNet",
    "clip_vit_l_14": "CLIP ViT-L/14",
}
pretty_dataset_names = {
    "imagenet_1k": "ImageNet",
    "imagenet_1k_real": "ImageNet ReaL",
    "imagenet_IN_plus_real": "ImageNet + ReaL",
    "imagenet_r": "ImageNet-R",
    "imagenet_a": "ImageNet-A",
    "imagenet_sketch": "ImageNet-Sketch",
    "objectnet": "ObjectNet",
}

In [9]:
# Parsing the transforms and finding the global min and max values for uniform scales
parsed_top_36_transforms = defaultdict(lambda: defaultdict(list))
location_frequencies = []

for dataset, models in min_covers.items():
    for model, transforms in models.items():
        top_36 = transforms[0][:36]
        locations = []
        for transform in top_36:
            loc_str, size_str = transform.split("_Size:")
            x, y = map(int, loc_str.split(":")[1].split("_"))
            size = int(size_str)
            locations.append((x, y))
            parsed_top_36_transforms[dataset][model].append(((x, y), size))
            global_min_size = min(global_min_size, size)
            global_max_size = max(global_max_size, size)

        location_frequencies.extend([locations.count(loc) for loc in set(locations)])

global_min_value = min(location_frequencies)
global_max_value = max(location_frequencies)


# Function to create a summary plot with numbers inside each cell of the heatmap
def plot_dataset_summary_with_numbers(dataset, vmin, vmax, xmin, xmax):
    fig, axs = plt.subplots(
        2, len(model_names), figsize=(20, 8), gridspec_kw={"height_ratios": [3, 1]}
    )
    # fig.suptitle(f"Summary Plots for {pretty_dataset_names[dataset]}", fontsize=16)
    for idx, model in enumerate(model_names):
        locations = [
            (loc[0], loc[1]) for loc, _ in parsed_top_36_transforms[dataset][model]
        ]
        sizes = [size for _, size in parsed_top_36_transforms[dataset][model]]
        location_matrix = np.zeros(
            (
                max(locations, key=lambda x: x[0])[0] + 1,
                max(locations, key=lambda x: x[1])[1] + 1,
            )
        )
        for loc in locations:
            location_matrix[loc[0]][loc[1]] += 1
        im = axs[0, idx].matshow(location_matrix, cmap="YlGnBu", vmin=vmin, vmax=vmax)
        axs[0, idx].set_xticks([])
        axs[0, idx].set_yticks([])
        for i in range(location_matrix.shape[0]):
            for j in range(location_matrix.shape[1]):
                axs[0, idx].text(
                    j,
                    i,
                    int(location_matrix[i, j]),
                    ha="center",
                    va="center",
                    fontsize=22,
                    color="black"
                    if location_matrix[i, j] < (vmax - vmin) / 2
                    else "white",
                )  # Increased fontsize to 10

        axs[0, idx].set_title(pretty_model_names[model], fontsize=16)
        axs[1, idx].hist(sizes, bins=20, edgecolor="black", range=(xmin, xmax))
        axs[1, idx].set_xlabel("Size", fontsize=16)
        axs[1, idx].set_xlim(xmin, xmax)
    axs[0, 0].set_ylabel("Location Heatmap", fontsize=16)
    axs[1, 0].set_ylabel("Size Frequency", fontsize=16)
    divider = make_axes_locatable(axs[0, -1])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax, label="Frequency")
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    return fig


# Creating and saving summary plots with numbers inside each cell of the heatmap for each dataset
for dataset in dataset_names:
    pdf_path = f"./pdfs/{dataset}_summary_numbers_plots.pdf"
    with PdfPages(pdf_path) as pdf:
        fig = plot_dataset_summary_with_numbers(
            dataset,
            global_min_value,
            global_max_value,
            global_min_size,
            global_max_size,
        )
        pdf.savefig(fig, bbox_inches="tight", pad_inches=0)
        plt.close(fig)


# Function to calculate the percentage of top-performing transforms that belong to a specific location
def calculate_percentage_for_location(location_x, location_y):
    percentages = defaultdict(dict)
    for dataset in dataset_names:
        for model in model_names:
            locations = [
                (loc[0], loc[1]) for loc, _ in parsed_top_36_transforms[dataset][model]
            ]
            total_transforms = len(locations)
            specific_location_count = locations.count((location_x, location_y))
            percentage = (specific_location_count / total_transforms) * 100
            percentages[pretty_dataset_names[dataset]][
                pretty_model_names[model]
            ] = percentage
    return percentages


# Calculating the percentage for location (1, 1)
location_x, location_y = 1, 1
percentage_for_location_1_1 = calculate_percentage_for_location(location_x, location_y)

# Calculating the overall average percentage across all datasets and classifiers for location (1, 1)
total_percentage = 0
total_count = 0
for dataset_percentages in percentage_for_location_1_1.values():
    for percentage in dataset_percentages.values():
        total_percentage += percentage
        total_count += 1

overall_average_percentage_for_location_1_1 = total_percentage / total_count
print(
    f"Overall average percentage for location (1, 1): {overall_average_percentage_for_location_1_1:.2f}%"
)

Overall average percentage for location (1, 1): 26.65%
