In [None]:
import wandb
import matplotlib as mpl
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm import tqdm
import numpy as np

sns.set_theme()

api = wandb.Api()


### Init Distill

In [None]:
run = api.run('safelix/DINO/runs/t951d1ss')
hist = run.history(keys=['probe/student', 'probe/teacher'])
s_acc = hist['probe/student']
t_acc = hist['probe/teacher']

run = api.run('safelix/DINO/runs/1jexlm4j')
hist = run.history(keys=['probe/student', 'probe/teacher'])
base = pd.concat((hist['probe/student'], hist['probe/teacher'])).mean()

# plot
ax = plt.figure(figsize=(7, 3.5)).gca()
ax.plot(t_acc.index, t_acc, label='teacher')
ax.plot(s_acc.index, s_acc, label='student')
ax.hlines(base, 0, 49, colors='gray', linestyles='--', label='linear')
plt.legend()

ax.set_xlim(0, 50)
ax.set_ylim(-0.0, 1)
ax.set_xticks([5 + 10*i for i in range(5)], minor=True)
ax.set_yticks([0.1 + 0.2*i for i in range(5)], minor=True)
ax.grid(visible=True, which='minor', color='w', linewidth=0.25)

ax.set_xlabel('epoch')
ax.set_ylabel('probing accuracy')
#plt.suptitle('Linear Probing Accuracy')

plt.savefig('radiant-wildflower_probe_v2.pdf', bbox_inches='tight')



In [None]:
# workaround to get runs associated with sweep https://github.com/wandb/wandb/issues/3347
sweep = api.sweep('safelix/DINO/sweeps/am04wjt4')
non_defining_config_keys = ['logdir', 'from_json', 'seed', 'out_dim']
order_by = 'config.out_dim.value'

def filter_from_config(config, exclude_keys):
    for key in exclude_keys:
        if key in config.keys():
            del config[key]
    
    return [{f'config.{key}':val} for (key, val) in config.items()]

filters = filter_from_config(sweep.runs[0].config, non_defining_config_keys)
runs = api.runs('safelix/DINO', filters={"$and": filters}, order=order_by)

print('Found', len(runs), 'runs:')
print(runs.objects)
runs = runs.objects[:12]

In [None]:
from matplotlib.cm import ScalarMappable, viridis, viridis_r
from matplotlib.colors import BoundaryNorm, NoNorm

f, ax = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(10, 3.5))

dims = [run.config['out_dim'] for run in runs]
cmap = ScalarMappable(norm=BoundaryNorm(dims, viridis_r.N, extend='max'), cmap=viridis_r)

# get baseline
run = api.run('safelix/DINO/runs/1jexlm4j')
hist = run.history(keys=['probe/student', 'probe/teacher'])
base = pd.concat((hist['probe/student'], hist['probe/teacher'])).mean()
ax[0].hlines(base, 0, 49, colors='gray', linestyles='--', label='linear')
ax[1].hlines(base, 2**4.5, 2**16.5, colors='gray', linestyles='--', label='linear')

# Make plots 
max_acc = []
for run in runs:
    hist = run.history(keys=['probe/student', 'probe/teacher'])
    s_acc = hist['probe/student']
    out_dim = run.config['out_dim']
    max_acc.append(s_acc.max())

    ax[0].plot(s_acc.index, s_acc, label=out_dim, color=cmap.to_rgba(out_dim))
    ax[1].scatter(out_dim, s_acc.max(), label=out_dim, color=cmap.to_rgba(out_dim))
ax[1].plot(dims, max_acc, color='gray', linestyle='-', linewidth=1)


# Format Left Plot
#plt.colorbar(mappable=cmap, ticks=dims, extendrect=True, drawedges=True)
#ax[0].legend(bbox_to_anchor=(1, 0.5), loc='center left', frameon=False,  borderpad=0)
ax[0].set_xlim(0, 50)
ax[0].set_ylim(0.3, 0.5)
ax[0].set_xlabel('epoch')
ax[0].set_ylabel('probing accuracy')

# Format Left Plot
#plt.colorbar(mappable=cmap, ticks=dims, extendrect=True, drawedges=True)
handles, labels = ax[0].get_legend_handles_labels()
ax[1].legend(handles[::-1], labels[::-1], bbox_to_anchor=(1, 0.5), loc='center left', frameon=False,  borderpad=0)
ax[1].set_xscale('log', base=2) 
ax[1].set_xlim(2**4.5, 2**16.5)
ax[1].set_ylim(0.3, 0.5)
ax[1].set_xlabel('size')
ax[1].set_ylabel('best accuracy')


plt.subplots_adjust(wspace=0.1, hspace=0)

plt.savefig('init-distill-outdim.pdf', bbox_inches='tight')


In [None]:

def plot(enc, cfgs):

    sweep = api.sweep('safelix/DINO/sweeps/bqs03n6g') # L2Bottleneck

    # gather groups
    groups = {}
    for run in sweep.runs:
        if run.config['enc'] != enc:
            continue
        if run.config['l2bot_cfg'] not in groups.keys():
            groups[run.config['l2bot_cfg']] = []
        groups[run.config['l2bot_cfg']].append(run)


    # plot groups
    f, (ax, legax) = plt.subplots(nrows=1, ncols=2, width_ratios=[0.7, 0.3], figsize=(10, 3.5))
    for cfg in cfgs:
        if cfg not in groups.keys():
            continue

        hists = []
        for run in groups[cfg]:  
            hist:pd.DataFrame = run.history(keys=['probe/student'])
            if not hist.empty:
                s_acc = hist['probe/student']
                s_acc.name = run.config['enc_seed']
                hists.append(s_acc)

        print(f'{cfg}: {len(hists)} runs')
        print(f'=> seeds : {[h.name for h in hists]}')
        if len(hists) == 0:
            continue 
        hists = pd.concat(hists, axis='columns')

        # plot mean and std
        hists_mean = hists.mean(axis=1, skipna=True)
        hists_std = hists.std(axis=1, skipna=True).fillna(0.0)
        ax.plot(hists.index, hists_mean, label=cfg) #  (#{hists.shape[1]})
        ax.fill_between(hists.index, hists_mean - hists_std, hists_mean + hists_std, alpha=0.2)

    # make the legend table
    ax.plot([0], [0], visible=False, label='empty')
    handles, labels = ax.get_legend_handles_labels()
    empty = handles[-1]
    titles = ['wn1', 'lin1', 'fn1', 'wn2', 'lin2', 'fn2']

    nrow, ncol = len(cfgs), len(titles)
    cells = nrow*ncol*[None]
    marks = nrow*ncol*[None]
    for row in range(nrow):
        for col in range(ncol):
            idx = row + nrow*col

            cells[idx] = ''  # make cells
            if row>=1 and labels[row-1].split('/')[col] != '-': 
                cells[idx] = ' $\checkmark$ '

            marks[idx] = empty  # make marks
            if col == 0 and row>=1: # ignore header
                marks[idx] = handles[row-1]

            if row==0:  # make titles
                cells[idx] = titles[col]

    legax.axis('off')
    legend = legax.legend(marks, cells, ncol=ncol, columnspacing=-2, labelspacing=1, #handletextpad=1.5, 
                        loc='center', frameon=True, facecolor='w', borderpad=1, alignment='center')

    for text in legend.texts:
        if 'checkmark' in text.get_text():
            text.set_backgroundcolor(ax.get_facecolor())
        if text.get_text() in titles:
            text.set_fontweight('bold')
            text.set_family('monospace')

    #plt.subplots_adjust(wspace=0.1)
    if enc == 'resnet18':
        ax.set_ylim(0.19, 0.51)
    if enc == 'vgg11':
        ax.set_ylim(0.08, 0.53)
    ax.set_ylabel('probing accuracy')
    ax.set_xlabel('epoch')
    ax.set_title(enc)

    plt.subplots_adjust(hspace=0)
    plt.savefig(f'init-distill-l2bot-{enc}.pdf', bbox_inches='tight')

cfgs = ['-/lb/fn/wn/l/-', 'wn/lb/-/wn/l/-', '-/lb/fn/-/l/fn', '-/lb/-/-/l/fn', '-/lb/-/-/l/-', '-/-/-/-/l/fn', '-/-/-/-/l/-']
plot(enc='resnet18', cfgs=cfgs)
plot(enc='vgg11', cfgs=cfgs)

In [None]:
sweep = api.sweep('safelix/DINO/sweeps/mgv2qemx')   # NormLayer

norm_names = ['BatchNorm', 'InstanceNorm', 'GroupNorm8', 'LayerNorm', 'Identity']

# gather groups
groups = {'resnet18':None, 'vgg11':None}
for run in sweep.runs:
    if groups[run.config['enc']] is None:
        groups[run.config['enc']] = {}

    if run.config['enc_norm_layer'] not in groups[run.config['enc']].keys():
        groups[run.config['enc']][run.config['enc_norm_layer']] = []  
    groups[run.config['enc']][run.config['enc_norm_layer']].append(run)

# plot groups
f, ax = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(10, 3.5))
for idx, (enc_name, enc_group) in enumerate(groups.items()):
    for norm_name in norm_names:
        if enc_group is None or norm_name not in enc_group.keys():
            continue
        
        hists = []
        for run in enc_group[norm_name]:
            assert(run.config['enc_norm_layer'] == norm_name)
            hist:pd.DataFrame = run.history(keys=['probe/student'])
            if not hist.empty:
                s_acc = hist['probe/student']
                s_acc.name = run.config['enc_seed']
                hists.append(s_acc)
                
        print(f'{enc_name}, {norm_name}: {len(hists)} runs' )
        print(f'=> seeds : {[h.name for h in hists]}')
        if len(hists) == 0:
            continue 
        hists = pd.concat(hists, axis='columns')

        # plot mean and std
        hists_mean = hists.mean(axis=1, skipna=True)
        hists_std = hists.std(axis=1, skipna=True).fillna(0.0)
        ax[idx].plot(hists.index, hists_mean, label=norm_name)  #{hists.shape[1]}
        ax[idx].fill_between(hists.index, hists_mean - hists_std, hists_mean + hists_std, alpha=0.2)
    ax[idx].set_xlabel('epoch')
    ax[idx].set_title(enc_name)


#ax[0].legend()
ax[1].legend()
ax[0].set_ylabel('probing accuracy')
ax[0].set_ylim(0.08, 0.53)

plt.subplots_adjust(wspace=0.1, hspace=0)
plt.savefig('init-distill-normlayer.pdf', bbox_inches='tight')


In [None]:
sweep = api.sweep('safelix/DINO/sweeps/cdig23bv')   # NormLayer

norm_names = ['BatchNorm', 'Identity']

# gather groups
groups = {'vgg11':None, 'resnet18':None}
for run in sweep.runs:
    if groups[run.config['enc']] is None:
        groups[run.config['enc']] = {}

    if run.config['enc_norm_layer'] not in groups[run.config['enc']].keys():
        groups[run.config['enc']][run.config['enc_norm_layer']] = []  
    groups[run.config['enc']][run.config['enc_norm_layer']].append(run)

# plot groups
f, ax = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(10, 3.5))
for idx, (enc_name, enc_group) in enumerate(groups.items()):
    for norm_name in norm_names:
        if enc_group is None or norm_name not in enc_group.keys():
            continue
        
        hists = []
        for run in enc_group[norm_name]:
            assert(run.config['enc_norm_layer'] == norm_name)
            hist:pd.DataFrame = run.history(keys=['probe/student'])
            if not hist.empty:
                s_acc = hist['probe/student']
                s_acc.name = run.config['enc_seed']
                hists.append(s_acc)
                
        print(f'{enc_name}, {norm_name}: {len(hists)} runs' )
        print(f'=> seeds : {[h.name for h in hists]}')
        if len(hists) == 0:
            continue 
        hists = pd.concat(hists, axis='columns')

        # plot mean and std
        hists_mean = hists.mean(axis=1, skipna=True)
        hists_std = hists.std(axis=1, skipna=True).fillna(0.0)
        ax[idx].plot(hists.index, hists_mean, label=norm_name)  #{hists.shape[1]}
        ax[idx].fill_between(hists.index, hists_mean - hists_std, hists_mean + hists_std, alpha=0.2)
    ax[idx].set_xlabel('epoch')
    ax[idx].set_title(enc_name)


#ax.legend()
ax[1].legend()
ax[0].set_ylabel('probing accuracy')
#ax.set_ylim(0.08, 0.53)

plt.subplots_adjust(wspace=0.1, hspace=0)
plt.savefig('init-distill-mse.pdf', bbox_inches='tight')


In [None]:
groups = {
    'resnet18':[
        '1gqyr8sa',
        '170lqug8',
        'qz5j8dva',
        '26qjvyrz',
        '3bkx4c2f',
        '15j3cwn5',
    ],
    'vgg11':[
        '1ozaihai',
        '1owvg6tb',
        '2mdna4pc',
        'eiur72q2',
        '3uim3e5y',
        '2lrm7pai',
    ]
}


# plot groups
f, ax = plt.subplots(nrows=1, ncols=2, sharey=False, figsize=(10, 3.5))
for idx, (enc_name, enc_group) in enumerate(groups.items()):

    for run_id in enc_group: 
        run = api.run(f'safelix/DINO/runs/{run_id}')
        hist:pd.DataFrame = run.history(keys=['probe/student'])
        if hist.empty:
            continue

        s_acc = hist['probe/student']
        s_acc.name = run.config['enc_seed']
                
        # plot mean and std
        name = ''
        name += run.config['opt'] + ':'
        name += ' wd=' + str(run.config['opt_wd'])
        if run.config['opt'] == 'adamw':
            name += ', betas=(' + str(run.config['opt_beta1']) + ', ' + str(run.config['opt_beta2']) + ')'

        if run.config['wn_freeze_epochs'] > 0:
            name += ', freeze_l2bot'
         

        ax[idx].plot(s_acc.index, s_acc, label=name)  #{hists.shape[1]}
    ax[idx].set_xlabel('epoch')
    ax[idx].set_title(enc_name)


#ax[0].legend()
ax[1].legend(bbox_to_anchor=(1, 0.5), loc='center left', frameon=True)#,  borderpad=1.0)

ax[0].set_ylabel('probing accuracy')
ax[1].set_ylabel('probing accuracy')
#ax[1].set_ylim(0.25, 0.53)

plt.subplots_adjust(hspace=0)
#plt.subplots_adjust(wspace=0.1, hspace=0)
plt.savefig('init-distill-opt.pdf', bbox_inches='tight')


In [None]:
from itertools import takewhile

max_epoch = 25
groups = {
    'resnet18':[
        api.run('safelix/DINO/runs/1gqyr8sa'),
        api.run('safelix/DINO/runs/170lqug8'),
        api.run('safelix/DINO/runs/qz5j8dva'),
        api.run('safelix/DINO/runs/26qjvyrz'),
        api.run('safelix/DINO/runs/3bkx4c2f'),
        api.run('safelix/DINO/runs/15j3cwn5'),
    ],
    'vgg11':[
        api.run('safelix/DINO/runs/1ozaihai'),
        api.run('safelix/DINO/runs/1owvg6tb'),
        api.run('safelix/DINO/runs/2mdna4pc'),
        api.run('safelix/DINO/runs/eiur72q2'),
        api.run('safelix/DINO/runs/3uim3e5y'),
        api.run('safelix/DINO/runs/2lrm7pai'),
    ]
}

shist_metrics = [
    'train/KL', # : 'kls',
    'params/norm(stud - init)', # : 'dists',
    'train/feat/embed/s_x.corr().rank()', #:'ranks',
]

for run1, run2 in zip(groups['resnet18'], groups['vgg11']):
    assert(run1.config['opt'] == run1.config['opt'])
    assert(run1.config['opt_beta1'] == run1.config['opt_beta1'])
    assert(run1.config['opt_beta2'] == run1.config['opt_beta2'])
    assert(run1.config['wn_freeze_epochs'] == run1.config['wn_freeze_epochs'])


# plot groups
f, ax = plt.subplots(nrows=1+len(shist_metrics), ncols=2, sharex=True, sharey='row', figsize=(10, 14))
for idx, (enc_name, enc_group) in enumerate(groups.items()):

    for run in tqdm(enc_group): 
        name = ''
        name += run.config['opt'] + ':'
        name += ' wd=' + str(run.config['opt_wd'])
        if run.config['opt'] == 'adamw':
            name += ', betas=(' + str(run.config['opt_beta1']) + ', ' + str(run.config['opt_beta2']) + ')'

        if run.config['wn_freeze_epochs'] > 0:
            name += ', freeze'

        # plot probing from epoch history
        ehist = run.history(keys=['probe/student'], x_axis='epoch')
        ehist = ehist.rename(columns={'probe/student':'accs'})
        ehist = ehist.head(max_epoch + 1)
        ehist['epoch'][1:] += 1

        ax[0][idx].plot(ehist['epoch'], ehist['accs'], label=name)

        for m_idx, metric in enumerate(shist_metrics):
            
            # get metric
            shist = run.scan_history(keys=['epoch', metric])
            shist = pd.DataFrame(takewhile(lambda elem: elem['epoch'] < max_epoch, shist))
            shist['epoch'] = np.linspace(0, shist['epoch'].iloc[-1] + 1, len(shist['epoch']))

            # remove outliers and smooth
            shist[metric] = shist[metric].rolling(3).median() 
            shist[metric] = shist[metric].ewm(alpha=0.1).mean()
            shist[metric] = shist[metric]

            # plot metric
            ax[1+m_idx][idx].plot(shist['epoch'], shist[metric], label=name)

    ax[-1][idx].set_xlabel('epoch')
    ax[0][1].set_title(enc_name)


ax[0][0].set_ylim(-0.01, 0.51)
ax[1][0].set_ylim(-1e-6, 2e-5)
ax[1][0].ticklabel_format(axis='y', style='sci')

ax[0][0].legend()
#ax[0][1].legend(bbox_to_anchor=(1, 0.5), loc='center left', frameon=True)#,  borderpad=1.0)

ax[0][0].set_ylabel('probing accuracy')
ax[1][0].set_ylabel('kl divergence')
ax[2][0].set_ylabel('distance from init')
ax[3][0].set_ylabel('embedding ranks')
    

plt.subplots_adjust(hspace=0.1, wspace=0.1)
plt.savefig('init-distill-opt-details.pdf', bbox_inches='tight')


In [None]:
sweep = api.sweep('safelix/DINO/sweeps/muerfvuw')   # NormLayer

# gather groups
groups = {}
for run in sweep.runs:
    width, depth = run.config['enc'].split('_')[-2:]
    width = int(width)
    depth = int(depth[:-1]) + 0.5 if depth[-1] == 'e' else int(depth)

    if width not in groups.keys():
        groups[width] = {}

    if depth not in groups[width].keys():
        groups[width][depth] = []  
    groups[width][depth].append(run)

# plot groups
f, ax = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(10, 3.5))
for idx, width in enumerate(reversed(groups.keys())):
    for depth in groups[width].keys():
        
        hists = []
        for run in groups[width][depth]:
            hist:pd.DataFrame = run.history(keys=['probe/student'])
            if not hist.empty:
                s_acc = hist['probe/student']
                s_acc.name = run.config['enc_seed']
                hists.append(s_acc)
                
        print(f'{width}, {depth}: {len(hists)} runs' )
        print(f'=> seeds : {[h.name for h in hists]}')
        if len(hists) == 0:
            continue 
        hists = pd.concat(hists, axis='columns')

        # plot mean and std
        hists_mean = hists.mean(axis=1, skipna=True)
        hists_std = hists.std(axis=1, skipna=True).fillna(0.0)
        line = ax[idx].plot(hists.index, hists_mean, label=depth)
        ax[idx].fill_between(hists.index, hists_mean - hists_std, hists_mean + hists_std, alpha=0.2, color=line[0].get_color())
        ax[idx].errorbar(-3, hists_mean[0], hists_std[0], capsize=2.5, marker='.', color=line[0].get_color())
    ax[idx].set_xlabel('epoch')
    ax[idx].set_title(f'convnet width {width}')


#ax[0].legend()
ax[0].legend(title='depth')
ax[0].set_ylabel('probing accuracy')
#ax[0].set_ylim(0.08, 0.53)

plt.subplots_adjust(wspace=0.1, hspace=0)
plt.savefig('init-distill-convnets.pdf', bbox_inches='tight')


In [None]:
from matplotlib.cm import ScalarMappable, viridis, viridis_r
from matplotlib.colors import BoundaryNorm, NoNorm

runs = api.sweep('safelix/DINO/sweeps/5dvw7a8c').runs

alphas = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

groups = {}
for run in runs:
    if run.config['t_init_alpha'] not in groups.keys():
        groups[run.config['t_init_alpha']] = []
    groups[run.config['t_init_alpha']].append(run)
    

cmap = ScalarMappable(norm=BoundaryNorm(alphas, viridis.N, extend='max'), cmap=viridis)

f, ax = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(10, 3.5))

# Make plots 
hists_max_means = {}
for alpha in reversed(alphas):

    hists = []
    for run in groups[alpha]:
        assert(run.config['t_init_alpha'] == alpha)
        hist:pd.DataFrame = run.history(keys=['probe/student'])
        if not hist.empty:
            s_acc = hist['probe/student']
            s_acc.name = run.config['enc_seed']
            hists.append(s_acc)
    
    print(f'{alpha}: {len(hists)} runs' )
    print(f'=> seeds : {[h.name for h in hists]}')
    if len(hists) == 0:
        continue 

    hists = pd.concat(hists, axis='columns')
    hists_mean = hists.mean(axis=1, skipna=True)
    hists_std = hists.std(axis=1, skipna=True).fillna(0.0)
    hists_max = hists.max(axis=0, skipna=True)
    hists_max_mean = hists_max.mean(skipna=True)
    hists_max_std = np.nan_to_num(hists_max.std(skipna=True))

    ax[0].plot(hists.index, hists_mean, label=alpha, color=cmap.to_rgba(alpha))
    ax[0].fill_between(hists.index, hists_mean - hists_std, hists_mean + hists_std, alpha=0.2, color=cmap.to_rgba(alpha))

    #ax[1].scatter(alpha, hists_max_mean, color=cmap.to_rgba(alpha))
    ax[1].errorbar(alpha, hists_max_mean, hists_max_std, marker='.', ms=15, ecolor='k', capsize=2.5, color=cmap.to_rgba(alpha))

    hists_max_means[alpha] = hists_max_mean
ax[1].plot(hists_max_means.keys(), hists_max_means.values(), color='gray', linestyle='-', linewidth=1)


# Format Left Plot
#plt.colorbar(mappable=cmap, ticks=dims, extendrect=True, drawedges=True)
#ax[0].legend(bbox_to_anchor=(1, 0.5), loc='center left', frameon=False,  borderpad=0)
#ax[0].set_xlim(0, 50)
#ax[0].set_ylim(0.3, 0.5)
ax[0].set_xlabel('epoch')
ax[0].set_ylabel('probing accuracy')

# Format Left Plot
#plt.colorbar(mappable=cmap, ticks=dims, extendrect=True, drawedges=True)
handles, labels = ax[0].get_legend_handles_labels()
ax[1].legend(handles[::-1], labels[::-1], bbox_to_anchor=(1, 0.5), loc='center left', frameon=False,  borderpad=0)
#ax[1].set_xscale('log', base=2) 
#ax[1].set_xlim(2**4.5, 2**16.5)
#ax[1].set_ylim(0.3, 0.5)
ax[1].set_xlabel('alpha')
ax[1].set_ylabel('best accuracy')


plt.subplots_adjust(wspace=0.1, hspace=0)

plt.savefig('init-distill-interpolate.pdf', bbox_inches='tight')


#### Guillotine Regularization

In [None]:
sweeps = [
    ('linear head', api.sweep('safelix/DINO/sweeps/qmhuwk3y'), 'C0'),
    ('512-512 head', api.sweep('safelix/DINO/sweeps/473v5acw'), 'C1'),
    ('2048-2048 head', api.sweep('safelix/DINO/sweeps/y1qqvpw1'), 'C2'),
]


ax = plt.figure(figsize=(10, 5)).gca()
h_probe, h_valid = [], []
for name, sweep, color in sweeps:
    xs, ys_probe, ys_valid = [], [], []
    for run in sweep.runs:
        xs.append(run.config['label_noise_ratio'])
        ys_probe.append(run.summary['probe/student']['max'])
        ys_valid.append(run.summary['valid/s_acc']['max'])

        hist = run.history(keys=['probe/student', 'valid/s_acc'])
        h_probe.append(hist['probe/student'][0])
        h_valid.append(hist['valid/s_acc'][0])
    ax.plot(xs, ys_probe, '-o', c=color, label=f'{name} probe')
    ax.plot(xs, ys_valid, '-D', c=color, label=f'{name} valid')

h_probe = sum(h_probe) / len(h_probe)
h_valid = sum(h_valid) / len(h_valid)
ax.hlines(h_probe, -0.05, 1.05, colors='gray', linestyles='--', label='random init')
ax.hlines(h_valid, -0.05, 1.05, colors='gray', linestyles='--')


ax.set_ylim(-0.05,0.75)
ax.set_xlim(-0.05,1.05)
ax.set_xticks(xs)

plt.legend()
plt.title('Guillotine Regularization under Label Noise')
plt.savefig('label_flipping_gr.pdf')
