In [None]:
import os
import sys
root = os.path.split(os.getcwd())[0]
sys.path.insert(0, root)

import json
import keras.backend as K
from unitization.resnets import ResNets
from unitization.train import (
    reset, train_on_cifar
)

def initialize(model, dataset, seed):
    weight_folder = os.path.join(os.getcwd(), 'weights')
    if not os.path.isdir(weight_folder):
        os.mkdir(weight_folder)
    
    weight_file = os.path.join(weight_folder, '{}-{}.h5'.format(dataset, model.name))
    if os.path.exists(weight_file):
        model.load_weights(weight_file, by_name=True)
    else:
        reset(seed)
        model.save_weights(weight_file)
        
    return model

def get_accuracy(history, model, dataset, normalization):
    try:
        accuracy = history['val_acc'][-1]
    except:
        accuracy = history.history['val_acc'][-1]
    
    accuracy_file = os.path.join(os.getcwd(), 'test_accuracy.json')
    if not os.path.exists(accuracy_file):
        table = dict()
    else:
        with open(accuracy_file, 'r') as f:
            table = json.load(f)
    
    _table = table.setdefault(dataset, dict())
    name = 'ResNet-{}'.format({'-18': '18'}.get(model.name[-3:], model.name[-3:]))
    _table.setdefault(name, []).append((normalization, accuracy))
    with open(accuracy_file, 'w') as f:
        json.dump(table, f)
    return table

In [None]:
if __name__ == '__main__':
    batch_size = 128
    epochs = 200
    seed = 9408

In [None]:
    dataset = 'CIFAR-10'

In [None]:
    print('ResNet-18 with Batch Normalization on CIFAR-10:')
    K.clear_session()
    model = ResNets.cifar_resnet_18('batchnorm', 10)
    model = initialize(model, dataset, seed)
    table = get_accuracy(
        train_on_cifar(dataset, model, seed, batch_size, epochs), 
        model, dataset, 'Batch Normalization'
    )

In [None]:
    print('ResNet-18 with the unitization on CIFAR-10:')
    K.clear_session()
    model = ResNets.cifar_resnet_18('unitization', 10)
    model = initialize(model, dataset, seed)
    table = get_accuracy(
        train_on_cifar(dataset, model, seed, batch_size, epochs), 
        model, dataset, 'Unitization'
    )

In [None]:
    print('ResNet-110 with Batch Normalization on CIFAR-10:')
    K.clear_session()
    model = ResNets.cifar_resnet_110('batchnorm', 10)
    model = initialize(model, dataset, seed)
    table = get_accuracy(
        train_on_cifar(dataset, model, seed, batch_size, epochs), 
        model, dataset, 'Batch Normalization'
    )

In [None]:
    print('ResNet-110 with the unitization on CIFAR-10:')
    K.clear_session()
    model = ResNets.cifar_resnet_110('unitization', 10)
    model = initialize(model, dataset, seed)
    table = get_accuracy(
        train_on_cifar(dataset, model, seed, batch_size, epochs), 
        model, dataset, 'Unitization'
    )

In [None]:
    batch_size = 64

In [None]:
    print('ResNet-164 with Batch Normalization on CIFAR-10:')
    K.clear_session()
    model = ResNets.cifar_resnet_164('batchnorm', 10)
    model = initialize(model, dataset, seed)
    table = get_accuracy(
        train_on_cifar(dataset, model, seed, batch_size, epochs), 
        model, dataset, 'Batch Normalization'
    )

In [None]:
    print('ResNet-164 with the unitization on CIFAR-10:')
    K.clear_session()
    model = ResNets.cifar_resnet_164('unitization', 10)
    model = initialize(model, dataset, seed)
    table = get_accuracy(
        train_on_cifar(dataset, model, seed, batch_size, epochs), 
        model, dataset, 'Unitization'
    )

In [None]:
    dataset = 'CIFAR-100'
    batch_size = 128

In [None]:
    print('ResNet-18 with Batch Normalization on CIFAR-100:')
    K.clear_session()
    model = ResNets.cifar_resnet_18('batchnorm', 100)
    model = initialize(model, dataset, seed)
    table = get_accuracy(
        train_on_cifar(dataset, model, seed, batch_size, epochs), 
        model, dataset, 'Batch Normalization'
    )

In [None]:
    print('ResNet-18 with the unitization on CIFAR-100:')
    K.clear_session()
    model = ResNets.cifar_resnet_18('unitization', 100)
    model = initialize(model, dataset, seed)
    table = get_accuracy(
        train_on_cifar(dataset, model, seed, batch_size, epochs), 
        model, dataset, 'Unitization'
    )

In [None]:
    print('ResNet-110 with Batch Normalization on CIFAR-100:')
    K.clear_session()
    model = ResNets.cifar_resnet_110('batchnorm', 100)
    model = initialize(model, dataset, seed)
    table = get_accuracy(
        train_on_cifar(dataset, model, seed, batch_size, epochs), 
        model, dataset, 'Batch Normalization'
    )

In [None]:
    print('ResNet-110 with the unitization on CIFAR-100:')
    K.clear_session()
    model = ResNets.cifar_resnet_110('unitization', 100)
    model = initialize(model, dataset, seed)
    table = get_accuracy(
        train_on_cifar(dataset, model, seed, batch_size, epochs), 
        model, dataset, 'Unitization'
    )

In [None]:
    batch_size = 64

In [None]:
    print('ResNet-164 with Batch Normalization on CIFAR-100:')
    K.clear_session()
    model = ResNets.cifar_resnet_164('batchnorm', 100)
    model = initialize(model, dataset, seed)
    table = get_accuracy(
        train_on_cifar(dataset, model, seed, batch_size, epochs), 
        model, dataset, 'Batch Normalization'
    )

In [None]:
    print('ResNet-164 with the unitization on CIFAR-100:')
    K.clear_session()
    model = ResNets.cifar_resnet_164('unitization', 100)
    model = initialize(model, dataset, seed)
    table = get_accuracy(
        train_on_cifar(dataset, model, seed, batch_size, epochs), 
        model, dataset, 'Unitization'
    )

In [None]:
    with open(os.path.join(os.getcwd(), 'test_accuracy.json'), 'r') as f:
        table = json.load(f)
    try:
        !{sys.executable} -m pip install tabulate
        from IPython.display import HTML, display
        import tabulate
        display_table = [['dataset', 'model', 'normalization', 'accuracy']]
        for dataset, _table in table.items():
            model_names = list(_table.keys())
            model_names.sort(key=lambda x: {'-18': 18, '110': 110, '164': 164}[x[-3:]])
            for model_name in model_names:
                for normalization, accuracy in _table[model_name]:
                    display_table.append([dataset, model_name, normalization, '{:.2f}%'.format(accuracy * 100)])
        display(HTML(tabulate.tabulate(display_table, tablefmt='html')))
    except:
        form = '{:10s} {:10s} {:20s} {:10s}'
        display = form.format('dataset', 'model', 'normalization', 'accuracy')
        for dataset, _table in table.items():
            model_names = list(_table.keys())
            model_names.sort(key=lambda x: {'-18': 18, '110': 110, '164': 164}[x[-3:]])
            for model_name in model_names:
                for normalization, accuracy in _table[model_name]:
                    display += '\n' + form.format(
                        dataset, model_name, normalization, '{:.2f}%'.format(accuracy * 100)
                    )
        print(display)