In [None]:
'''
The crop sizes at each layer
'''

import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
from matplotlib import cm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

model_list = ['googlenet', 'resnet50', 'alexnet']
label_list = [i for i in range(50)]
directory = f"./fidelity_of_neuron/full_16/"

# Initialize data structures to collect data for all models
all_percentage_scores = {}
all_layers = {}
all_abbreviated_layers = {}

for model_name in model_list:
    print(f"Processing model: {model_name}")
    # Initialize data structure to accumulate scores across runs
    crop_size_score_all_runs = []
    
    for label in tqdm(label_list):
        path = directory + f"store_{model_name}_label{label}_tau16.pth"
                
        store = torch.load(path, map_location="cpu", weights_only=False)
        FW = store["FW"]
        layers = store["layers"]
        num_layer = len(layers) - 1
        del store
        
        crop_size_score = []
        for layer in range(1, num_layer+1):
            crop_size_score_layer = {size: 0 for size in set(FW.labels_concept_data)}
            
            for cfm in FW.layers[layer].values():
                for size, score in cfm.label_scores.items():
                    crop_size_score_layer[size] += score
                    
            crop_size_score.append(crop_size_score_layer)
            
        crop_size_score_all_runs.append(crop_size_score)
    
    # Calculate the average percentage scores for each size across layers
    percentage_scores = {size: [0] * num_layer for size in set(FW.labels_concept_data)} # type: ignore
    
    for run_scores in crop_size_score_all_runs:
        for layer_idx, layer_scores in enumerate(run_scores):
            total_score = sum(layer_scores.values())
            for size, score in layer_scores.items():
                percentage = (score / total_score) * 100 if total_score != 0 else 0
                percentage_scores[size][layer_idx] += percentage
    
    # Average the percentage scores across runs
    num_runs = len(crop_size_score_all_runs)
    for size in percentage_scores:
        percentage_scores[size] = [score / num_runs for score in percentage_scores[size]] # type: ignore
    
    layers.pop() # type: ignore
    if model_name == "resnet50":
        name = "ResNet50"
        abbreviated_layers = [layer.replace('layer', '') for layer in layers] # type: ignore
    elif model_name == "googlenet":
        name = "GoogLeNet"
        abbreviated_layers = [layer.replace('inception', '') for layer in layers] # type: ignore
    elif model_name == "alexnet":
        name = "AlexNet"
        abbreviated_layers = [layer.replace('features.', 'f') for layer in layers] # type: ignore
        abbreviated_layers = [layer.replace('classifier.', 'c') for layer in abbreviated_layers]
    else:
        raise ValueError(f"Unknown model name: {model_name}")
    
    # Store data for plotting
    all_percentage_scores[model_name] = percentage_scores
    all_layers[model_name] = layers # type: ignore
    all_abbreviated_layers[model_name] = abbreviated_layers

In [None]:
markers = ['o', 's', 'D', '^', 'v', 'p', 'P', '*', 'X', 'd', 'H', 'h', '8', '<', '>', '1', '2', '3', '4', 'x', '|', '_', '']
colors = cm.get_cmap('plasma_r', len(percentage_scores)+1)
available = ['default'] + plt.style.available

# Plotting
with plt.style.context(available[17]):
    fig, axes = plt.subplots(1, 3, figsize=(24, 6), sharey=True, gridspec_kw={'width_ratios': [0.8, 1, 1.2]})
    for idx, model_name in enumerate(model_list):
        ax = axes[idx]
        percentage_scores = all_percentage_scores[model_name]
        abbreviated_layers = all_abbreviated_layers[model_name]
        num_layer = len(abbreviated_layers)
        
        for i, size in enumerate(percentage_scores):
            ax.plot(abbreviated_layers, percentage_scores[size][::-1], marker=markers[i], label=f'Size {size}', color=colors(i+1), markersize=15, linewidth=5)
        ax.tick_params(axis='y', labelsize=25)
        ax.set_title(f'{model_name.capitalize()}', fontsize=30)
        ax.set_xticks(range(0, num_layer + 1))
        if ax == axes[0]:
            ax.set_ylabel('Percentage (%)', fontsize=28)
        ax.tick_params(axis='x', labelsize=25)
    handles, labels = axes[0].get_legend_handles_labels()    
    fig.legend(handles, labels, loc='upper center', ncol=len(labels), fontsize=25)
    fig.text(0.5, -0.04, 'Layers', ha='center', fontsize=25)
    
    plt.tight_layout(rect=[0, 0, 1, 0.9]) # type: ignore
    plt.show()
    plt.savefig(f'Percentage_scores_all_models.png')