In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json

cmap = plt.get_cmap('Reds_r')
cmap.set_under(color='cyan')

def update_dict_(d, u):
    d.update(u)
    return d

def layer_renamer(s):
    s = s.replace("model.", "")
    
    if "layer" not in s and s != "fc":
        return f"[0.0] conv:7x7"
    elif s == "fc":
        return "[Head] fc"
    else:
        operator = ""
        if s.endswith("downsample.1"):
            operator = "downsample bn"
        elif s.endswith("downsample.0"):
            operator = "downsample"
        else: 
            operator = s.split(".")[-1]

        operator = operator.replace("conv1", "conv1:1x1").replace("conv2", "conv2:3x3").replace("conv3", "conv3:1x1")

        return f"[{s.split('.')[0].replace('layer', '')}.{s.split('.')[1]}] {operator}"    

def model_renamer(s):
    name_map = {
        'robust_resnet50_l2_eps0': r'PGD Adversarial Training ($\epsilon$=0)',
        'robust_resnet50_l2_eps0.01': r'PGD Adversarial Training (L2, $\epsilon$=0.01)',
        'robust_resnet50_l2_eps0.03': r'PGD Adversarial Training (L2, $\epsilon$=0.03)',
        'robust_resnet50_l2_eps0.05': r'PGD Adversarial Training (L2, $\epsilon$=0.05)',
        'robust_resnet50_l2_eps0.1': r'PGD Adversarial Training (L2, $\epsilon$=0.1)',
        'robust_resnet50_l2_eps0.25': r'PGD Adversarial Training (L2, $\epsilon$=0.25)',
        'robust_resnet50_l2_eps0.5': r'PGD Adversarial Training (L2, $\epsilon$=0.5)',
        'robust_resnet50_l2_eps1': r'PGD Adversarial Training (L2, $\epsilon$=1)',
        'robust_resnet50_l2_eps3': r'PGD Adversarial Training (L2, $\epsilon$=3)',
        'robust_resnet50_l2_eps5': r'PGD Adversarial Training (L2, $\epsilon$=5)',
        'robust_resnet50_linf_eps0.5': r'PGD Adversarial Training (Linf, $\epsilon$=0.5)',
        'robust_resnet50_linf_eps1.0': r'PGD Adversarial Training (Linf, $\epsilon$=1.0)',
        'robust_resnet50_linf_eps2.0': r'PGD Adversarial Training (Linf, $\epsilon$=2.0)',
        'robust_resnet50_linf_eps4.0': r'PGD Adversarial Training (Linf, $\epsilon$=4.0)',
        'robust_resnet50_linf_eps8.0': r'PGD Adversarial Training (Linf, $\epsilon$=8.0)',
        'resnet50_augmix_180ep': 'AugMix (180Ep)',
        'resnet50.autoaugment_270ep': 'AutoAugment (270Ep)',
        'resnet50_deepaugment': 'DeepAugment',
        'resnet50_deepaugment_augmix': 'DeepAugment+AugMix',
        'resnet50_diffusionnoise_fixed_nonoise': 'Diffusion-like Noise',
        'resnet50.fastautoaugment_270ep': 'FastAutoAugment (270Ep)',
        'resnet50_noisymix': 'NoisyMix',
        'resnet50_opticsaugment': 'OpticsAugment',
        'resnet50_pixmix_180ep': 'PixMix (180Ep)',
        'resnet50_pixmix_90ep': 'PixMix (90Ep)',
        'resnet50_prime': 'PRIME',
        'resnet50.randaugment_270ep': 'RandAugment (270Ep)',
        'resnet50_tsbias_debiased': 'Texture/Shape-debiased',
        'resnet50_tsbias_sbias': 'Texture/Shape Shape-biased',
        'resnet50_tsbias_tbias': 'Texture/Shape Texture-biased',
        'tv_resnet50': 'Baseline',
        'resnet50_dino': 'DINOv1',
        'resnet50_moco_v3_1000ep': 'MoCo v3 (1000Ep)',
        'resnet50_moco_v3_100ep': 'MoCo v3 (100Ep)',
        'resnet50_moco_v3_300ep': 'MoCo v3 (300Ep)',
        'resnet50_simclrv2': 'SimCLRv2',
        'resnet50_swav': 'SwAV',
        'resnet50_trained_on_SIN': 'ShapeNet (SIN)',
        'resnet50_trained_on_SIN_and_IN': 'ShapeNet (SIN+IN)',
        'resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN': r'ShapeNet (SIN+IN$\rightarrow$IN)',
        'resnet50.a1_in1k': 'timm A1',
        'resnet50.a1h_in1k': 'timm A1h',
        'resnet50.a2_in1k': 'timm A2',
        'resnet50.a3_in1k': 'timm A3',
        'resnet50.b1k_in1k': 'timm B1k',
        'resnet50.b2k_in1k': 'timm B2k',
        'resnet50.c1_in1k': 'timm C1',
        'resnet50.c2_in1k': 'timm C2',
        'resnet50.d_in1k': 'timm D',
        'tv2_resnet50': 'TorchVision 2'
    } 

    return name_map.get(s, s)


model_accuracy = {
        'tv_resnet50': 76.15,
        'robust_resnet50_l2_eps0': 75.81,
        'robust_resnet50_l2_eps0.01': 75.67,
        'robust_resnet50_l2_eps0.03': 75.77,
        'robust_resnet50_l2_eps0.05': 75.58,
        'robust_resnet50_l2_eps0.1': 74.79,
        'robust_resnet50_l2_eps0.25': 74.14,
        'robust_resnet50_l2_eps0.5': 73.17,
        'robust_resnet50_l2_eps1': 70.42,
        'robust_resnet50_l2_eps3': 62.83,
        'robust_resnet50_l2_eps5': 56.14,
        'robust_resnet50_linf_eps0.5': 73.74,
        'robust_resnet50_linf_eps1.0': 72.04,
        'robust_resnet50_linf_eps2.0': 69.09,
        'robust_resnet50_linf_eps4.0': 63.87,
        'robust_resnet50_linf_eps8.0': 54.53,        
        'resnet50.autoaugment_270ep': 77.50,
        'resnet50.fastautoaugment_270ep': 77.65,
        'resnet50.randaugment_270ep': 77.64,
        'resnet50_augmix_180ep': 77.53,
        'resnet50_deepaugment': 77.65,
        'resnet50_deepaugment_augmix': 75.80,
        'resnet50_diffusionnoise_fixed_nonoise': 67.22,            
        'resnet50_noisymix': 77.05,
        'resnet50_opticsaugment': 74.22,
        'resnet50_prime': 76.91,
        'resnet50_pixmix_180ep': 78.09,
        'resnet50_pixmix_90ep': 77.36,
        'resnet50_trained_on_SIN': 60.18,
        'resnet50_trained_on_SIN_and_IN': 74.59,
        'resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN': 76.72,        
        'resnet50_tsbias_debiased': 76.89,
        'resnet50_tsbias_sbias': 76.21,
        'resnet50_tsbias_tbias': 75.27,        
        'resnet50_dino': 75.28,
        'resnet50_moco_v3_1000ep': 74.60,
        'resnet50_moco_v3_100ep': 68.91,
        'resnet50_moco_v3_300ep': 72.80,
        'resnet50_simclrv2': 74.90,
        'resnet50_swav': 75.31,        
        'resnet50.a1_in1k': 80.10,
        'resnet50.a1h_in1k': 80.10,
        'resnet50.a2_in1k': 79.80,
        'resnet50.a3_in1k': 77.55,
        'resnet50.b1k_in1k': 79.16,
        'resnet50.b2k_in1k': 79.27,
        'resnet50.c1_in1k': 79.76,
        'resnet50.c2_in1k': 79.92,
        'resnet50.d_in1k': 79.89,
        'tv2_resnet50': 80.34,
    } 

In [5]:
df = pd.read_csv('../raw-results/results.csv').set_index(['model_id','layer'])
df = df.query("model_id != 'resnet50_frozen_random'")
df = df.groupby(['model_id','layer']).agg({"pred_cos_sim": ("mean", "std")}).reset_index()
df.columns = ["model_id", "layer", "mean", "std"]

with open("r50_categories.json", "r") as f:
       r50_categories = json.load(f)

       categorized_resnets = []
       r50_category_inv_index = dict()
       for k, v in r50_categories.items():
              for vv in v:
                     categorized_resnets.append(vv)
                     r50_category_inv_index[vv] = k

df["model"] = df.model_id.apply(model_renamer)
df["layer"] = df.layer.apply(layer_renamer)
df["category"] = df.model_id.apply(lambda x: r50_category_inv_index.get(x, None))


df["accuracy"] = df.model_id.apply(lambda x: model_accuracy.get(x, -10000000))

df = df.sort_values(["category", "model_id"])

assert len(df.reset_index().model_id.unique()) == 50

In [None]:
plt.figure(figsize=(17, 10))
df_t = df.pivot(index=["category", "model_id"], columns="layer", values="mean")
df_t.columns = sorted(df_t.columns)

# make baseline the first row
c = df_t.index[33]
df_t = df_t.loc[[c] + [x for x in df_t.index if x != c]]

mat = (1 - df_t.values) * 100

sns.heatmap(mat, cbar_kws={'label': 'Layer Criticality [%]', 'pad': 0.07}, xticklabels=True, yticklabels=True, vmin=0, vmax=100, cmap="viridis_r")

plt.tick_params(axis='both', which='major', labelbottom=False, bottom=False, top=True, labeltop=True)

plt.xticks(np.arange(0.5, len(df_t.columns)), df_t.columns, rotation=90)
# only use second level for y labels
plt.yticks(np.arange(0.5, len(df_t.index)), map(lambda x: model_renamer(x[1]) + f" [{model_accuracy[x[1]]:.2f}%]", df_t.index))

ax = plt.gca()

def add_label_patch(ax, xy, w, h, s, c, **kwargs):

    from matplotlib.colors import to_rgba

    rect = plt.Rectangle(xy, w, h, facecolor=to_rgba(c, 0.6), edgecolor=to_rgba('black', 0.6), **kwargs)
    ax.add_patch(rect)
    plt.text(xy[0] + w/2, xy[1] + h/2, s, rotation=90, color='black', va='center', ha='center')

x_off = -21

add_label_patch(ax, (x_off, 0), 2, 1, "", 'k', clip_on=False, linewidth = 1, alpha=1)
add_label_patch(ax, (x_off, 1), 2, 15, "Adversarial Training", 'C0', clip_on=False, linewidth = 1)
add_label_patch(ax, (x_off, 16), 2, 18, "Augmentations", 'C1',clip_on=False, linewidth = 1)
add_label_patch(ax, (x_off, 34), 2, 6, "SSL", 'C2', clip_on=False, linewidth = 1)
add_label_patch(ax, (x_off, 40), 2, 10, "Improved Training", 'C3', clip_on=False, linewidth = 1)

bottom_topax = ax.secondary_xaxis('bottom')
# top_ax.set_xlabel('Avg. Layer Criticality')
bottom_topax.set_xticks(np.arange(0.5, len(df_t.columns)), map(lambda x: f"({x[0]:.0f}$\pm${x[1]:.0f})", zip(mat.mean(axis=0), mat.std(axis=0))), rotation=90, c="gray", fontsize=7)
bottom_topax.tick_params(axis='x', which='major', labelbottom=True, bottom=False, top=False, labeltop=False)

right_ax = ax.secondary_yaxis('right')
# right_ax.set_ylabel('Avg. Model Criticality')
right_ax.set_yticks(np.arange(0.5, len(df_t.index)), map(lambda x: f"({x[0]:.0f}$\pm${x[1]:.0f})", zip(mat.mean(axis=1), mat.std(axis=1))), c="gray", fontsize=7)
right_ax.tick_params(axis='y', which='major', labelright=True, right=False, left=False, labelleft=False)


plt.ylabel(None)
plt.xlabel(None)

plt.savefig("weight_reset_r50_training.pdf", bbox_inches="tight")

In [None]:
acc_series = df.groupby(["category", "model_id"]).agg({"mean": lambda x: 100* (1 - x.mean()), "accuracy": "mean"}).reset_index()

acc_series_wo_adv = df.query("category != 'adversarial_training'").groupby(["category", "model_id"]).agg({"mean": lambda x: 100* (1 - x.mean()), "accuracy": "mean"}).reset_index()


from scipy.stats import spearmanr, kendalltau

print(spearmanr(acc_series["mean"], acc_series["accuracy"]))
print(kendalltau(acc_series["mean"], acc_series["accuracy"]))

print(spearmanr(acc_series_wo_adv["mean"], acc_series_wo_adv["accuracy"]))
print(kendalltau(acc_series_wo_adv["mean"], acc_series_wo_adv["accuracy"]))

In [None]:

plt.figure(figsize=(6, 6/1.6))

scatter = sns.scatterplot(data=acc_series, x="mean", y="accuracy", hue="category", hue_order=["baseline", "adversarial_training", "augmentation", "contrastive", "training_recipes"], palette=["black", "C0", "C1", "C2", "C3"])

legend = plt.legend()
legend.texts[0].set_text("Baseline")
legend.texts[1].set_text("Adversarial Training")
legend.texts[2].set_text("Augmentations")
legend.texts[3].set_text("SSL")
legend.texts[4].set_text("Improved Training")
legend.set_title("Category")

plt.xlabel("Avg. Layer Criticality [%]")
plt.ylabel("ImageNet-1k val. accuracy [%]")

plt.savefig("weight_reset_r50_acc_vs_crit.pdf", bbox_inches="tight")

In [None]:
df_adv = df.query("category == 'adversarial_training'").copy()
df_adv["eps"] = df_adv["model_id"].apply(lambda x: float(x.replace("robust_resnet50_linf_eps", "").replace("robust_resnet50_l2_eps", "")))
df_adv['Norm'] = df_adv["model_id"].apply(lambda x: x.split("_")[-2].capitalize())

# df_adv = df_adv[["norm", "eps", "layer", "pred_cos_sim"]]
df_adv = (1 - df_adv.groupby(["Norm", "eps"])["mean"].mean().to_frame("avg_crit")) * 100
df_adv = df_adv.reset_index()


plt.figure(figsize=(6, 6/1.6)) 
sns.scatterplot(data=df_adv, x="eps", y="avg_crit", hue="Norm", palette="viridis", s=100)

plt.xlabel("Training Attack Budget ($\epsilon$)")
plt.ylabel("Avg. Criticality [%]")


plt.savefig("weight_reset_adv_training.pdf", bbox_inches="tight")

df_adv

In [None]:
plt.figure(figsize=(17, 10))
df_t = df.pivot(index=["category", "model_id"], columns="layer", values="std")
df_t.columns = sorted(df_t.columns)

# make baseline the first row
c = df_t.index[33]
df_t = df_t.loc[[c] + [x for x in df_t.index if x != c]]

mat = (df_t.values) * 100

sns.heatmap(mat, cbar_kws={'label': 'Layer Criticality Std. Err. [%]', 'pad': 0.07}, xticklabels=True, yticklabels=True, vmin=0, vmax=100, cmap="Reds")

plt.tick_params(axis='both', which='major', labelbottom=False, bottom=False, top=True, labeltop=True)

plt.xticks(np.arange(0.5, len(df_t.columns)), df_t.columns, rotation=90)
# only use second level for y labels
plt.yticks(np.arange(0.5, len(df_t.index)), map(lambda x: model_renamer(x[1]), df_t.index))

ax = plt.gca()

def add_label_patch(ax, xy, w, h, s, c, **kwargs):

    from matplotlib.colors import to_rgba

    rect = plt.Rectangle(xy, w, h, facecolor=to_rgba(c, 0.6), edgecolor=to_rgba('black', 0.6), **kwargs)
    ax.add_patch(rect)
    plt.text(xy[0] + w/2, xy[1] + h/2, s, rotation=90, color='black', va='center', ha='center')

add_label_patch(ax, (-17, 0), 2, 1, "", 'k', clip_on=False, linewidth = 1, alpha=1)
add_label_patch(ax, (-17, 1), 2, 15, "Adversarial Training", 'C0', clip_on=False, linewidth = 1)
add_label_patch(ax, (-17, 16), 2, 18, "Augmentations", 'C1',clip_on=False, linewidth = 1)
add_label_patch(ax, (-17, 34), 2, 6, "SSL", 'C2', clip_on=False, linewidth = 1)
add_label_patch(ax, (-17, 40), 2, 10, "Improved Training", 'C3', clip_on=False, linewidth = 1)

bottom_topax = ax.secondary_xaxis('bottom')
# top_ax.set_xlabel('Avg. Layer Criticality')
bottom_topax.set_xticks(np.arange(0.5, len(df_t.columns)), map(lambda x: f"({x[0]:.0f}$\pm${x[1]:.0f})", zip(mat.mean(axis=0), mat.std(axis=0))), rotation=90, c="gray", fontsize=7)
bottom_topax.tick_params(axis='x', which='major', labelbottom=True, bottom=False, top=False, labeltop=False)

right_ax = ax.secondary_yaxis('right')
# right_ax.set_ylabel('Avg. Model Criticality')
right_ax.set_yticks(np.arange(0.5, len(df_t.index)), map(lambda x: f"({x[0]:.0f}$\pm${x[1]:.0f})", zip(mat.mean(axis=1), mat.std(axis=1))), c="gray", fontsize=7)
right_ax.tick_params(axis='y', which='major', labelright=True, right=False, left=False, labelleft=False)


plt.ylabel(None)
plt.xlabel(None)

plt.savefig("weight_reset_r50_training_std.pdf", bbox_inches="tight")

mat.mean(), mat.max()

In [None]:
plt.figure(figsize=(17, 10))
df_t = df.pivot(index=["category", "model_id"], columns="layer", values="mean")
df_t.columns = sorted(df_t.columns)

# make baseline the first row
c = df_t.index[33]
df_t = df_t.loc[[c] + [x for x in df_t.index if x != c]]

mat = (1 - df_t.values) * 100

mat = mat - mat[0]

sns.heatmap(mat, cbar_kws={'label': 'Layer Criticality compared to Baseline [%]', 'pad': 0.07}, xticklabels=True, yticklabels=True, vmin=-100, vmax=100, cmap="coolwarm")

plt.tick_params(axis='both', which='major', labelbottom=False, bottom=False, top=True, labeltop=True)

plt.xticks(np.arange(0.5, len(df_t.columns)), df_t.columns, rotation=90)
# only use second level for y labels
plt.yticks(np.arange(0.5, len(df_t.index)), map(lambda x: model_renamer(x[1]), df_t.index))

ax = plt.gca()

def add_label_patch(ax, xy, w, h, s, c, **kwargs):

    from matplotlib.colors import to_rgba

    rect = plt.Rectangle(xy, w, h, facecolor=to_rgba(c, 0.6), edgecolor=to_rgba('black', 0.6), **kwargs)
    ax.add_patch(rect)
    plt.text(xy[0] + w/2, xy[1] + h/2, s, rotation=90, color='black', va='center', ha='center')

add_label_patch(ax, (-17, 0), 2, 1, "", 'k', clip_on=False, linewidth = 1, alpha=1)
add_label_patch(ax, (-17, 1), 2, 15, "Adversarial Training", 'C0', clip_on=False, linewidth = 1)
add_label_patch(ax, (-17, 16), 2, 18, "Augmentations", 'C1',clip_on=False, linewidth = 1)
add_label_patch(ax, (-17, 34), 2, 6, "SSL", 'C2', clip_on=False, linewidth = 1)
add_label_patch(ax, (-17, 40), 2, 10, "Improved Training", 'C3', clip_on=False, linewidth = 1)

bottom_topax = ax.secondary_xaxis('bottom')
# top_ax.set_xlabel('Avg. Layer Criticality')
bottom_topax.set_xticks(np.arange(0.5, len(df_t.columns)), map(lambda x: f"({x[0]:.0f}$\pm${x[1]:.0f})", zip(mat.mean(axis=0), mat.std(axis=0))), rotation=90, c="gray", fontsize=7)
bottom_topax.tick_params(axis='x', which='major', labelbottom=True, bottom=False, top=False, labeltop=False)

right_ax = ax.secondary_yaxis('right')
# right_ax.set_ylabel('Avg. Model Criticality')
right_ax.set_yticks(np.arange(0.5, len(df_t.index)), map(lambda x: f"({x[0]:.0f}$\pm${x[1]:.0f})", zip(mat.mean(axis=1), mat.std(axis=1))), c="gray", fontsize=7)
right_ax.tick_params(axis='y', which='major', labelright=True, right=False, left=False, labelleft=False)


plt.ylabel(None)
plt.xlabel(None)

plt.savefig("weight_reset_r50_training_delta.pdf", bbox_inches="tight")
