### Can three ResNet-20 models be distilled into another ResNet-20?

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import math

from upcycle.plotting.credible_regions import get_arm, draw_arm_comparison
from pathlib import Path

sns.set(style='whitegrid', font_scale=1.25)

In [None]:
exp_dir = '../data/experiments/image_classification'
dataset = 'cifar10'

soft_label_arms = {
    '1 --> 1 (no synth)': f'synth_aug_0.0_{dataset}_v0.0.7/preresnet20_1-preresnet20_1',
    '3 --> 1 (no synth)': f'synth_aug_0.0_{dataset}_v0.0.7/preresnet20_3-preresnet20_1',
    '1 --> 1 (1:5 synth)': f'synth_aug_0.2_{dataset}_v0.0.7/preresnet20_1-preresnet20_1',
    '3 --> 1 (1:5 synth)': f'synth_aug_0.2_{dataset}_v0.0.7/preresnet20_3-preresnet20_1',
#     '4 --> 1 (no synth)': f'synth_aug_0.0_{dataset}_v0.0.7/preresnet20_4-preresnet20_1',
#     '4 --> 1 (1:5 synth)': f'synth_aug_0.2_{dataset}_v0.0.7/preresnet20_4-preresnet20_1',
    '5 --> 1 (no synth)': f'synth_aug_0.0_{dataset}_v0.0.7/preresnet20_5-preresnet20_1',
    '5 --> 1 (1:5 synth)': f'synth_aug_0.2_{dataset}_v0.0.7/preresnet20_5-preresnet20_1',
}

hard_label_arms = {
    '1 --> 1 (no synth)': f'synth_aug_0.0_{dataset}_hard_labels_v0.0.7/preresnet20_1-preresnet20_1',
    '3 --> 1 (no synth)': f'synth_aug_0.0_{dataset}_hard_labels_v0.0.7/preresnet20_3-preresnet20_1',
    '1 --> 1 (1:5 synth)': f'synth_aug_0.2_{dataset}_hard_labels_v0.0.7/preresnet20_1-preresnet20_1',
    '3 --> 1 (1:5 synth)': f'synth_aug_0.2_{dataset}_hard_labels_v0.0.7/preresnet20_3-preresnet20_1',
#     '4 --> 1 (no synth)': f'synth_aug_0.0_{dataset}_hard_labels_v0.0.7/preresnet20_4-preresnet20_1',
#     '4 --> 1 (1:5 synth)': f'synth_aug_0.2_{dataset}_hard_labels_v0.0.7/preresnet20_4-preresnet20_1',
    '5 --> 1 (no synth)': f'synth_aug_0.0_{dataset}_hard_labels_v0.0.7/preresnet20_5-preresnet20_1',
    '5 --> 1 (1:5 synth)': f'synth_aug_0.2_{dataset}_hard_labels_v0.0.7/preresnet20_5-preresnet20_1',
}

In [None]:
plot_config = dict(
    table_name='student_train_metrics',
    x_key='epoch',
    window=4,
    xlim=(100, 200),
)

In [None]:
plot_config['y_key'] = 'train_loss'
plot_config['ylim'] = (0, 1)
fig = plt.figure(figsize=(12, 5))
fig.suptitle(dataset.upper())

ax_1 = fig.add_subplot(1, 2, 1)
ax_1.set_title('Hard Labels')
ax_1 = draw_arm_comparison(ax_1, exp_dir, hard_label_arms, **plot_config)

ax_2 = fig.add_subplot(1, 2, 2)
ax_2.set_title('Soft Labels')
ax_2 = draw_arm_comparison(ax_2, exp_dir, soft_label_arms, **plot_config)
ax_2.legend(loc='upper right')

plt.tight_layout()

In [None]:
plot_config['y_key'] = 'train_acc'
plot_config['ylabel'] = 'train error'
plot_config['ylim'] = (0, 0.06)
plot_config['transform'] = lambda x: 1 - x / 100
fig = plt.figure(figsize=(12, 5))
fig.suptitle(dataset.upper())

ax_1 = fig.add_subplot(1, 2, 1)
ax_1.set_title('Hard Labels')
ax_1 = draw_arm_comparison(ax_1, exp_dir, hard_label_arms, **plot_config)
x_range, mean, lb, ub = get_arm(exp_dir, hard_label_arms['1 --> 1 (no synth)'], 'student_train_metrics.csv', 'epoch', 'teacher_train_acc',
                              plot_config['window'], transform=lambda x: 1 - x / 100)
ax_1.hlines(mean[-1], x_range[0], x_range[-1], color='blue', linestyle='--')
x_range, mean, lb, ub = get_arm(exp_dir, hard_label_arms['3 --> 1 (no synth)'], 'student_train_metrics.csv', 'epoch', 'teacher_train_acc',
                              plot_config['window'], transform=lambda x: 1 - x / 100)
ax_1.hlines(mean[-1], x_range[0], x_range[-1], color='orange', linestyle='--')

ax_2 = fig.add_subplot(1, 2, 2)
ax_2.set_title('Soft Labels')
ax_2 = draw_arm_comparison(ax_2, exp_dir, soft_label_arms, **plot_config)
x_range, mean, _, _ = get_arm(exp_dir, soft_label_arms['1 --> 1 (no synth)'], 'student_train_metrics.csv', 'epoch', 'teacher_train_acc',
                              plot_config['window'], transform=lambda x: 1 - x / 100)
ax_2.hlines(mean[-1], x_range[0], x_range[-1], color='blue', linestyle='--')
x_range, mean, lb, ub = get_arm(exp_dir, soft_label_arms['3 --> 1 (no synth)'], 'student_train_metrics.csv', 'epoch', 'teacher_train_acc',
                              plot_config['window'], transform=lambda x: 1 - x / 100)
ax_2.hlines(mean[-1], x_range[0], x_range[-1], color='orange', linestyle='--')
    
ax_2.legend(loc='upper right')

plt.tight_layout()
# plt.savefig(f'figures/image_classification/sngan_preresnet20_{dataset}_train_err.pdf')

In [None]:
plot_config['y_key'] = 'test_acc'
plot_config['ylabel'] = 'test error'
plot_config['ylim'] = (0.05, 0.15)
plot_config['transform'] = lambda x: 1 - x / 100
fig = plt.figure(figsize=(12, 5))
fig.suptitle(dataset.upper())

ax_1 = fig.add_subplot(1, 2, 1)
ax_1.set_title('Hard Labels')
ax_1 = draw_arm_comparison(ax_1, exp_dir, hard_label_arms, **plot_config)
x_range, mean, _, _ = get_arm(exp_dir, hard_label_arms['1 --> 1 (no synth)'], 'student_train_metrics.csv', 'epoch', 'teacher_test_acc',
                              plot_config['window'], transform=lambda x: 1 - x / 100)
ax_1.hlines(mean[-1], x_range[0], x_range[-1], color='blue', linestyle='--')
x_range, mean, _, _ = get_arm(exp_dir, hard_label_arms['3 --> 1 (no synth)'], 'student_train_metrics.csv', 'epoch', 'teacher_test_acc',
                              plot_config['window'], transform=lambda x: 1 - x / 100)
ax_1.hlines(mean[-1], x_range[0], x_range[-1], color='orange', linestyle='--')

ax_2 = fig.add_subplot(1, 2, 2)
ax_2.set_title('Soft Labels')
ax_2 = draw_arm_comparison(ax_2, exp_dir, soft_label_arms, **plot_config)
x_range, mean, _, _ = get_arm(exp_dir, soft_label_arms['1 --> 1 (no synth)'], 'student_train_metrics.csv', 'epoch', 'teacher_test_acc',
                              plot_config['window'], transform=lambda x: 1 - x / 100)
ax_2.hlines(mean[-1], x_range[0], x_range[-1], color='blue', linestyle='--')
x_range, mean, _, _ = get_arm(exp_dir, soft_label_arms['3 --> 1 (no synth)'], 'student_train_metrics.csv', 'epoch', 'teacher_test_acc',
                              plot_config['window'], transform=lambda x: 1 - x / 100)
ax_2.hlines(mean[-1], x_range[0], x_range[-1], color='orange', linestyle='--')
    
ax_2.legend(loc='upper right')

plt.tight_layout()