In [12]:
from pathlib import Path
from collections import namedtuple

import torch


def valid_result(key, accuracy):
    if key.dataset == 'CIFAR10':
        threshold = 0.4
    elif key.dataset == 'MNIST':
        threshold = 0.975
    else:
        return False
    if accuracy < threshold:
        return False
    if key.epsilon == 0 and key.temperature != 0:
        return False
    return True


def iterate_results(resutls_dir=Path.home() / 'results'):
    Key = namedtuple('Key', (
        'dataset', 'model', 'epsilon',
        'learning_rate', 'factor', 'temperature'))
    for p in Path(resutls_dir).iterdir():
        if len(p.name.split('-')) != 3:
            continue
        d, m, e = p.name.split('-')
        for p in p.iterdir():
            if len(p.name.split('-')) != 3:
                continue
            l, f, t = p.name.split('-')
            key = Key(d, m, float(e), float(l), float(f), int(t))
            p = p / 'checkpoint.pth'
            if not p.exists():
                continue
            yield key, p


def iterate_models(results_iterator=iterate_results, check=valid_result):
    Experiment = namedtuple('Experiment', (
        'method', 'dataset', 'model', 'epsilon'))
    Result = namedtuple('Result', ('state_dict', 'accuracy', 'pgd'))
    for k, p in results_iterator():
        if k.epsilon == 0:
            e = 'nominal'
        elif k.learning_rate == 0:
            e = 'deepmind'
        else:
            e = 'ours'
        e = Experiment(e, k.dataset, k.model, k.epsilon)
        c = torch.load(p, 'cpu')
        accuracy = c.get('accuracy', c.get('best_acc1', 0) / 100)
        if check is not None and not check(k, accuracy):
            continue
        robustness = [1 - torch.load(g, 'cpu')['fooling_rate']
                      for g in p.parent.glob('pgd*.pth')]
        if len(robustness) == 0:
            print(f'No PGD in: {p.parent}')
            continue
        pgd = sum(robustness) / (len(robustness) + 1)
        r = Result(c['state_dict'], accuracy, pgd)
        yield e, r


def get_results(model_iterator=iterate_models):
    results = {}
    for k, r in model_iterator():
        if k not in results or results[k].pgd < r.pgd:
            results[k] = r
    return results

In [48]:
results = get_results()
k, v = next(iter(results.items()))
torch.save({
    'keys': k._fields,
    'values': v._fields,
    'experiments': {tuple(k): tuple(v) for k, v in results.items()},
}, 'results.pth')