In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.cm as cm 
from scipy.stats import sem, t

# config
resnet_layers = ["layer4.2", "layer4.1", "layer4.0", "layer3.5", "layer3.4", 
                 "layer3.3", "layer3.2", "layer3.1", "layer3.0", "layer2.3",
                 "layer2.2", "layer2.1", "layer2.0", "layer1.2", "layer1.1", "layer1.0"]
googlenet_layers = ["inception5b", "inception5a", "inception4e", "inception4d",
                    "inception4c", "inception4b", "inception4a", "inception3b", "inception3a"]
plot_resnet_layers = ["layer4.2", "", "layer4.0", "", "layer3.4", 
            "", "layer3.2", "", "layer3.0", "", 
            "layer2.2", "", "layer2.0", "", "layer1.1",
            ""]
plot_googlenet_layers = ["inception5b", "", "inception4e", "",
            "inception4c", "", "inception4a", "",
            "inception3a"]
# labels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# taus = [10, 30, 50]
labels=[0]
taus = [10]
max_y_lim_resnet = 0.98
max_y_lim_googlenet = 0.67
confidence_level = 0.99

def _load_multiple_labels(model_name, labels, layer, tau):
    _mean_acc, _mean_below_ori, _mean_min_acc, _all_acc, _all_ori_acc = [], [], [], [], []
    for label in labels:
        mean_acc, mean_below_ori, mean_min_acc, all_acc, all_ori_acc = torch.load(
            "./minimization/" + f"store_{model_name}_{layer}_{label}_tau_{tau}.pth", weights_only=False
        )
        
        _mean_acc += mean_acc
        _mean_below_ori += mean_below_ori
        _mean_min_acc += mean_min_acc
        
        for fm, acc in all_acc.items():
            _all_acc.append(acc)
            
        for fm, acc in all_ori_acc.items():
            _all_ori_acc.append(acc)
    return [_mean_acc, _mean_below_ori, _mean_min_acc, _all_acc, _all_ori_acc]
    

list_num_node_per_comb = {tau : [int(1 * tau)] for tau in taus}

def load_runs(model_name, load_layers, taus):
    return {tau: [
        _load_multiple_labels(model_name, labels, layer, tau)
        for layer in load_layers
    ] for tau in taus}
    
list_run_resnet = load_runs("resnet50", resnet_layers, taus)
list_run_googlenet = load_runs("googlenet", googlenet_layers, taus)

cmap_upper = cm.get_cmap('viridis_r', len(taus) + 1)
cmap_lower = cm.get_cmap('magma_r', len(taus) * 2 + 1)

# Plotting
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8), gridspec_kw={'width_ratios': [1.4, 1]})

def plot_model(ax, load_layers, list_run, max_y_lim, subtitle, plot_layers=None):
    x = np.arange(len(load_layers)) 
    bar_width = 0.6
    load_layers = plot_layers if plot_layers else load_layers
    if subtitle == "ResNet50":
        ax.set_facecolor('#f5edf3') 
    if subtitle == "GoogLeNet":
        ax.set_facecolor('#e7f1e6')
    for tau_index, tau in enumerate(taus):
        scores = list_run[tau]

        list_mean_acc = [np.mean(score[0]) for score in scores]
        list_mean_below_ori = [np.mean(score[1]) for score in scores]
        
        # Calculate the standard error of the mean for confidence intervals
        list_std_acc = []
        for score in scores:
            below_ori = []
            avg_acc = []
            for i, list_acc in enumerate(score[3]):
                for acc in list_acc:
                    avg_acc.append( - score[4][i])
                    if acc < score[4][i]:
                        below_ori.append(acc - score[4][i])
            # degrees_freedom_below_ori = len(below_ori) - 1  
            # t_critical_below_ori = t.ppf((1 + confidence_level) / 2, degrees_freedom_below_ori)
            degrees_freedom_acc = len(avg_acc) - 1
            t_critical_acc = t.ppf((1 + confidence_level) / 2, degrees_freedom_acc)
            list_std_acc.append(sem(avg_acc) * t_critical_acc)
                    
        # list_std_acc = [sem(score[0]) for score in scores]  # CI for Total
        # list_std_below_ori = [sem(score[1]) for score in scores]  # CI for Negative

        positions = x + (tau_index - len(taus) / 2) * bar_width / len(taus)    

        ax.bar(positions, list_mean_below_ori, bar_width / len(taus), 
               label=f'$\\tau = {tau}$ (Negative)', color=cmap_lower(tau_index + 1), alpha=0.7,)
        
        ax.bar(positions, list_mean_acc, bar_width / len(taus), 
               label=f'$\\tau = {tau}$ (Total)', color=cmap_upper(tau_index + 1),
               yerr= list_std_acc, capsize=5, ecolor='red', error_kw={'linewidth': 2, 'capthick': 3})

    abbreviated_layers = [layer.replace('layer', '') if 'layer' in layer else layer.replace('inception', '') for layer in load_layers]
    ax.set_xticks(x)
    ax.set_xticklabels(abbreviated_layers, fontsize=28)
    ax.tick_params(axis='y', labelsize=28)
    ax.set_ylim(-0.12, max_y_lim)
    ax.grid(True, axis='y', linestyle='--', alpha=0.8)
    ax.axhline(y=0, color='r', linestyle='--', label='Baseline', alpha=0.6, linewidth=3)
    ax.set_title(subtitle, fontsize=38)
    ax.set_xlabel('Layers', fontsize=28)

plot_model(ax1, resnet_layers, list_run_resnet, max_y_lim_resnet, "ResNet50", plot_resnet_layers)
plot_model(ax2, googlenet_layers, list_run_googlenet, max_y_lim_googlenet, "GoogLeNet", plot_googlenet_layers)
handles, labels = ax1.get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=4, fontsize=28, bbox_to_anchor=(0.5, -0.11))

ax1.set_ylabel('Loss difference', fontsize=38)
plt.tight_layout(rect=[0, 0.06, 1, 0.95])
plt.show()
