In [None]:
import numpy as np
import cmocean as cmo
import os
import yaml
import pandas as pd
import copy
import functools
from collections import OrderedDict
from matplotlib.lines import Line2D
import copy

from matplotlib import pyplot as plt
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import seaborn as sns
from matplotlib import rc
rc('text', usetex=True)
rc('text.latex', preamble=[r'\usepackage{sansmath}', r'\sansmath']) #r'\usepackage{DejaVuSans}'
rc('font',**{'family':'sans-serif','sans-serif':['DejaVu Sans']})

rc('xtick.major', pad=12)
rc('ytick.major', pad=12)
rc('grid', linewidth=1.3)

palette = sns.color_palette("colorblind")

save_dir = './figures'
# save_dir = '../../neurips_2021/figures'

In [None]:
color_1 = [33/255, 120/255, 68/255]
color_2 = [0/255, 170/255, 212/255]
color_3 = [135/255, 222/255, 170/255]
color_4 = [233/255, 175/255, 221/255]

## Initialization ablation

### ResNet-20 LayerNorm

In [None]:
train_kls = {
    0.: [0., 0., 0.],
    0.125: [1.e-5, 1.e-5, 1.e-5],
    0.25: [5.e-3, 1.e-3, 6.e-3],
    0.5: [4.e-4, 5.e-4, 1.e-4],
    0.625: [1.e-2, 1.7e-2, 1.4e-2],
    0.75: [0.46, 0.44, 0.4],
    1.: [0.496, 0.496, 0.498]
}

train_ts_agree = {
    0.: [100., 100., 100],
    0.125: [99.9, 99.9, 99.94],
    0.25: [99.53, 99.35, 99.82],
    0.5: [99.59, 99.59, 99.76],
    0.625: [97.8, 97.25, 97.6],
    0.75: [79.76, 80.32, 81.1],
    1.: [78.9, 79.3, 78.67]
}

train_loss = {
    0.: [51.16, 51.66, 51.34],
    0.125: [51.14, 51.67, 51.34],
    0.25: [51.17, 51.7, 51.37],
    0.5: [51.15, 51.69, 51.35],
    0.625: [51.19, 51.75, 51.42],
    0.75: [53.08, 53.6, 53.15],
    1.: [53.3, 53.8, 53.5]
}

In [None]:
fig, ax1 = plt.subplots(figsize=(4, 3))

color = color_1 #[233/255, 198/255, 175/255]#[183/255, 183/255, 200/255]#color_1
text_color = color_1 #[212/255, 85/255, 0/255]#color_1

lambdas =  np.array(list(train_loss.keys()))

train_loss_arr = np.stack(list(train_loss.values()))
mean, std = train_loss_arr.mean(-1), train_loss_arr.std(-1)
lb, ub = mean - std, mean + std

ax1.set_xlabel(r'$\lambda$', fontsize=16)
ax1.set_ylabel('Train Loss', color=text_color, fontsize=18)
ax1.plot(1 - lambdas, 
         mean, "-o", 
         color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k")
ax1.fill_between(1 - lambdas, lb, ub, color=color, alpha=0.35)
ax1.tick_params(axis='y', labelcolor=text_color)

# # ax1.set_yticks(corruptions, ['0', '5k', '10k', '20k', '30k', '40k', '50k'])
# # ax2.set_xticks([0, 12e3, 25e3, 37.5e3, 50e3], ['0', '12k', '25k', '37.5k', '50k'])
# ax1.set_xticks([0, 12e3, 25e3, 37.5e3, 50e3])
# ax1.set_xticklabels(['0', '12k', '25k', '37.5k', '50k'])

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
# # ax2.set_xticks([0, 12e3, 25e3, 37.5e3, 50e3], ['0', '12k', '25k', '37.5k', '50k'])


color = color_2 #[233/255, 175/255, 221/255]#[183/255, 183/255, 200/255]#color_1
text_color = [0/255, 120/255, 252/255] #[200/255, 55/255, 171/255]#color_1

train_agree_arr = np.stack(list(train_ts_agree.values()))
mean, std = train_agree_arr.mean(-1), train_agree_arr.std(-1)
print(mean, std, lb, ub)
lb, ub = mean - std, mean + std

ax2.set_ylabel('Train Agreement', color=text_color, fontsize=18)
ax2.plot(1 - lambdas, mean, "-o", color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k")
ax2.fill_between(1 - lambdas, lb, ub, color=color, alpha=0.35)
# ax2.plot(arr["num_synth"], arr["test_ts_agree_mean"], "-o", color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k")
# ax2.tick_params(axis='y', labelcolor=text_color)
# # ax2.set_xticks(arr["num_synth"])
# ax2.set_yticks([70, 71, 72, 73, 74, 75])#[72, 73, 74, 75])
# ax1.set_yticks([70, 71, 72, 73, 74, 75])


ax2.tick_params(axis='y', which='major', labelsize=14)
ax2.tick_params(axis='x', which='major', labelsize=14)
ax1.tick_params(axis='y', which='major', labelsize=14)
ax1.tick_params(axis='x', which='major', labelsize=14)
# # ax2.grid(False)
ax1.grid(True, axis="both")

# fig.tight_layout()  # otherwise the right y-label is slightly clipped
plt.title(r"\phantom{a}")
save_path = os.path.join(save_dir, "init_ablation_layernorm.pdf")
plt.savefig(save_path, bbox_inches="tight")
# # # plt.show()

In [None]:
bn_results = pd.read_csv("self_distillation_init_ablation.csv")

In [None]:
bn_results

In [None]:
bn_results["loc_param"].unique()

In [None]:
bn_results[bn_results["loc_param"] == 0.]["train_loss"].item()

In [None]:
train_loss = {
    val: bn_results[bn_results["loc_param"] == val]["train_loss"].item()
    for val in bn_results["loc_param"].unique()
}
train_ts_agree = {
    val: bn_results[bn_results["loc_param"] == val]["train_ts_agree"].item()
    for val in bn_results["loc_param"].unique()
}

In [None]:
fig, ax1 = plt.subplots(figsize=(4, 3))


color = color_1 #[233/255, 198/255, 175/255]#[183/255, 183/255, 200/255]#color_1
text_color = color_1 #[212/255, 85/255, 0/255]#color_1

lambdas = np.array(list(train_loss.keys()))

ax1.set_xlabel(r'$\lambda$', fontsize=16)
ax1.set_ylabel('Train Loss', color=text_color, fontsize=18)
ax1.plot(1 - lambdas, 
         [np.mean(v) for v in train_loss.values()], "-o", 
         color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k")
ax1.tick_params(axis='y', labelcolor=text_color)

# # ax1.set_yticks(corruptions, ['0', '5k', '10k', '20k', '30k', '40k', '50k'])
# # ax2.set_xticks([0, 12e3, 25e3, 37.5e3, 50e3], ['0', '12k', '25k', '37.5k', '50k'])
# ax1.set_xticks([0, 12e3, 25e3, 37.5e3, 50e3])
# ax1.set_xticklabels(['0', '12k', '25k', '37.5k', '50k'])

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
# # ax2.set_xticks([0, 12e3, 25e3, 37.5e3, 50e3], ['0', '12k', '25k', '37.5k', '50k'])


color = color_2 #[233/255, 175/255, 221/255]#[183/255, 183/255, 200/255]#color_1
text_color = [0/255, 120/255, 252/255] #[200/255, 55/255, 171/255]#color_1

ax2.set_ylabel('Train Agreement', color=text_color, fontsize=18)
ax2.plot(1 - lambdas, 
         [np.mean(v) for v in train_ts_agree.values()], "-o", 
         color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k")
# ax2.plot(arr["num_synth"], arr["test_ts_agree_mean"], "-o", color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k")
# ax2.tick_params(axis='y', labelcolor=text_color)
# # ax2.set_xticks(arr["num_synth"])
# ax2.set_yticks([70, 71, 72, 73, 74, 75])#[72, 73, 74, 75])
# ax1.set_yticks([70, 71, 72, 73, 74, 75])


ax2.tick_params(axis='y', which='major', labelsize=14)
ax2.tick_params(axis='x', which='major', labelsize=14)
ax1.tick_params(axis='y', which='major', labelsize=14)
ax1.tick_params(axis='x', which='major', labelsize=14)
# # ax2.grid(False)
ax1.grid(True, axis="both")

# ax1.text(0.3, 52.5, "TODO", fontsize=30)

# fig.tight_layout()  # otherwise the right y-label is slightly clipped
plt.title(r"\phantom{a}")
save_path = os.path.join(save_dir, "init_ablation_batchnorm.pdf")
plt.savefig(save_path, bbox_inches="tight")
# # # plt.show()

## Loss Surface visualizations

In [None]:
from matplotlib import colors
xs = np.arange(0, 100, 10)
cmap_colors = [cmo.cm.deep(x) for x in xs]#[::-1]
my_cmap = colors.ListedColormap(cmap_colors)

In [None]:
cmap_colors

In [None]:
npz_dir = '.'
# npz_dir = '../../../../gnosis/data/experiments/image_classification'

plane_npzs = {
    1.: np.load(os.path.join(npz_dir, "loss_visualization/init_loc_1.0/trial_0/2021-04-10_15-52-39/plane.npz")),
    .75: np.load(os.path.join(npz_dir, "loss_visualization/init_loc_0.75/trial_0/2021-04-10_17-13-45/plane.npz")),
    .625: np.load(os.path.join(npz_dir, "loss_visualization/init_loc_0.625/trial_0/2021-04-10_18-44-26/plane.npz")),
    }

In [None]:
def make_loss_surface_plot(plane_npz):

    grid = plane_npz["grid"]
    loss = plane_npz["loss"]
    agree = plane_npz["agree"]
    v_norm = plane_npz['v_norm']
    u_norm = plane_npz['u_norm']
    
    grid[:, :, 0] *= v_norm
    grid[:, :, 1] *= u_norm
    
    x_pos = np.array([0., 1., 0.5]) * v_norm
    y_pos = np.array([0., 0., 1.]) * u_norm
    teacher_text_pos = (-0.2 * v_norm, -0.2 * u_norm)
    student_text_pos = (0.7 * v_norm, -0.2 * u_norm)
    init_text_pos = (0.4 * v_norm, 0.8 * u_norm)
    
    levels = [48, 50, 56, 64, 68, 72, 74, 76, 80, 85]

    plt.figure(figsize=(3, 3))
    plt.contour(grid[:, :, 0], grid[:, :, 1], loss, zorder=1, levels=levels, colors="k", alpha=0.15)#cmap=my_cmap)
    mpb = plt.contourf(grid[:, :, 0], grid[:, :, 1], loss, zorder=0, alpha=0.9, levels=levels, cmap=my_cmap)
    plt.plot(x_pos, y_pos, "ro", ms=10, markeredgecolor="k")
    plt.text(*teacher_text_pos, "Teacher", fontsize=18)
    plt.text(*student_text_pos, "Student", fontsize=18)
    plt.text(*init_text_pos, "Init", fontsize=18)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    return mpb
#     plt.colorbar()

In [None]:
loc = 1.
make_loss_surface_plot(plane_npzs[loc])
plt.title(r"$\lambda={}$".format(1 - loc), fontsize=16)
# plt.xlabel(r"\phantom{$\lambda$}", fontsize=16)
plt.xlabel(r"$\lambda$", fontsize=16, color="white")
save_path = os.path.join(save_dir, "init_ablation_lambda0.pdf")
plt.savefig(save_path, bbox_inches="tight")

In [None]:
loc = 0.625
make_loss_surface_plot(plane_npzs[loc])
plt.title(r"$\lambda={}$".format(1 - loc), fontsize=16)
plt.xlabel(r"$\lambda$", fontsize=16, color="white")
# plt.savefig("../../neurips_2021/figures/init_ablation_lambda0625.pdf", bbox_inches="tight")
save_path = os.path.join(save_dir, "init_ablation_lambda0375.pdf")
plt.savefig(save_path, bbox_inches="tight")

In [None]:
loc = 0.75
make_loss_surface_plot(plane_npzs[loc])
plt.title(r"$\lambda={}$".format(1 - loc), fontsize=16)
plt.xlabel(r"$\lambda$", fontsize=16, color="white")
save_path = os.path.join(save_dir, "init_ablation_lambda025.pdf")
plt.savefig(save_path, bbox_inches="tight")

In [None]:
mpb = make_loss_surface_plot(plane_npzs[loc])

In [None]:
# fig,ax = plt.subplots(figsize=(1, 3))
# plt.colorbar(mpb, ax=ax, aspect=15, label='Train Loss')
# ax.remove()
# save_path = os.path.join(save_dir, "init_ablation_cbar.pdf")
# plt.savefig(save_path, bbox_inches="tight")

In [None]:
list(plane_npzs[loc].keys())

In [None]:
plane_npzs[1.]['v_norm']

## How well are we solving the optimization problem?

In [None]:
big_df = pd.read_csv("results_combined_v5.csv")
# big_df["gan_ratio"].unique()

In [None]:
big_df = pd.read_csv("results_combined_v4.csv")


masks = OrderedDict()
masks["Baseline"] = {
        "teacher_depth": 56,
        "classifier_depth": 56,
        "mixup_used": False,
        "temperature": 1.,
        "augmentations": "crop,horizontal_flip",
        "dataset": "cifar100",
        "tensor_augmentations": "[]",
        "gan_ratio": "Not Specified"
    }
masks[r"Baseline($T=4$)"] = {
        "teacher_depth": 56,
        "classifier_depth": 56,
        "mixup_used": False,
        "temperature": 4.,
        "augmentations": "crop,horizontal_flip",
        "dataset": "cifar100",
        "tensor_augmentations": "[]"
    }
masks["Rotation"] = {
        "teacher_depth": 56,
        "classifier_depth": 56,
        "mixup_used": False,
        "temperature": 1.,
        "augmentations": "rotation,crop,horizontal_flip",
        "dataset": "cifar100",
        "tensor_augmentations": "[]"
    }
masks["Vertical Flip"] = {
        "teacher_depth": 56,
        "classifier_depth": 56,
        "mixup_used": False,
        "temperature": 1.,
        "augmentations": "vertical_flip,crop,horizontal_flip",
        "dataset": "cifar100",
        "tensor_augmentations": "[]"
    }
masks["ColorJitter"] = {
        "teacher_depth": 56,
        "classifier_depth": 56,
        "mixup_used": False,
        "temperature": 1.,
        "augmentations": "colorjitter,crop,horizontal_flip",
        "dataset": "cifar100",
        "tensor_augmentations": "[]"
    }
masks["Combined Augs"] = {
        "teacher_depth": 56,
        "classifier_depth": 56,
        "mixup_used": False,
        "temperature": 1.,
        "augmentations": "rotation,vertical_flip,colorjitter,crop,horizontal_flip",
        "dataset": "cifar100",
        "tensor_augmentations": "[]",
        "num_epochs": 300
    }
masks["CombAug2k"] = {
        "teacher_depth": 56,
        "classifier_depth": 56,
        "mixup_used": False,
        "temperature": 1.,
        "augmentations": "rotation,vertical_flip,colorjitter,crop,horizontal_flip",
        "dataset": "cifar100",
        "tensor_augmentations": "[]",
        "num_epochs": 2000
    }
# masks["MixUp"] = {
#         "teacher_depth": 56,
#         "classifier_depth": 56,
#         "mixup_used": True,
#         "temperature": 1.,
#         "augmentations": "crop,horizontal_flip",
#         "dataset": "cifar100",
#         "tensor_augmentations": "[]"
#     }
# masks[r"MixUp($T=4$)"] = {
#         "teacher_depth": 56,
#         "classifier_depth": 56,
#         "mixup_used": True,
#         "temperature": 4.,
#         "augmentations": "crop,horizontal_flip",
#         "dataset": "cifar100",
#         "tensor_augmentations": "[]"
#     }
# masks[r"GAN($T=4$)"] = {
#         "teacher_depth": 56,
#         "classifier_depth": 56,
#         "gan_ratio": '0.2'
# }
# # masks["GAN(0.4)"] = {
# #         "teacher_depth": 56,
# #         "classifier_depth": 56,
# #         "gan_ratio": '0.4'
# # }
# # masks["GAN(0.8)"] = {
# #         "teacher_depth": 56,
# #         "classifier_depth": 56,
# #         "gan_ratio": '0.8'
# # }
# masks["OOD"] = {
#         "teacher_depth": 56,
#         "classifier_depth": 56,
#         "dataset": "cifar100_svhn",
#         "tensor_augmentations": "[]"
#     }
# masks["Noise"] = {
#         "teacher_depth": 56,
#         "classifier_depth": 56,
#         "tensor_augmentations": "replace_with_uniform"
#     }

In [None]:
metrics = OrderedDict()
metrics["Accuracy"] = "student_test_accuracy"
# metrics["NLL"] = "student_test_nll"
metrics["Agreement"] = "test_matching"
metrics[r"$KL(T \vert\vert S)$"] = "test_kl_ts"
metrics["Train Agreement"] = "train_matching"
metrics["Aug Train Agreement"] = "aug_train_matching"
metrics["Teacher Size"] = "teacher_num_components"
# metrics["TAT Accuracy"] = "teacher_aug_train_accuracy"

df = None
for j, (exp_name, mask_dict) in enumerate(masks.items()):
    for num_components in 1, 3, 5:
        mask_dict.update({"teacher_num_components": num_components})

        mask = functools.reduce(
            lambda a, b: a & b, [big_df[key] == value for key, value in mask_dict.items()],
            True)
        masked_results = big_df[mask]
        print("Number of trials:", len(masked_results))
        mean = masked_results.mean().to_frame().T
        mean["name"] = exp_name
        std = masked_results.std().to_frame().T
        std["name"] = exp_name + " std"
        std["teacher_num_components"] = mean["teacher_num_components"]
        if df is None:
            df = mean[["name", "teacher_test_accuracy", "teacher_test_nll", *metrics.values()]]
        else:
            df = df.append(mean[["name", "teacher_test_accuracy", *metrics.values()]])
        df = df.append(std[["name", "teacher_test_accuracy", *metrics.values()]])
df = df.reset_index()

In [None]:
df[(df["name"] == "CombAug2k")]

In [None]:


fig = plt.figure(figsize=(6.5,2))
ax = fig.add_subplot(111)
augs = ["Baseline", "Rotation", "Vertical Flip", "ColorJitter", "Combined Augs"]

n_teachers = 1
xs = np.arange(len(augs))
ys = [df[(df["name"] == aug) & (df["teacher_num_components"] == n_teachers)]["aug_train_matching"].tolist()[0]
         for aug in augs]
ys = np.array(ys)
sigmas = np.array(sigmas)
ax.plot(xs, ys * 100, "-o", ms=12, mec="k", lw=3, color=color_2, label="Student")

n_teachers = 3
xs = np.arange(len(augs))
ys = [df[(df["name"] == aug) & (df["teacher_num_components"] == n_teachers)]["aug_train_matching"].tolist()[0]
         for aug in augs]
ys = np.array(ys)
sigmas = np.array(sigmas)
ax.plot(xs, ys * 100, "-o", ms=12, mec="k", lw=3, color=color_4, label="Student")

n_teachers = 5
xs = np.arange(len(augs))
ys = [df[(df["name"] == aug) & (df["teacher_num_components"] == n_teachers)]["aug_train_matching"].tolist()[0]
         for aug in augs]
ys = np.array(ys)
sigmas = np.array(sigmas)
ax.plot(xs, ys * 100, "-o", ms=12, mec="k", lw=3, color=color_3, label="Student")


ax.set_xticks(xs)
ax.set_xticklabels(augs, fontsize=14, rotation=0)
ax.tick_params(axis='both', labelsize=14)

ax.set_xlim(-0.15, 4.15)

ax.set_yticks([25, 50, 75, 100])
ax.set_ylim(15, 110)

ax.grid(True)
plt.ylabel("Train Agreement", fontsize=18)

plt.title("Data Augmentation", fontsize=16)

plt.savefig("../../neurips_2021/figures/train_agreement_augs_combined.pdf".format(n_teachers), 
            bbox_inches="tight")

## GAN Augmentation

In [None]:
df = pd.read_csv("distillation_synthetic_augmentation_cifar100.csv")
df

In [None]:
fig = plt.figure(figsize=(6.5,2))
ax = fig.add_subplot(111)

augs = df["num_synth"].unique()

n_teachers = 1
xs = np.arange(len(augs))
ys = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["train_ts_agree_mean"].tolist()[0]
         for aug in augs]
ys = np.array(ys)
sigmas = np.array(sigmas)
ax.plot(xs, ys, "-o", ms=12, mec="k", lw=3, color=color_2, label="1 Teacher")

n_teachers = 3
xs = np.arange(len(augs))
ys = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["train_ts_agree_mean"].tolist()[0]
         for aug in augs]
ys = np.array(ys)
sigmas = np.array(sigmas)
ax.plot(xs, ys, "-o", ms=12, mec="k", lw=3, color=color_4, label="3 Teachers")

n_teachers = 5
xs = np.arange(len(augs))
ys = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["train_ts_agree_mean"].tolist()[0]
         for aug in augs]
ys = np.array(ys)
sigmas = np.array(sigmas)
ax.plot(xs, ys, "-o", ms=12, mec="k", lw=3, color=color_3, label="5 Teachers")


ax.set_xticks(xs)
# ax.set_xticklabels(map(int, augs), fontsize=14, rotation=0)
ax.set_xticklabels(["CIFAR-100", "+12.5k GAN", "+25k GAN", "+37.5k GAN", "+50k GAN"], fontsize=14, rotation=0)
ax.tick_params(axis='both', labelsize=14)

# ax.set_xlim(-0.15, 4.15)

ax.set_yticks([90, 95, 100])
ax.set_ylim(87, 103)

ax.grid(True)
plt.ylabel("Train Agreement", fontsize=18)

plt.title("GAN-Generated Data", fontsize=16)

handles = ax.get_legend_handles_labels()

plt.savefig("../../neurips_2021/figures/train_agreement_gan.pdf".format(n_teachers), 
            bbox_inches="tight")

In [None]:
# handles += ax.get_legend_handles_labels()
# handles = (handles[0] + [ax.get_legend_handles_labels()[0][1]],
#            handles[1] + [ax.get_legend_handles_labels()[1][1]])

# handles = [handles[0][i] for i in [0,1,3,2]], [handles[1][i] for i in [0,1,3,2]]
figlegend = plt.figure(figsize=(2,2))
plt.legend(*handles, 
           loc ='upper left',
           fontsize=22,
           ncol=3
          )
plt.axis("off")
plt.savefig("../../neurips_2021/figures/train_agreement_legend.pdf", bbox_inches="tight")

## Intro figure

In [None]:
df = pd.read_csv("distillation_synthetic_augmentation_cifar100.csv")

In [None]:
df.keys()

In [None]:
n_teachers = 1
augs = df["num_synth"].unique()

xs = np.arange(len(augs))
ys1 = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["test_ts_agree_mean"].tolist()[0]
         for aug in augs]
ys1_ub = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["test_ts_agree_ub"].tolist()[0]
         for aug in augs]
ys1_lb = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["test_ts_agree_lb"].tolist()[0]
         for aug in augs]

ys2 = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["test_acc_mean"].tolist()[0]
         for aug in augs]
ys2_ub = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["test_acc_ub"].tolist()[0]
         for aug in augs]
ys2_lb = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["test_acc_lb"].tolist()[0]
         for aug in augs]

teacher_acc = df[(df["n_teach"] == n_teachers)]["teacher_test_acc_mean"].mean()


fig, ax1 = plt.subplots(figsize=(6.5, 3))

color = color_2 #[233/255, 175/255, 221/255]#[183/255, 183/255, 200/255]#color_1
text_color = [0/255, 120/255, 252/255] #[200/255, 55/255, 171/255]#color_1

# ax1.set_xlabel('\# Synthetic Examples', fontsize=16)
ax1.set_ylabel('Test Agreement', color=text_color, fontsize=18)
ax1.plot(xs, ys1, "-o", 
         color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k")
ax1.fill_between(xs, ys1_lb, ys1_ub, color=color, alpha=0.5, ec="k")

ax1.tick_params(axis='y', labelcolor=text_color)
ax1.set_xticks(xs)
# ax.set_xticklabels(map(int, augs), fontsize=14, rotation=0)
ax1.set_xticklabels(["CIFAR-100", "+12.5k GAN", "+25k GAN", "+37.5k GAN", "+50k GAN"], fontsize=14, rotation=0)
# ax1.set_yticks([70, 72, 74, 76, 78, 80])
# ax1.set_ylim([70, 81])
ax1.set_yticks([70, 71, 72, 73, 74, 75])
ax1.set_ylim([70, 76])

color = color_3 #[233/255, 198/255, 175/255]#[183/255, 183/255, 200/255]#color_1
text_color = color_1 #[212/255, 85/255, 0/255]#color_1

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
ax2.plot(xs, ys2, "-o", 
         color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k")
ax2.fill_between(xs, ys2_lb, ys2_ub, color=color, alpha=0.35, ec="k")

# ax2.set_yticks([70, 72, 74, 76, 78, 80])
# ax2.set_ylim([70, 81])
ax2.set_yticks([70, 71, 72, 73, 74, 75])
ax2.set_ylim([70, 76])

ax2.tick_params(axis='y', labelcolor=text_color)
ax2.set_ylabel('Test Accuracy', color=text_color, fontsize=18)

ax2.plot([xs[0], xs[-1]], [teacher_acc, teacher_acc], "--", color=color, lw=3)



# ax2.set_ylabel('Train Agreement', color=text_color, fontsize=18)
# ax2.plot(train_ts_agree.keys(), 
#          [np.mean(v) for v in train_ts_agree.values()], "-o", 
#          color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k")
# # ax2.plot(arr["num_synth"], arr["test_ts_agree_mean"], "-o", color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k")
# # ax2.tick_params(axis='y', labelcolor=text_color)
# # # ax2.set_xticks(arr["num_synth"])
# # ax2.set_yticks([70, 71, 72, 73, 74, 75])#[72, 73, 74, 75])
# # ax1.set_yticks([70, 71, 72, 73, 74, 75])


ax2.tick_params(axis='y', which='major', labelsize=14)
ax2.tick_params(axis='x', which='major', labelsize=14)
ax1.tick_params(axis='y', which='major', labelsize=14)
ax1.tick_params(axis='x', which='major', labelsize=14)
# # ax2.grid(False)
ax1.grid(True, axis="both")

# fig.tight_layout()  # otherwise the right y-label is slightly clipped
# plt.title(r"\phantom{a}")
plt.savefig("../../neurips_2021/figures/motivation_self_distillation_cifar100.pdf", bbox_inches="tight")
# # # plt.show()

In [None]:
n_teachers = 3
augs = df["num_synth"].unique()

xs = np.arange(len(augs))
ys1 = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["test_ts_agree_mean"].tolist()[0]
         for aug in augs]
ys1_ub = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["test_ts_agree_ub"].tolist()[0]
         for aug in augs]
ys1_lb = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["test_ts_agree_lb"].tolist()[0]
         for aug in augs]

ys2 = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["test_acc_mean"].tolist()[0]
         for aug in augs]
ys2_ub = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["test_acc_ub"].tolist()[0]
         for aug in augs]
ys2_lb = [df[(df["num_synth"] == aug) & (df["n_teach"] == n_teachers)]["test_acc_lb"].tolist()[0]
         for aug in augs]

teacher_acc = df[(df["n_teach"] == n_teachers)]["teacher_test_acc_mean"].mean()


fig, ax1 = plt.subplots(figsize=(6.5, 3))

color = color_2 #[233/255, 175/255, 221/255]#[183/255, 183/255, 200/255]#color_1
text_color = [0/255, 120/255, 252/255] #[200/255, 55/255, 171/255]#color_1

# ax1.set_xlabel('\# Synthetic Examples', fontsize=16)
ax1.set_ylabel('Test Agreement', color=text_color, fontsize=18)
ax1.plot(xs, ys1, "-o", color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k",
        label="Teacher-Student Agreement")
ax1.fill_between(xs, ys1_lb, ys1_ub, color=color, alpha=0.35, ec="k")

ax1.tick_params(axis='y', labelcolor=text_color)
ax1.set_xticks(xs)
# ax.set_xticklabels(map(int, augs), fontsize=14, rotation=0)
ax1.set_xticklabels(["CIFAR-100", "+12.5k GAN", "+25k GAN", "+37.5k GAN", "+50k GAN"], fontsize=14, rotation=0)
ax1.set_yticks([70, 72, 74, 76, 78, 80])
ax1.set_ylim([70, 81])

color = color_3 #[233/255, 198/255, 175/255]#[183/255, 183/255, 200/255]#color_1
text_color = color_1 #[212/255, 85/255, 0/255]#color_1


ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
ax2.plot(xs, ys2, "-o", color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k",
        label="Student Accuracy")
ax2.fill_between(xs, ys2_lb, ys2_ub, color=color, alpha=0.5, ec="k")
ax2.plot([xs[0], xs[-1]], [teacher_acc, teacher_acc], "--", color=color, lw=3,
        label="Teacher Accuracy")

ax2.set_yticks([70, 72, 74, 76, 78, 80])
ax2.set_ylim([70, 81])
ax2.tick_params(axis='y', labelcolor=text_color)
ax2.set_ylabel('Test Accuracy', color=text_color, fontsize=18)



# ax2.set_ylabel('Train Agreement', color=text_color, fontsize=18)
# ax2.plot(train_ts_agree.keys(), 
#          [np.mean(v) for v in train_ts_agree.values()], "-o", 
#          color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k")
# # ax2.plot(arr["num_synth"], arr["test_ts_agree_mean"], "-o", color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k")
# # ax2.tick_params(axis='y', labelcolor=text_color)
# # # ax2.set_xticks(arr["num_synth"])
# # ax2.set_yticks([70, 71, 72, 73, 74, 75])#[72, 73, 74, 75])
# # ax1.set_yticks([70, 71, 72, 73, 74, 75])


ax2.tick_params(axis='y', which='major', labelsize=14)
ax2.tick_params(axis='x', which='major', labelsize=14)
ax1.tick_params(axis='y', which='major', labelsize=14)
ax1.tick_params(axis='x', which='major', labelsize=14)
# # ax2.grid(False)
ax1.grid(True, axis="both")

# fig.tight_layout()  # otherwise the right y-label is slightly clipped
# plt.title(r"\phantom{a}")
plt.savefig("../../neurips_2021/figures/motivation_3_ensemble_distillation_cifar100.pdf", bbox_inches="tight")
# # # plt.show()

In [None]:
figlegend = plt.figure(figsize=(0,0))
handles, labels = ax1.get_legend_handles_labels()
handles2, labels2 = ax2.get_legend_handles_labels()
handles += handles2
labels += labels2
plt.axis("off")
# legend = f.legend(handles, labels, loc=(0.3, -0.0), ncol=2, fontsize=18)
legend = plt.legend(handles, labels, bbox_to_anchor=(0.6, -.17), loc='lower center', ncol=3, fontsize=14)
plt.savefig("../../neurips_2021/figures/motivation_distillation_cifar100_legend.pdf", bbox_inches="tight")
# legend.get_frame().set_linewidth(2)
# legend.get_frame().set_edgecolor("k")

### MNIST

In [None]:
def mask_from_dict(mask_dict):
    return functools.reduce(
        lambda a, b: a & b, 
        [big_df[key] == value for key, value in mask_dict.items()], True)

In [None]:
big_df = pd.read_csv("results_combined_mnist.csv")

masks = [
    {
        "dataset": "mnist",
        "teacher_num_components": 1
    },
    {
        "dataset": "emnist",
        "teacher_num_components": 1,
        "subsample_ratio": 0.25,
    },
    {
        "dataset": "emnist",
        "teacher_num_components": 1,
        "subsample_ratio": 0.5,
    },
    {
        "dataset": "emnist",
        "teacher_num_components": 1,
        "subsample_ratio": 1.,
    },
]

In [None]:
def get_mean_std(masks, metric):
    means = [big_df[mask_from_dict(mask)][metric].mean() for mask in masks]
    stds = [big_df[mask_from_dict(mask)][metric].std() for mask in masks]
    return np.array(means), np.array(stds)

acc_mean, acc_std = get_mean_std(masks, "student_test_accuracy")
agg_mean, agg_std = get_mean_std(masks, "test_matching")
kl_mean, kl_std = get_mean_std(masks, "test_kl_ts")
b_agg_mean, b_agg_std = get_mean_std(masks, "baseline_test_matching")

In [None]:
b_agg_mean

In [None]:
fig = plt.figure(figsize=(4,3))
ax = fig.add_subplot(111)

xs = np.arange(len(agg_mean))

ax.plot(agg_mean, 
        "-o", ms=12, mec="k", lw=3, color=color_2, label="Student")
ax.fill_between(xs, (agg_mean+agg_std), 
                (agg_mean-agg_std), color=color_2, alpha=0.35, ec="k")


# ax.set_xlabel("Teacher size", fontsize=18)
ax.set_xticks([0, 1, 2, 3])
ax.set_xticklabels(["MNIST 60k", r"EMNIST 175k", r"EMNIST 350k", r"EMNIST 700k"],fontsize=14, rotation=45)
ax.tick_params(axis='both', labelsize=14)
#     ax.set_yticks(fontsize=14)
ax.grid(True)
plt.ylabel("Test Agreement", fontsize=18)
# plt.legend(fontsize=16)
plt.savefig("../../neurips_2021/figures/mnist_agreement.pdf", bbox_inches="tight")

In [None]:
fig = plt.figure(figsize=(4,3))
ax = fig.add_subplot(111)

xs = np.arange(len(agg_mean))

ax.plot(acc_mean * 100, 
        "-o", ms=12, mec="k", lw=3, color=color_3, label="Student")
ax.fill_between(xs, (acc_mean+acc_std) * 100, 
                (acc_mean-acc_std) * 100, color=color_3, alpha=0.5, ec="k")
ax.hlines(85.34, -1, 5, "k", linestyle="dashed", lw=2)

# ax.set_xlabel("Teacher size", fontsize=18)
ax.set_xticks([0, 1, 2, 3])
ax.set_yticks([83, 84, 85, 86, 87])
ax.set_xticklabels(["MNIST 60k", r"EMNIST 175k", r"EMNIST 350k", r"EMNIST 700k"],fontsize=14, rotation=45)
ax.tick_params(axis='both', labelsize=14)
#     ax.set_yticks(fontsize=14)
ax.grid(True)

plt.ylabel("Test Accuracy", fontsize=18)
ax.set_xlim(-0.15, 3.15)
# plt.legend(fontsize=16)
plt.savefig("../../neurips_2021/figures/mnist_acc.pdf", bbox_inches="tight")

In [None]:
n_teachers = 1
augs = df["num_synth"].unique()

xs = np.arange(len(agg_mean))
teacher_acc = df[(df["n_teach"] == n_teachers)]["teacher_test_acc_mean"].mean()


fig, ax1 = plt.subplots(figsize=(4, 3))

# color = color_2 #[233/255, 198/255, 175/255]#[183/255, 183/255, 200/255]#color_1
# text_color = color_1 #[212/255, 85/255, 0/255]#color_1
color = color_2 #[233/255, 175/255, 221/255]#[183/255, 183/255, 200/255]#color_1
text_color = [0/255, 120/255, 252/255] #[200/255, 55/255, 171/255]#color_1

# ax1.set_xlabel('\# Synthetic Examples', fontsize=16)
ax1.set_ylabel('Test Agreement', color=text_color, fontsize=18)
ax1.plot(xs, agg_mean * 100, "-o", 
         color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k", label="Agreement")
ax1.fill_between(xs, (agg_mean-agg_std) * 100, (agg_mean+acc_std) * 100, 
                      color=color, alpha=0.5, ec="k")

ax1.tick_params(axis='y', labelcolor=text_color)
ax1.set_xticks([0, 1, 2, 3])
# ax.set_yticks([83, 84, 85, 86, 87])
ax1.set_ylim([80, 101])
ax1.set_xticklabels(["MNIST 60k", r"EMNIST 175k", r"EMNIST 350k", r"EMNIST 700k"],fontsize=14, rotation=15)

# color = color_4 #[233/255, 198/255, 175/255]#[183/255, 183/255, 200/255]#color_1
#text_color = color_4 #[212/255, 85/255, 0/255]#color_1
color = color_3 #[233/255, 198/255, 175/255]#[183/255, 183/255, 200/255]#color_1
text_color = color_1 #[212/255, 85/255, 0/255]#color_1

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
ax2.plot(xs, acc_mean * 100, "-o", 
         color=color, lw=3, ms=12, markeredgewidth=1., markeredgecolor="k", label="Student Accuracy")
ax2.fill_between(xs, (acc_mean-acc_std)*100, (acc_mean+acc_std)*100, color=color, alpha=0.35, ec="k")
# ax2.set_yticks([70, 72, 74, 76, 78, 80])
ax2.set_ylim([80, 101])
ax2.tick_params(axis='y', labelcolor=text_color)
ax2.set_ylabel('Test Accuracy', color=text_color, fontsize=18)

ax2.plot([xs[0], xs[-1]], [85.34, 85.34], "--", color=text_color, lw=3, label="Teacher Accuracy")


ax2.tick_params(axis='y', which='major', labelsize=14)
ax2.tick_params(axis='x', which='major', labelsize=14)
ax1.tick_params(axis='y', which='major', labelsize=14)
ax1.tick_params(axis='x', which='major', labelsize=14)
# # ax2.grid(False)
ax1.grid(True, axis="both")

plt.savefig("../../neurips_2021/figures/mnist_combined.pdf", bbox_inches="tight")
# fig.tight_layout()  # otherwise the right y-label is slightly clipped
# plt.title(r"\phantom{a}")
# plt.savefig("../../neurips_2021/figures/motivation_self_distillation_cifar100.pdf", bbox_inches="tight")
# # # plt.show()

In [None]:
figlegend = plt.figure(figsize=(0,0))
handles, labels = ax2.get_legend_handles_labels()
handles2, labels2 = ax1.get_legend_handles_labels()
handles += handles2
labels += labels2
plt.axis("off")
# legend = f.legend(handles, labels, loc=(0.3, -0.0), ncol=2, fontsize=18)
legend = plt.legend(handles, labels, bbox_to_anchor=(0.6, -.17), loc='lower center', ncol=2, fontsize=14)
plt.savefig("../../neurips_2021/figures/mnist_combined_legend.pdf", bbox_inches="tight")
# legend.get_frame().set_linewidth(2)
# legend.get_frame().set_edgecolor("k")

## Optimizer ablation

In [None]:
np.mean([81.11, 81.41, 81.29])

In [None]:
agreements = [
    ("SGD", "300", [78.9, 79.3, 78.67]),
    ("SGD", "1k", [81.11, 81.41, 81.29]),
    ("SGD", "5k", [83.43, 83.42, 83.26]),
    ("Adam", "1k", [80.57, 80.64, 80.19]),
    ("Adam", "5k", [82.65, 83.19, 82.78]),
]


# sgd_xlabels = ["300", "1k", "5k"]
# sgd_agree = np.array([
#     [78.9, 79.3, 78.67],
#     [81.11, 81.41, 81.29],
#     [83.43, 83.42, 83.26],
# ])

# adam_xlabels = ["1k", "5k"]
# adam_agree = np.array([
#     [80.57, 80.64, 80.19],
#     [82.65, 83.19, 82.78]
# ])

In [None]:
fig, ax = plt.subplots(figsize=(3, 3))

ax.set_xlabel('Number of Epochs', fontsize=18)
ax.set_ylabel('Train Agreement', fontsize=18)

xs = [0, 1, 2, 3.5, 4.5]
for x, (name, epochs, res) in zip(xs, agreements):
    color = color_2 if name=="SGD" else color_4
    ax.bar(x, np.mean(res), yerr=np.std(res), capsize=4, width=1., ec="k", color=color, alpha=0.8)

ax.grid(axis="y")
ax.set_ylim(78, 85.99)
ax.set_xticks(xs)
ax.set_xticklabels(["{}".format(e) for _, e, _ in agreements], 
                   rotation=0, fontsize=14)
ax.tick_params(axis='both', which='major', labelsize=14)

# ax.text(3.1, 84, "\{", color="k", fontsize=50, rotation=-90)
ax.plot([3., 5], [84.5, 84.5], "-k", lw=3)
ax.plot([3., 3.], [84., 84.5], "-k", lw=3)
ax.plot([5, 5], [84., 84.5], "-k", lw=3)
ax.text(3.4, 85.1, "Adam", color="k", fontsize=16, rotation=0)

# ax.text(0., 84, "\{", color="k", fontsize=70, rotation=-90)
ax.plot([-0.5, 2.5], [84.5, 84.5], "-k", lw=3)
ax.plot([-0.5, -0.5], [84., 84.5], "-k", lw=3)
ax.plot([2.5, 2.5], [84., 84.5], "-k", lw=3)
ax.text(0.4, 85.1, "SGD", color="k", fontsize=16, rotation=0)

ax.set_title(r"\phantom{a}")

# ax1.tick_params(axis='y', labelcolor=text_color)
save_path = os.path.join(save_dir, "optimizer_ablation.pdf")
plt.savefig(save_path, bbox_inches="tight")
# plt.savefig("../../neurips_2021/figures/optimizer_ablation.pdf", bbox_inches="tight")