In [None]:
%load_ext autoreload
%autoreload 2
import os
import torch
import wandb
import matplotlib
import matplotlib.pyplot as plt
import matplotlib
from matplotlib import cm
import seaborn as sns
import pandas as pd
from tqdm import tqdm
import numpy as np
from itertools import takewhile, islice


from typing import Dict, Iterable, Any
from numbers import Number

from wandb.apis.public import Run
Runs = Iterable[Run]
api = wandb.Api(timeout=19)

import dinopl.utils as U

#sns.set_theme()
#sns.set_theme(context='paper', font_scale=0.5)
#sns.set_style()


from tueplots import bundles, axes, figsizes, fonts, fontsizes
tue_params = bundles.icml2022(usetex=False, family='sans-serif') #sans-serif
tue_params.update(bundles.fontsizes.icml2022(default_smaller=1))
tue_params.update(figsizes.icml2022_half(tight_layout=True, pad_inches=0)) 
tue_params.update(axes.lines(base_width=0.25))
tue_params.update(axes.legend(frameon=True, fancybox=False))
#sns.set_style(style='darkgrid', rc=tue_params)
plt.rcParams.update(tue_params)
plt.style.use('seaborn-v0_8-darkgrid')
matplotlib.rcParams['text.latex.preamble'] = matplotlib.rcParams['text.latex.preamble'] + r" \usepackage{amsmath, amssymb}"


# Use latest verion of matplotlib-inline to avoid jupyter notebook bug:
#https://github.com/matplotlib/matplotlib/issues/9217
#https://github.com/ipython/ipykernel/issues/267

In [None]:
from cycler import cycler
from matplotlib import cm
from matplotlib.colors import Normalize, BoundaryNorm, Colormap
from warnings import warn

def alphanumsort(keys, reverse_sort=False):
    str_keys = [k for k in keys if isinstance(k, str)]
    nbr_keys = [k for k in keys if isinstance(k, Number)]
    sorted_keys = list(sorted(nbr_keys, reverse=reverse_sort))
    sorted_keys += list(sorted(str_keys, reverse=reverse_sort))
    return sorted_keys

def get_run_attr(run, key:str):
    if key.startswith('config.'):
        key = key[len('config.'):] # remove config.
        return run.config[key]
    if key.startswith('summary.'):
        key = key[len('summary.'):] # remove summary.
        return run.summary[key]
    if key.startswith('dino_config.'):
        key = key[len('dino_config.'):] # remove dino_config.
        return run.config['dino_config'][key]

def plot_agg(runs:Runs, group_by:str, metric:str, xmetric:str=None, labels:list=None, 
                group_keys=None, reverse_sort=False, filter:dict={}, colors:Colormap=None, 
                center='mean', spread='std', center_kwargs={}, spread_kwargs={}, plot_kwargs={}, 
                scan_step:int=None, max_step=None, ema_alpha:float=0., summarize='max', slice2summarize=slice(1, None),  ax=None,
                ):
    center_kwargs = plot_kwargs | center_kwargs      # use plot_kwargs as default for center_kwargs
    spread_kwargs = plot_kwargs | spread_kwargs      # use plot_kwargs as default for spread_kwargs
    spread_kwargs = dict(alpha=0.15) | spread_kwargs # set default alpha for spread


    if ax is None:
        fig, ax = plt.subplots(1, 1)

    groups:Dict[Any, Iterable[Run]] = {}
    for run in runs: 
        #if any([run.config[name] != value for name, value in filter.items()]):
        #    continue

        if any([get_run_attr(run, name) != value for name, value in filter.items()]):
            continue

        group_val = get_run_attr(run, group_by)
        #if group_by.startswith('dino_config.'): 
        #    group_val = run.config['dino_config'][group_by[len('dino_config.'):]]
        #else:
        #    group_val = run.config[group_by]

        # group by by config
        if group_val not in groups.keys():
            groups[group_val] = []
        groups[group_val].append(run)

    if group_keys is None:
        group_keys = alphanumsort(groups.keys())
    
    if len(groups.items()) == 0:
        warn('No runs to plot.')
        return None

    # make labels
    if labels is None:
        labels = group_keys
    if not isinstance(labels, list):
        labels = [labels] * len(group_keys)
    if len(labels) != len(group_keys):
        raise ValueError('Provided labels are not of same length as group keys.')
    
    # make colors
    if colors is None: # set default colorcycle
        colors = list(cm.tab10.colors)

    if isinstance(colors, (tuple, str, float)):
        colors = [colors] * len(group_keys)
    elif isinstance(colors, Iterable):
        cmapper = cycler(colors=colors)
        colors = [elem['c'] for _, elem in zip(group_keys, cycler(c=colors))]
    elif isinstance(colors, Colormap):
        color_keys = group_keys
        if not all(isinstance(k, (float, int)) for k in color_keys):
            color_keys = range(len(group_keys))
        #norm = Normalize(group_keys)              # linearly 
        #norm = BoundaryNorm(sorted(group_keys), colors.N, extend='max')    # discrete sorted numeric values
        norm = BoundaryNorm(sorted(color_keys), colors.N, extend='max')    # discrete sorted numeric values
        cmapper = cm.ScalarMappable(norm=norm, cmap=colors)
        colors = map(cmapper.to_rgba, color_keys)

    # iterate through groups
    summary = dict()
    n_aggregates = dict()
    for group_key, label, color in tqdm(list(zip(group_keys, labels, colors))):
        hists, hists_xs = [], [] # gather histories of runs in group
        for run in groups[group_key]:
            hist:pd.DataFrame = run.history(keys=[metric]+([] if xmetric is None else [xmetric]))

            if scan_step is not None:
                scan = run.scan_history(keys=[metric, 'trainer/global_step']+([] if xmetric is None else [xmetric]))
                if hist.empty:
                    continue
                hist = pd.DataFrame(islice(scan, 0, max_step, scan_step), dtype=float) # run.summary['trainer/global_step'] // scan_samples
                if xmetric == 'epoch':
                    hist[xmetric] = hist['trainer/global_step'] * hist[xmetric].iloc[-1] / hist['trainer/global_step'].iloc[-1]
                    hist[xmetric] = hist[xmetric].fillna(0.0)
            
                # sort values and interpolate xmetric
                hist = hist.sort_values('trainer/global_step', axis='index')
                
            if hist.empty:
                continue

            #hist[metric] = hist[metric].rolling(3).median() 
            if 1 >= ema_alpha and ema_alpha > 0: # ema smoothing for all but initial entries
                hist[metric] = hist[metric].ewm(alpha=ema_alpha).mean()

            if metric in hist.keys() and not hist[metric].empty:
                hists.append(hist[metric].rename(run.name))
            if xmetric in hist.keys() and not hist[xmetric].empty:
                hists_xs.append(hist[xmetric].rename(run.name))
        
        # make one dataframe for this group
        if len(hists) == 0:
            print(f'Run {run.name} in {group_key} has no history with metric \'{metric}\'.')
            continue
        n_aggregates[f'{label}'] = len(hists)

        hists = pd.concat(hists, axis='columns')
        if len(hists_xs) > 0:
            hists.index = max(hists_xs, key=len) # use longest xmetric as index

        # aggregate metric accross runs in group to represent the distribution center
        hists_center = None
        if center == 'mean':
            hists_center = hists.mean(axis=1, skipna=True)
        if center == 'median':
            hists_center = hists.median(axis=1, skipna=True)
        if hists_center is not None:
            ax.plot(hists.index, hists_center, color=color, label=label, **center_kwargs)

        # aggregate metric accross runs in group to represent the distribution spread 
        hists_lower, hists_upper = None, None
        if spread == 'minmax':
            hists_lower = hists.min(axis=1, skipna=True)
            hists_upper = hists.max(axis=1, skipna=True)
        if spread == 'std' and center == 'mean':
            hists_std = hists.std(axis=1, skipna=True).fillna(0.0)
            hists_lower = hists_center - hists_std
            hists_upper = hists_center + hists_std
        if hists_lower is not None:
            ax.fill_between(hists.index, hists_lower, hists_upper, color=color, **spread_kwargs)
        elif spread == 'samples':
            ax.plot(hists.index, hists, color=color, **spread_kwargs)

        summary[group_key] = {'centers':hists_center, 'colors':color, 'labels':label}
        if summarize is None:
            continue # process next group

        # summarize metric within run but ignore first ones
        if summarize == 'max':
            bests = hists.iloc[slice2summarize, :].max(axis=0, skipna=True)
        elif summarize == 'min':
            bests = hists.iloc[slice2summarize, :].min(axis=0, skipna=True)
        elif summarize == 'first':
            bests = hists.head(1)
        elif summarize == 'last':
            bests = hists.tail(1)
        else:
            raise ValueError('Unkown key for summarize.')

        summary[group_key]['medians'] = bests.median(skipna=True)
        summary[group_key]['means'] = bests.mean(skipna=True)
        summary[group_key]['stds']  = bests.std(skipna=True) 
        summary[group_key]['mins']  = bests.min(skipna=True)
        summary[group_key]['maxs']  = bests.max(skipna=True)
          
    print(f'aggregated: {n_aggregates}')
    return pd.DataFrame.from_dict(summary, orient='index')
    

#### Gradient Norm and Variance

In [None]:
max_step = 2500
xmetric = 'trainer/global_step'
metric = 'params/var(grad)'

kwargs = dict(group_by='config.enc',
                xmetric=xmetric,
                scan_step=1, 
                max_step=max_step,
                ema_alpha=1,
                center_kwargs=dict(alpha=0.8, lw=0.2),
                summarize='last',
            )

nrows, ncols = 1, 1
@plt.rc_context(figsizes.icml2022_half(nrows=nrows, ncols=ncols, height_to_width_ratio=figsizes._GOLDEN_RATIO))
def plot_gradvar(runs, filename):
    f, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=True)   

    ax.set_ylabel('gradient variance trace')
    for color, run in zip(cm.tab10.colors, runs):
        plot_agg(runs=[run], metric=metric, colors=color, labels=[f'{run.config["enc"]}: {run.config["s_mode"]}'], ax=ax, **kwargs)


    ax.legend(frameon=True)

    ax.set_xlim(-max_step // 20, max_step + max_step // 20)
    ax.set_yscale('log') 
    ax.set_xlabel('step')
    plt.savefig(filename)
    plt.show()


runs = [ 
    api.run('safelix/DINO/runs/3sg5hiwn'), # vgg11 + head in supervised
    api.run('safelix/DINO/runs/sl426j4i'), # resnet18 + head in supervised
    api.run('safelix/DINO/runs/1omtevqy'), # vgg11 + head in distillation
    api.run('safelix/DINO/runs/3ip3rj1c'), # resnet18 + head in distillation
]
plot_gradvar(runs, 'gradvar.pdf')

In [None]:
max_step = 2500
xmetric = 'trainer/global_step'
metrics = {'kl-divergence':'train/KL',
           'gradient norm': ['params/norm(grad)', 'params/head/norm(grad)', 'params/enc/norm(grad)'],
           'gradient variance trace' : ['params/var(grad)', 'params/head/var(grad)', 'params/enc/var(grad)'],
           'distance from init': ['params/norm(stud - init)', 'params/head/norm(stud - init)', 'params/enc/norm(stud - init)'],
           'linear probe vs rank' : ('probe/student', 'train/feat/embed/s_x.rank()')}

colors = cm.tab10.colors
kwargs = dict(group_by='config.enc',
                xmetric=xmetric,
                #colors=cm.viridis,
                scan_step=1, 
                max_step=max_step,
                ema_alpha=0.9,
                center_kwargs=dict(alpha=0.8, lw=0.2),
                summarize='last',
            )

nrows, ncols = len(metrics), 1
@plt.rc_context(figsizes.icml2022_half(nrows=nrows, ncols=ncols, height_to_width_ratio=figsizes._GOLDEN_RATIO))
def plot_early(run, filename):
    f, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=True)   

    for idx, (name, metric) in  enumerate(metrics.items()):
        ax[idx].set_ylabel(name)

        if isinstance(metric, tuple) and len(metric) == 2:
            twinax = ax[idx].twinx()
            plot_agg(runs=[run], metric=metric[1], labels=[metric[1]], colors=[colors[1]], ax=twinax, **kwargs)       
            plot_agg(runs=[run], metric=metric[0], labels=[metric[0]], colors=[colors[0]], ax=ax[idx], **kwargs)     
            handles1, labels1 = ax[idx].get_legend_handles_labels() 
            handles2, labels2 = twinax.get_legend_handles_labels()
            twinax.legend(handles1+handles2, labels1+labels2, frameon=True)
            continue

        elif isinstance(metric, list):
            for color, submetric in zip(colors, metric):
                plot_agg(runs=[run], metric=submetric, colors=color, labels=[submetric[7:]], ax=ax[idx], **kwargs)
                ax[idx].legend(frameon=True)
            continue

        plot_agg(runs=[run], metric=metric, ax=ax[idx], **kwargs)
        ax[idx].legend(frameon=True)


    ax[0].set_xlim(-max_step // 20, max_step + max_step // 20)
    ax[0].set_yscale('log') 
    ax[1].set_yscale('log') 
    ax[2].set_yscale('log') 
    ax[-1].set_xlabel('step')
    plt.savefig(filename)
    plt.show()
    

run = api.run('safelix/DINO/runs/1omtevqy') # vgg11 + head in distillation
plot_early(run, 'early-distillation-vgg11.pdf')


In [None]:
run = api.run('safelix/DINO/runs/3sg5hiwn') # vgg11 + head in supervised
plot_early(run, 'early-supervised-vgg11.pdf')

In [None]:
run = api.run('safelix/DINO/runs/3ip3rj1c') # resnet18 + head in distillation
plot_early(run, 'early-distillation-resnet18.pdf')

In [None]:
run = api.run('safelix/DINO/runs/sl426j4i') # resnet18 + head in supervised
plot_early(run, 'early-supervised-resnet18.pdf')

#### BatchSizes

In [None]:
xmetric = 'trainer/global_step'
metrics = {'kl-divergence':'train/KL', 
            'gradient norm':'params/norm(grad)', 
            'distance from init':'params/norm(stud - init)',
            'linear probing':'probe/student'
            }


nrows, ncols = len(metrics), 1
@plt.rc_context(figsizes.icml2022_half(nrows=nrows, ncols=ncols, height_to_width_ratio=figsizes._GOLDEN_RATIO))
def plot_batchsize_enc(runs, enc, filename=None, ax=None, kwargs=dict()):

    kwargs_defaults = dict(group_by='config.bs_train',
                xmetric=xmetric,
                colors=cm.viridis_r, 
                scan_step=1, 
                max_step=2500,
                ema_alpha=1,
                center_kwargs=dict(alpha=0.8, lw=0.2),
                summarize='last',
            )
    kwargs_defaults.update(kwargs)
    kwargs = kwargs_defaults


    if ax is None:
        f, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=True)
    ax = ax if nrows > 1 else [ax]

    for idx, (name, metric) in  enumerate(metrics.items()):
        summary = plot_agg(runs=runs, metric=metric, filter={'enc':enc}, ax=ax[idx], **kwargs)
        ax[idx].set_ylabel(name)
        if name in ['kl-divergence', 'gradient norm']:
            ax[idx].set_yscale('log') 

    if kwargs['max_step'] is not None:
        ax[0].set_xlim(-kwargs['max_step'] // 20, kwargs['max_step'] + kwargs['max_step'] // 20)
    
    legendcols = 6
    handles, labels = ax[0].get_legend_handles_labels()
    labels = [elem for i in range(legendcols) for elem in labels[i::legendcols]]    # transpose
    handles = [elem for i in range(legendcols) for elem in handles[i::legendcols]]  # transpose
    ax[0].legend(handles, labels, frameon=True, loc='upper center', handlelength=1, columnspacing=1, ncols=legendcols)

    if filename is not None:
        plt.savefig(filename)
        plt.show()

#plot_batchsize_enc(runs, enc='vgg11', filename='batchsize-vgg11.pdf')

In [None]:
runs = api.sweep('safelix/DINO/sweeps/bn9nj3n2').runs
runs_batchaccum = api.sweep('safelix/DINO/sweeps/0wwrnkaj').runs

nrows, ncols = len(metrics), 2
with plt.rc_context(figsizes.icml2022_full(nrows=nrows, ncols=ncols, height_to_width_ratio=figsizes._GOLDEN_RATIO)):
    f, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey='row')
    plot_batchsize_enc(runs, enc='vgg11', ax=[a[0] for a in ax])
    plot_batchsize_enc(runs, enc='resnet18', ax=[a[1] for a in ax])

    # plot full batch training
    kwargs=dict(colors='red', labels=[runs_batchaccum[1].config['bs_train'] * 192]) # runs_batchaccum[0].config['batchaccum']
    plot_batchsize_enc(runs_batchaccum, enc='vgg11', ax=[a[0] for a in ax], kwargs=kwargs)
    plot_batchsize_enc(runs_batchaccum, enc='resnet18', ax=[a[1] for a in ax], kwargs=kwargs)

    [a[1].set_ylabel('') for a in ax]
    plt.savefig('batchsize.pdf')
    plt.show()

In [None]:
runs = api.sweep('safelix/DINO/sweeps/bn9nj3n2').runs
runs_batchaccum = api.sweep('safelix/DINO/sweeps/0wwrnkaj').runs

metrics = {'kl-divergence':'train/KL', 
            #'gradient norm':'params/norm(grad)', 
            #'distance from init':'params/norm(stud - init)',
            'linear probing':'probe/student'
            }

nrows, ncols = len(metrics), 2
with plt.rc_context(figsizes.icml2022_full(nrows=nrows, ncols=ncols, height_to_width_ratio=figsizes._GOLDEN_RATIO)):
    f, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey='row')
    plot_batchsize_enc(runs, enc='vgg11', ax=[a[0] for a in ax])
    #plot_batchsize_enc(runs, enc='resnet18', ax=[a[1] for a in ax])

    # plot full batch training
    kwargs=dict(colors='red', labels=[runs_batchaccum[1].config['bs_train'] * 192]) # runs_batchaccum[0].config['batchaccum']
    plot_batchsize_enc(runs_batchaccum, enc='vgg11', ax=[a[0] for a in ax], kwargs=kwargs)
    #plot_batchsize_enc(runs_batchaccum, enc='resnet18', ax=[a[1] for a in ax], kwargs=kwargs)

    [a[1].set_ylabel('') for a in ax]
    plt.savefig('batchsize-vgg11.small.pdf')
    plt.show()

In [None]:
with plt.rc_context(figsizes.icml2022_full()):
    data = {'vgg11':{}, 'resnet18':{}}

    for run in runs:
        epochs = run.summary['epoch'] + 1
        runtime_min = run.summary['_runtime'] / 60
        runtime_hours = run.summary['_runtime'] / 60 / 60

        data[run.config['enc']][run.config['bs_train']] = runtime_min / epochs

    plt.plot(*zip(*sorted(data['vgg11'].items())), '-o', label='vgg11')
    plt.plot(*zip(*sorted(data['resnet18'].items())), '-o', label='resnet18')
    plt.ylim(0, 5)
    plt.ylabel('minutes per epoch')
    plt.xlabel('training batch size')
    plt.legend()

bs_train = 256
n_epochs = 2500
print(f'vgg11: {data["vgg11"][bs_train]:.1f} min => {n_epochs} epochs in { n_epochs * data["vgg11"][bs_train] / 60 / 24 :.1f} days ')
print(f'resnet18: {data["resnet18"][bs_train]:.1f} min => {n_epochs} epochs in { n_epochs * data["resnet18"][bs_train] / 60 / 24  :.1f} days')

In [None]:
config = api.run('safelix/DINO/runs/wbmnl6y4').config # initinterp
config1 = api.run('safelix/DINO/runs/1omtevqy').config # gradvar
config2 = api.run('safelix/DINO/runs/7t70ov3c').config # batchsizes
config3 = api.run('safelix/DINO/runs/gjynwffs').config # batchsizes accum


In [None]:
def dict2set(d):
    return {('key', v) for k, v in d.items()}

def dictdiff(A:dict, B:dict):
    return {k:v for k,v in A.items() if k not in B or v != B[k]}

#### SGD + Learning Rate

In [None]:
xmetric = 'trainer/global_step'
metrics = {'kl-divergence':'train/KL', 
            'gradient norm':'params/norm(grad)', 
            'distance from init':'params/norm(stud - init)',
            'linear probing':'probe/student'
            }

nrows, ncols = len(metrics), 1
@plt.rc_context(figsizes.icml2022_half(nrows=nrows, ncols=ncols, height_to_width_ratio=figsizes._GOLDEN_RATIO))
def plot_lr(runs, enc, normlayer=None, filename=None, ax=None, kwargs=dict()):

    kwargs_defaults = dict(group_by='config.opt_lr',
                xmetric=xmetric,
                colors=cm.viridis_r,
                scan_step=1, 
                max_step=2500,
                ema_alpha=1,
                center_kwargs=dict(alpha=0.8, lw=0.2),
                summarize='last',
            )
    kwargs_defaults.update(kwargs)
    kwargs = kwargs_defaults

    if ax is None:
        f, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=True)
    ax = ax if nrows > 1 else [ax]

    summary = dict()
    for idx, (name, metric) in  enumerate(metrics.items()):
        summary[metric] = plot_agg(runs=runs, metric=metric, filter={'config.enc':enc, 'config.enc_norm_layer':normlayer}, 
                                   ax=ax[idx], **kwargs)
        ax[idx].set_ylabel(name)
        if metric in ['train/KL', 'params/norm(grad)', 'train/MSE']:
            ax[idx].set_yscale('log') 

    if kwargs['max_step'] is not None:
        ax[0].set_xlim(-kwargs['max_step'] // 20, kwargs['max_step'] + kwargs['max_step'] // 20)
    
    legendcols = 5
    handles, labels = ax[0].get_legend_handles_labels()
    labels = [elem for i in range(legendcols) for elem in labels[i::legendcols]]    # transpose
    handles = [elem for i in range(legendcols) for elem in handles[i::legendcols]]  # transpose
    ax[0].legend(handles, labels, frameon=True, loc='upper center', handlelength=1, columnspacing=1, ncols=legendcols)

    if filename is not None:
        plt.savefig(filename)
        plt.show()

    return summary

runs = api.sweep('safelix/DINO/runs/eg0tf2oo').runs # SGD LR Gridsearch (CE)
plot_lr(runs, enc='vgg11', normlayer='Identity', filename='sgd-lr-ce-vgg11.pdf')

In [None]:
runs = api.sweep('safelix/DINO/runs/eg0tf2oo').runs # SGD LR Gridsearch (CE)

nrows, ncols = len(metrics), 2
with plt.rc_context(figsizes.icml2022_full(nrows=nrows, ncols=ncols, height_to_width_ratio=figsizes._GOLDEN_RATIO)):
    f, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey='row')
    plot_lr(runs, enc='vgg11', normlayer='Identity', ax=[a[0] for a in ax])
    plot_lr(runs, enc='resnet18', normlayer='BatchNorm', ax=[a[1] for a in ax])
    ax[0][0].set_ylim(None, 1e-3)
    ax[2][0].set_ylim(-10, 200)

    [a[1].set_ylabel('') for a in ax]
    plt.savefig('sgd-lr-ce.pdf')
    plt.show()

In [None]:
runs = api.sweep('safelix/DINO/runs/eg0tf2oo').runs # SGD LR Gridsearch (CE)

metrics = {'kl-divergence':'train/KL', 
            #'gradient norm':'params/norm(grad)', 
            #'distance from init':'params/norm(stud - init)',
            'linear probing':'probe/student'
            }

nrows, ncols = len(metrics), 2
with plt.rc_context(figsizes.icml2022_full(nrows=nrows, ncols=ncols, height_to_width_ratio=figsizes._GOLDEN_RATIO)):
    f, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey='row')
    plot_lr(runs, enc='vgg11', normlayer='Identity', ax=[a[0] for a in ax])
    plot_lr(runs, enc='resnet18', normlayer='BatchNorm', ax=[a[1] for a in ax])
    ax[0][0].set_ylim(None, 1e-3)
    #ax[2][0].set_ylim(-10, 200)

    [a[1].set_ylabel('') for a in ax]
    plt.savefig('sgd-lr-ce.small.pdf')
    plt.show()

In [None]:
runs = api.sweep('safelix/DINO/runs/90k67dq4').runs # SGD LR Gridsearch (MSE)

xmetric = 'trainer/global_step'
metrics = {'mean squared error':'train/MSE', 
            'gradient norm':'params/norm(grad)', 
            'distance from init':'params/norm(stud - init)',
            'linear probing':'probe/student'
            }

nrows, ncols = len(metrics), 2
with plt.rc_context(figsizes.icml2022_full(nrows=nrows, ncols=ncols, height_to_width_ratio=figsizes._GOLDEN_RATIO)):
    f, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey='row')
    plot_lr(runs, enc='vgg11', normlayer='Identity', ax=[a[0] for a in ax])
    plot_lr(runs, enc='resnet18', normlayer='BatchNorm', ax=[a[1] for a in ax])
    ax[0][0].set_ylim(None, 1e0)
    ax[2][0].set_ylim(-10, 200)

    [a[1].set_ylabel('') for a in ax]
    plt.savefig('sgd-lr-mse.pdf')
    plt.show()

### Investigate large LR with Simplified Head 

#### LR Range Test

In [None]:
nrows, ncols = 1, 1
@plt.rc_context(figsizes.icml2022_half(nrows=nrows, ncols=ncols, height_to_width_ratio=figsizes._GOLDEN_RATIO))
def plot_rangetest(runs, labels=None, colors=None, filename=None, ax=None):

    if ax is None:
        f, ax = plt.subplots()

    if isinstance(runs, wandb.apis.public.Run):
        runs = [runs]

    if isinstance(labels, str):
        labels = [labels] * len(runs)

    if isinstance(colors, (tuple, str, float)):
        colors = [colors] * len(runs)

    agg = pd.DataFrame()
    for idx, run in enumerate(runs):
        loss = run.config['loss']
        loss = loss if  loss != 'CE' else 'KL' # replace CE -> KL
        metric = f'train/{loss}' #


        #plot_agg(runs=[run], metric=metric, xmetric='hparams/lr', ax=ax, **kwargs)
        df = pd.DataFrame(run.scan_history(keys=[metric, 'hparams/lr']), dtype=float).set_index('hparams/lr')
        
        label = labels[idx] if labels and idx < len(labels) else metric
        color = colors[idx] if colors and idx < len(colors) else None
        ax.plot(df.index, df[metric], label=label, color=color)
        
        df = df.rename(columns={metric:run.name})
        agg = pd.concat([agg, df], axis=1).sort_index()
        

    ax.set_xlabel('step size')
    ax.set_xscale('log') 

    ax.set_ylabel('train loss')
    ax.set_yscale('log') 
    ax.legend(loc='upper left')

    if filename is not None:
        plt.savefig(filename)
        plt.show()

    return agg

nrows, ncols = 1, 1
with plt.rc_context(figsizes.icml2022_half(nrows=nrows, ncols=ncols, height_to_width_ratio=figsizes._GOLDEN_RATIO)):
    runs = api.sweep('safelix/DINO/sweeps/yjeg6huc').runs # KL vs MSE

    agg = plot_rangetest(runs, filename='lrrangetest-kl-vs-mse.pdf')
    plt.show()

#### LR Gridsearch

In [None]:
from matplotlib import cm

runs = api.sweep('safelix/DINO/runs/xduatpbk').runs # SGD LR Gridsearch: Linear Head, MSE
runs_fn = api.sweep('safelix/DINO/runs/4cu7u7mo').runs # SGD LR Gridsearch: Linear Head, FN+MSE

xmetric = 'trainer/global_step'
metrics = {'mean squared error':'train/MSE', 
            #'gradient norm':'params/norm(grad)', 
            #'distance from init':'params/norm(stud - init)',
            'linear probing':'probe/student'
            }

nrows, ncols = len(metrics), 2
with plt.rc_context(figsizes.icml2022_full(nrows=nrows, ncols=ncols, height_to_width_ratio=figsizes._GOLDEN_RATIO)):
    f, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey='row')
    summary = plot_lr(runs, enc='vgg13', ax=[a[0] for a in ax], kwargs=dict(max_step=None, xmetric='epoch', colors=cm.viridis_r))
    summary_fn = plot_lr(runs_fn, enc='vgg13', ax=[a[1] for a in ax], kwargs=dict(max_step=None, xmetric='epoch', colors=cm.inferno_r))
    ax[0][0].set_ylim(None, 1e0)
    #ax[2][0].set_ylim(-10, 200)

    #[a[1].set_ylabel('') for a in ax]
    plt.savefig('sgd-lr-linhead-mse.pdf')
    plt.show()

In [None]:
summary = next(iter(summary.values()))
summary_fn = next(iter(summary_fn.values()))

colors = {idx: summary['colors'][idx] for idx in summary.index if not isinstance(idx, str)}
colors_fn = {idx: summary_fn['colors'][idx] for idx in summary_fn.index if not isinstance(idx, str)}
colors

In [None]:
from matplotlib import cm

nrows, ncols = 1, 1
with plt.rc_context(figsizes.icml2022_half(nrows=nrows, ncols=ncols, height_to_width_ratio=figsizes._GOLDEN_RATIO)):
    runs = api.sweep('safelix/DINO/sweeps/lsy74rit').runs # FN vs no FN
    run = api.run('safelix/DINO/runs/iprry987')
    run_fn = api.run('safelix/DINO/runs/9wxlksbt')


    f, ax = plt.subplots(nrows, ncols, sharex=True, sharey=True)
    if isinstance(ax, matplotlib.axes.Axes):
        ax = (ax, ax)
    agg = plot_rangetest(run, labels='MSE', colors=tuple(cm.viridis.colors[160]), ax=ax[0])   
    agg_fn = plot_rangetest(run_fn, labels='NMSE', colors=tuple(cm.inferno.colors[160]), ax=ax[1])
    
    ymin, ymax = ax[0].get_ylim()
    xs = np.array(list(colors.keys()))
    ys = np.interp(xs, agg.index, agg.values.squeeze())

    ax[0].annotate('$\\frac{2}{L}$', (xs[0],ys[0]), xytext=(-1, 6), textcoords='offset pixels', fontsize=7)
    #ax[0].scatter(xs[0], ys[0], marker='|', s=10, c=list(colors.values())[0], zorder=10)
    ax[0].scatter(xs[0:], ys[0:], marker='x', s=10, c=list(colors.values())[0:], zorder=10)
    #for x, y, c in zip(xs, ys, colors.values()):
        #ax[0].vlines(x, ymin, ymax, color=c)
        #ax[0].annotate(x, (x,y), xytext=(0, 20), textcoords='offset pixels', fontsize=5)

    xs = np.array(list(colors_fn.keys()))
    ys = np.interp(xs, agg_fn.index, agg_fn.values.squeeze())

    ax[1].annotate('$\\frac{2}{L}$', (xs[0],ys[0]), xytext=(-1, -8), textcoords='offset pixels', fontsize=7)
    #ax[1].scatter(xs[0], ys[0], marker='|', s=10, c=list(colors_fn.values())[0], zorder=10)
    ax[1].scatter(xs[0:], ys[0:], marker='x', s=10, c=list(colors_fn.values())[0:], zorder=10)
    #for lr, color in colors_fn.items():
    #    ax[1].vlines(lr, ymin, ymax, color=color)

    ax[0].set_xlim(10**-3, 10**6)
    ax[0].set_ylim(None, 10**10)

    plt.savefig('lrrangetest-mse-vs-nmse.pdf')
    plt.show()

### Large Experiment: MSE vs ND-MSE vs N-MSE for different LRs and schedules

#### Colors

In [None]:
from matplotlib import cm
from matplotlib.colors import LinearSegmentedColormap
print(255*np.array(cm.tab10(3)))
cm.tab10

In [None]:
import numpy as np
from colorsys import hls_to_rgb
from matplotlib.colors import LinearSegmentedColormap, ListedColormap

def make_colorscale(h=223, dh=0, l=0.33, dl=0., s=1., ds=0, beta=1, num=256):
    h = np.linspace(h-dh, h+dh, num=num) / 360
    l = (l-dl) + (dl+dl) * np.linspace(0., 1., num=num) ** beta 
    s = (s-ds) + (ds+ds) * np.linspace(0., 1., num=num) ** beta
    hls = np.stack([h, np.flip(l), s], axis=1) # flip luminosity to start with bright
    rgb = np.apply_along_axis((lambda row: hls_to_rgb(*row)), axis=1, arr=hls)
    return ListedColormap(colors=rgb)

make_colorscale(h=223, l=0.33, dh=60, dl=0.20, beta=5.)

In [None]:
blue = hls_to_rgb(h=223/360,l=0.33, s=1)
bluescale = make_colorscale(h=223, l=0.33, dl=0.2, beta=1.5)
bluescale

In [None]:
red = hls_to_rgb(h=(223 + 120)/360, l=0.33, s=1)
redscale = make_colorscale(h=223 + 120, l=0.33, dl=0.2, beta=1.5)
redscale

In [None]:
green = hls_to_rgb(h=(223 - 120)/360, l=0.25, s=1)
greenscale = make_colorscale(h=223 - 120, l=0.22, dl=0.15, beta=1.5)
greenscale

#### LR Range Test

In [None]:
def plot_rangetest(labels, colors, group_keys, colored_dots={}, ax=None):
    runs = api.sweep('DINO/sweeps/hyfq2un6').runs

    if ax is None:
        f, ax = plt.subplots()

    summary = plot_agg(runs, group_by='config.l2bot_cfg', metric='train/loss', xmetric='hparams/lr', 
                    labels=labels, colors=colors, group_keys=group_keys, center='median', spread='minmax', 
                    center_kwargs=dict(alpha=0.6), spread_kwargs=dict(alpha=0.15), scan_step=1, ax=ax)
    
    for group_key, xs_colors in colored_dots.items():
        line = summary['centers'][group_key]
        xs = np.array(list(xs_colors.keys()))
        ys = np.interp(xs, line.index, line.values.squeeze())
        ax.scatter(xs, ys, marker='x', s=10, c=list(xs_colors.values()), zorder=10)

    ax.set_xlabel('step size')
    ax.set_xscale('log')
    ax.set_xlim(1e-3, 1e6)

    ax.set_ylabel('train loss')
    ax.set_yscale('log')
    ax.set_ylim(1e-16, 1e6)
    ax.legend(loc='upper left')
    return ax

labels = ['MSE', 'ND-MSE', 'N-MSE']
colors = [blue, red, green]
group_keys = ['-/-/-/-/lb/-', '-/-/-/-/lb/fnd', '-/-/-/-/lb/fn']
ax = plot_rangetest(labels=labels, colors=colors, group_keys=group_keys)
plt.savefig('lrrangetest.pdf')
plt.show()

#### LR Gridsearch

In [None]:
from matplotlib import cm
from dinopl.scheduling import Schedule, ConstSched, LinWarmup, CosWarmup

def make_colors(cmap, xs):
    norm = BoundaryNorm(sorted(xs), cmap.N, extend='max')
    cmapper = cm.ScalarMappable(norm=norm, cmap=cmap)
    return {x:cmapper.to_rgba(x) for x in xs}

def run2sortkey(run):
    sortkey1 = cfgs.index(run.config['l2bot_cfg'])
    sortkey2 = scheds.index(type(Schedule.parse(run.config['opt_lr'])))
    sortkey3 = run.summary['hparams/lr']
    return sortkey1, sortkey2, sortkey3


# load and prepare hyperparameters and 
sweep = api.sweep('safelix/DINO/sweeps/deoormth')
lrs = sweep.config['parameters']['opt_lr']['values']
lrs = [lr for lr in lrs if isinstance(lr, (int, float))]

scheds = [ConstSched, LinWarmup, CosWarmup]
sched2ls = {sched:ls for sched, ls in zip(scheds, ['-','--', '-.'])}

losses = ['MSE', 'ND-MSE', 'N-MSE']
cfgs = ['-/-/-/-/lb/-', '-/-/-/-/lb/fnd', '-/-/-/-/lb/fn']
cfg2loss = {cfg:loss for cfg, loss in zip(cfgs, losses)}
cfg2color = {cfg:scale for cfg, scale in zip(cfgs, [blue, red, green])}
cfg2cmap = {cfg:scale for cfg, scale in zip(cfgs, [bluescale, redscale, greenscale])}
cfglr2color = {cfg:make_colors(cfg2cmap[cfg], lrs) for cfg in cfgs}

runs = list(sorted(sweep.runs, key=run2sortkey))

In [None]:
group_keys = cfgs
labels = [cfg2loss[cfg] for cfg in cfgs]
colors = [cfg2color[cfg] for cfg in cfgs]
ax = plot_rangetest(group_keys=group_keys, labels=labels, colors=colors, colored_dots=cfglr2color)
plt.savefig('lrrangetest-dots.pdf')
plt.show()

In [None]:
from IPython.display import clear_output
from dinopl.scheduling import Schedule, ConstSched, LinWarmup, CosWarmup

f, ax = plt.subplots()
for run in sorted(sweep.runs, key=run2sortkey):
    sched = type(Schedule.parse(run.config['opt_lr']))

    if run.config['l2bot_cfg'] != '-/-/-/-/lb/fn':
        continue

    if sched not in [LinWarmup, CosWarmup]:
        continue
    
    ls = sched2ls[sched]
    plot_agg([run], group_by='summary.hparams/lr', metric='hparams/lr', xmetric='epoch', colors=['gray'], labels=[None],
                center='mean', spread=None, center_kwargs=dict(alpha=0.6, ls=ls), scan_step=100, ax=ax)

for sched in [LinWarmup, CosWarmup]:
    ax.plot(0, 0, label=sched.__name__, color='gray', ls=sched2ls[sched])
ax.legend(loc='lower right', frameon=True)

ax.set_xlabel('epoch')
ax.set_ylabel('stepsize')
ax.set_ylim(-0.05 * lrs[-3], 1.05 * lrs[-3])

plt.savefig('lrs.pdf')

clear_output()
plt.show()

In [None]:
def plot_gridsearch_epochs(filter_lrs=lrs, filter_losses=['MSE', 'N-MSE', 'ND-MSE'], 
                    filter_sched=(ConstSched, LinWarmup, CosWarmup), pseudolabels=[], ax=None):
    if ax is None:
        f, ax = plt.subplots()
    
    summary = {}
    for run in sorted(sweep.runs, key=run2sortkey):
        lr = run.summary['hparams/lr']
        sched = type(Schedule.parse(run.config['opt_lr']))
        cfg = run.config['l2bot_cfg']
        loss = cfg2loss[cfg]

        if lr not in filter_lrs:
            continue
        if loss not in filter_losses:
            continue
        if sched not in filter_sched:
            continue
        
        label = run.name if len(pseudolabels)==0 else None # make pseudolabels
        color, ls = cfglr2color[cfg][lr], sched2ls[sched]
        out = plot_agg([run], group_by='summary.hparams/lr', metric='probe/student', xmetric='epoch', colors=[color], labels=[label],
                    center='mean', spread=None, center_kwargs=dict(alpha=0.6, ls=ls), slice2summarize=slice(10, None), ax=ax)

        if cfg not in summary.keys():
            summary[cfg] = {}
        if sched not in summary[cfg].keys():
            summary[cfg][sched] = {}
        if lr in out['means'].index:
            summary[cfg][sched][lr] = np.nan_to_num(out['means'][lr], nan=0.1) # nan => diverged => 0.1 probing accuracy

    for kwargs in pseudolabels:
        ax.plot(0,0,**kwargs)

    ax.set_xlabel('epoch')
    ax.set_ylabel('probing accuracy')
    ax.set_ylim(0.05, 0.55)
    return summary, ax

plot_gridsearch_epochs()
#plt.legend(loc='upper left', frameon=True)
#plt.savefig(f'lrgridsearch.pdf')

clear_output()
plt.show()

In [None]:
from IPython.display import clear_output

def plot_gridsearch_summary(summary, ax=None):
    if ax is None:
        f, ax = plt.subplots()

    for cfg in summary.keys():
        for sched in summary[cfg].keys():
            xs = list(summary[cfg][sched].keys())
            ys = list(summary[cfg][sched].values())
            ax.plot(xs, ys, color=cfg2color[cfg], ls=sched2ls[sched], alpha=0.6)
            ax.scatter(xs, ys, color=[cfglr2color[cfg][x] for x in xs], alpha=0.6, s=8)
    
    ax.set_xlabel('step size')
    ax.set_xscale('log')
    xoffset = (lrs[-1]/lrs[0])**0.05
    ax.set_xlim(lrs[0]/xoffset, lrs[-1]*xoffset)

    ax.set_ylabel('probing accuracy')
    ax.set_ylim(0.05, 0.55)
    return ax


summary, ax = plot_gridsearch_epochs()
plt.close()
clear_output()

plot_gridsearch_summary(summary)
plt.show()

In [None]:
from IPython.display import clear_output

def plot_gridsearch(filter_lrs=lrs, filter_losses=losses, filter_sched=scheds, pseudolabels=[], summary=None):
    
    nrows, ncols = (1, 2) if summary else (1, 1)    
    figsize = figsizes.icml2022_full if summary else figsizes.icml2022_half
    with plt.rc_context(figsize(nrows=nrows, ncols=ncols, height_to_width_ratio=figsizes._GOLDEN_RATIO)):
        f, ax = plt.subplots(nrows, ncols, sharex=False, sharey=True)
        ax = [ax] if ncols==1 else ax

        # plot gridsearch epochs into ax[0]
        new_summary, _ = plot_gridsearch_epochs(filter_lrs, filter_losses, filter_sched, pseudolabels, ax=ax[0])

        if summary == True:             # use new summary if was true
            summary = new_summary
        elif isinstance(summary, dict): # merge if summary was dict
            for cfg in new_summary.keys():
                for sched in new_summary[cfg].keys():
                    for lr in new_summary[cfg][sched].keys():
                        if cfg not in summary.keys():
                            summary[cfg] = {}
                        if sched not in summary[cfg].keys():
                            summary[cfg][sched] = {}
                        summary[cfg][sched][lr] = new_summary[cfg][sched][lr]

        if isinstance(summary, dict): # plot gridsearch summary if is dict 
            plot_gridsearch_summary(summary, ax=ax[1])

    return summary, ax

pseudolabels = [dict(label=cfg2loss[cfg], color=cfg2color[cfg]) for cfg in cfgs]
pseudolabels += [dict(label=sched.__name__, color='gray', ls=ls) for sched, ls in sched2ls.items()]

summary, ax = plot_gridsearch(summary=True, pseudolabels=pseudolabels)
ax[0].legend(loc='lower right', frameon=True, ncols=2)
plt.savefig(f'lrgridsearch.pdf')
clear_output()
plt.show()

In [None]:
from IPython.display import clear_output
pseudolabels = [dict(label=cfg2loss[cfg], color=cfg2color[cfg]) for cfg in cfgs]

summary = True
for idx, lr in enumerate(lrs):
    plt.close()
    summary, ax = plot_gridsearch(filter_lrs=lrs[:idx+1], filter_sched=[ConstSched], pseudolabels=pseudolabels, summary=summary)
    ax[0].text(0.05, 0.95, f'lr={lr:.1f}', ha='left', va='top', transform=ax[0].transAxes)
    ax[0].legend(loc='lower right', frameon=True, ncols=1)
    plt.savefig(f'lrgridsearch-constsched-lr={int(lr):06d}.pdf')
    clear_output()
plt.show()

#!pdfunite lrgridsearch-constsched-lr=*.pdf lrgridsearch-constsched-lr.pdf

In [None]:
from IPython.display import clear_output
pseudolabels = [dict(label=cfg2loss[cfg], color=cfg2color[cfg]) for cfg in cfgs]
pseudolabels += [dict(label=sched.__name__, color='gray', ls=ls) for sched, ls in sched2ls.items()]

summary = True
for idx, lr in enumerate(lrs):
    plt.close()
    summary, ax = plot_gridsearch(filter_lrs=[lr], pseudolabels=pseudolabels, summary=summary)
    ax[0].text(0.05, 0.95, f'lr={lr:.1f}', ha='left', va='top', transform=ax[0].transAxes)
    ax[0].legend(loc='lower right', frameon=True, ncols=2)
    plt.savefig(f'lrgridsearch-lr={int(lr):06d}.pdf')
    clear_output()
plt.show()

#!pdfunite lrgridsearch-lr=*.pdf lrgridsearch-lr.pd

In [None]:
from IPython.display import clear_output
pseudolabels = [dict(label=cfg2loss[cfg], color=cfg2color[cfg]) for cfg in cfgs]

for idx, (sched, ls) in enumerate(sched2ls.items()):
    plt.close()
    _, ax = plot_gridsearch(filter_sched=[sched], summary=True,
                            pseudolabels=pseudolabels+[dict(label=sched.__name__, color='gray', ls=ls)])
    ax[0].legend(loc='lower right', frameon=True, ncols=1)
    plt.savefig(f'lrgridsearch-{sched.__name__.lower()}.pdf')
    clear_output()
plt.show()
