In [23]:
# System
import os
import sys
import inspect
import tabulate
import time

# Data processing
import numpy as np
import math as m

# Results presentation
from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output
import matplotlib
import matplotlib.pyplot as plt

# NN related stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# from torch.autograd import Variable

# import os,sys,inspect
current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir) 

import data
import models
import utils
import correlation


%matplotlib inline

In [27]:
class GlobalArguments():
    
    def __init__(self):
        self.model       = 'vgg5_bn'
        self.dataset     = 'CIFAR100'
        self.data_path   = 'Data/'
        self.batch_size  = 128
        self.num_workers = 4
        self.transform   = 'VGG'
        self.use_test    = False
        self.models_path = 'Checkpoints/'
        self.n_models    = 15
        self.cycle       = 8
        self.dir         = 'Checkpoints/'
        self.ckpt        = 'Checkpoints/VGG16BN_CIFAR10_0/checkpoint-200.pt'
        self.lr_1        = 0.05
        self.lr_2        = 0.0001
args = GlobalArguments()

In [28]:
assert args.cycle % 2 == 0, 'Cycle length should be even'

os.makedirs(args.dir, exist_ok=True)
with open(os.path.join(args.dir, 'fge.sh'), 'w') as f:
    f.write(' '.join(sys.argv))
    f.write('\n')

In [26]:
loaders, num_classes = data.loaders(
    args.dataset,
    args.data_path,
    args.batch_size,
    args.num_workers,
    args.transform,
    args.use_test
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to Data/cifar100/cifar-100-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting Data/cifar100/cifar-100-python.tar.gz to Data/cifar100
Using train (45000) + validation (5000)
Files already downloaded and verified


In [29]:
torch.backends.cudnn.benchmark = True

architecture = getattr(models, args.model)
model = architecture.base(num_classes=num_classes, **architecture.kwargs)
criterion = torch.nn.CrossEntropyLoss()

checkpoint = torch.load(args.ckpt)
start_epoch = checkpoint['epoch'] + 1
model.load_state_dict(checkpoint['model_state'])
model.cuda()

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=args.lr_1,
    momentum=args.momentum,
    weight_decay=args.wd
)
optimizer.load_state_dict(checkpoint['optimizer_state'])

FileNotFoundError: [Errno 2] No such file or directory: 'Checkpoints/VGG16BN_CIFAR10_0/checkpoint-200.pt'

In [13]:
ensemble_size = 0
predictions_sum = np.zeros((len(loaders['test'].dataset), num_classes))

columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_nll', 'te_acc', 'ens_acc', 'time']

for epoch in range(args.epochs):
    time_ep = time.time()
    lr_schedule = utils.cyclic_learning_rate(epoch, args.cycle, args.lr_1, args.lr_2)
    train_res = utils.train(loaders['train'], model, optimizer, criterion, lr_schedule=lr_schedule)
    test_res = utils.test(loaders['test'], model, criterion)
    time_ep = time.time() - time_ep
    predictions, targets = utils.predictions(loaders['test'], model)
    ens_acc = None
    if (epoch % args.cycle + 1) == args.cycle // 2:
        ensemble_size += 1
        predictions_sum += predictions
        ens_acc = 100.0 * np.mean(np.argmax(predictions_sum, axis=1) == targets)

    if (epoch + 1) % (args.cycle // 2) == 0:
        utils.save_checkpoint(
            args.dir,
            start_epoch + epoch,
            name='fge',
            model_state=model.state_dict(),
            optimizer_state=optimizer.state_dict()
        )

    values = [epoch, lr_schedule(1.0), train_res['loss'], train_res['accuracy'], test_res['nll'],
              test_res['accuracy'], ens_acc, time_ep]
    table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='9.4f')
    if epoch % 40 == 0:
        table = table.split('\n')
        table = '\n'.join([table[1]] + table)
    else:
        table = table.split('\n')[2]
    print(table)

AttributeError: 'GlobalArguments' object has no attribute 'ckpt'