In [None]:
DARK = False
import matplotlib.pyplot as plt
if DARK:
    plt.style.use('dark_background')
import torch
from main import Trainer
from itertools import product
from argparse import Namespace as ns
from vis import results, print_results

## Explore the results
# probe the results by:
# - providing a single value or a list of options for each argument
# - a function that filters out the arguments you want
# out = results(model='lenet', dataset='cifar10', sigma=lambda x: x <= 0.5, aug=[0, 1], exp=0)

# print_results(out)  # you can print the results in lines
# print_results(out, plot=True)  # or plot them, yay!

In [None]:
tol = {
    # these are accuracy tolerance from the baseline accuracy
    # increase these tolerance values in case you get this exception:
    #   ValueError: max() arg is an empty sequence
    ('lenet', 'mnist'): 0,
    ('lenet', 'cifar10'): 0.0039,
    ('lenet', 'cifar100'): 0.0075, # 0,
    ('alexnet', 'cifar10'): 0.0168,
    ('alexnet', 'cifar100'): 0.0483,
    ('vgg16', 'cifar10'): 0.0236,
    ('vgg16', 'cifar100'): 0,
}
def get(model, dataset, **kwargs):
    return [ns(sig=r['sig'], aug=r['aug'], exp=r['exp'],
               epochs=r['summary'].last_epoch,
               time=sum(r['summary'].time),
               acc=r['summary'].test.accuracy,
               rob=r['summary'].test.robustness)
            for r in results(model=model, dataset=dataset, **kwargs)]

## Performance comparison for training with augmentation vs training with the expectation

In [None]:
def results_table(model, dataset, plot=False, fontsize=15, figsize=(8, 5)):
    clean = get(model, dataset, aug=0, exp=0)[0]
    aug = get(model, dataset, aug=lambda x: x > 0, exp=0)
    exp = get(model, dataset, aug=0, exp=lambda x: x > 0)
    augs = sorted(set(a.aug for a in aug).union(set(a.aug for a in exp)))
    sigs = sorted(set(s.sig for s in aug).union(set(a.sig for a in exp)))

    # aug = [a for a in aug if a.aug==10]
    # print(clean.time)
    # print(sum(a.time for a in aug)/len(aug))
    # print(sum(a.time for a in exp)/len(exp))
    
    def extract(robustness):
        out = torch.empty(len(augs), len(sigs)).fill_(float('nan'))
        for r in aug:
            loc = (augs.index(r.aug), sigs.index(r.sig))
            out[loc] = r.rob if robustness else r.acc
        for s in sigs:
            sub_exp = [r for r in exp if r.sig == s]
#             if robustness:
#                 print('\n'.join(str(e) for e in sub_exp) + '\n'+'#'*5)
            acc = [r.acc for r in sub_exp]
            loc = acc.index(max(acc))
            out[0, sigs.index(s)] = sub_exp[loc].rob if robustness else sub_exp[loc].acc
        baseline = clean.rob if robustness else clean.acc
        if plot:
            mn = min(out[out == out].min().item(), baseline)
            mx = max(out[out == out].max().item(), baseline)
            ys = torch.linspace(mn, mx, 8)
            title = 'Robustness $\Re$' if robustness else 'Testing Accuracy'
            plt.figure(figsize=figsize)
            plt.axhline(y=baseline, linewidth=2, color='k', linestyle='--')
            plt.plot(out.t().numpy(), '8-')
            plt.legend(['Baseline', 'Ours']+[f'aug = {a}' for a in augs[1:]], fontsize=fontsize*.9)
            plt.xlabel('Training $\sigma$', fontsize=fontsize)
            plt.ylabel(title, fontsize=fontsize)
            plt.yticks([round(float(y), 4) for y in ys], fontsize=fontsize)
            plt.xticks(range(len(sigs)), [str(s) for s in sigs], fontsize=fontsize)
            plt.grid()
            plt.show()
        return baseline, out
        
    acc = extract(False)
    rob = extract(True)
    
    return acc, rob
        
for model, dataset in tol.keys():
    print(model, dataset)
    acc, rob = results_table(model, dataset, plot=True)

## Fair comparisons

In [None]:
def comparable_results(model, dataset, plot=False, fontsize=15, figsize=(8, 5), tol=tol.copy()):
    clean = get(model, dataset, aug=0, exp=0)[0]
    some = lambda res: list(filter(lambda x: x.acc > clean.acc - tol[model, dataset], res))
    aug = some(get(model, dataset, aug=lambda x: x > 0, exp=0))
    exp = some(get(model, dataset, aug=0, exp=lambda x: x > 0))
    augs = sorted(set(a.aug for a in aug).union(set(a.aug for a in exp)))
    sigs = sorted(set(s.sig for s in aug).union(set(a.sig for a in exp)))

    out = torch.empty(len(sigs)).fill_(float('nan'))
    for s in sigs:
        sub_exp = [r for r in exp if r.sig == s]
        acc = [r.acc for r in sub_exp]
        out[sigs.index(s)] = sub_exp[acc.index(max(acc))].rob

    baseline = clean.rob
    best = max([a.rob for a in aug])
    if plot:
        mn = min(out[out == out].min().item(), baseline, best)
        mx = max(out[out == out].max().item(), baseline, best)
        ys = torch.linspace(mn, mx, 5)
        plt.figure(figsize=figsize)
        plt.plot(out.numpy(), '8-', label='Ours')
        plt.axhline(baseline, linewidth=2, color='k', linestyle='--', label='Baseline')
        plt.axhline(best, linewidth=2, color='r', linestyle='--', label='Best aug')
        plt.ylabel('Robustness $\Re$', fontsize=fontsize)
        plt.xlabel('Training $\sigma$', fontsize=fontsize)
        plt.yticks([round(float(y), 4) for y in ys], fontsize=fontsize)
        plt.xticks(range(len(sigs)), [str(s) for s in sigs], fontsize=fontsize)
        plt.legend(fontsize=fontsize*.9)
        plt.grid()
        plt.show()
    return baseline, best, out

for model, dataset in tol.keys():
    print(model, dataset)
    baseline, best, out = comparable_results(model, dataset, plot=True)