#Utility functions, parameters, and Google Drive

In [None]:
#@title General parameters { run: "auto", display-mode: "form" }
from torchvision import transforms
import PIL.Image

_nb_epochs = 10 #@param {type:"integer"}
_log_interval = 1 #@param {type:"integer"}
_dataset = "CIFAR-10" #@param ["ImageNet-16", "CIFAR-10", "CIFAR-100", "CINIC-10"]

# _interpolation_method = "nearest" #@param ["nearest", "bilinear", "cubic"]
# if "nearest" == _interpolation_method:
#     _interpolation_method = PIL.Image.NEAREST
# elif "bilinear" == _interpolation_method:
#     _interpolation_method = PIL.Image.BILINEAR
# elif "cubic" == _interpolation_method:
#     _interpolation_method = PIL.Image.CUBIC
    
_transformations = transforms.Compose([
    transforms.RandomResizedCrop(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

In [None]:
#@title AlexNet parameters

_alex_enabled = False #@param {type:"boolean"}
if _alex_enabled:
#     _alex_scheduler="Step" #@param ["Step", "Adaptive"]
#     _alex_step_size=15 #@param {type:"number"}
#     _alex_gamma = 0.1 #@param {type:"number"}
    _alex_pretrained = False #@param {type:"boolean"}

In [None]:
#@title SqueezeNet parameters

# _squeeze_learning_rate = 0.01 #@param {type:"number"}
# _squeeze_momentum = 0.5 #@param {type:"number"}
# _squeeze_weight_decay=0.0002 #@param {type:"number"}
# _squeeze_scheduler="Step" #@param ["Step", "Adaptive"]
# _squeeze_step_size=15 #@param {type:"number"}
# _squeeze_gamma = 0.1 #@param {type:"number"}
_squeeze_pretrained = True #@param {type:"boolean"}

In [None]:
#@title Google Drive {run: "auto"}
from google.colab import drive

drive.mount('/gdrive', force_remount=True)
data_folder = '/gdrive/My Drive/COMP551_Assignment3_data/'


In [None]:
#@title Utility functions

def log_progress(sequence, every=None, size=None, name='Items'):
    from ipywidgets import IntProgress, HTML, VBox
    from IPython.display import display
    
    is_iterator = False
    if size is None:
        try:
            size = len(sequence)
        except TypeError:
            is_iterator = True
    if size is not None:
        if every is None:
            if size <= 200:
                every = 1
            else:
                every = int(size / 200)     # every 0.5%
    else:
        assert every is not None, 'sequence is iterator, set every'

    if is_iterator:
        progress = IntProgress(min=0, max=1, value=1)
        progress.bar_style = 'info'
    else:
        progress = IntProgress(min=0, max=size, value=0)
    label = HTML()
    box = VBox(children=[label, progress])
    display(box)
        
    index = 0
    try:
        for index, record in enumerate(sequence, 1):
            if index == 1 or index % every == 0:
                if is_iterator:
                    label.value = '{name}: {index} / ?'.format(
                        name=name,
                        index=index
                    )
                else:
                    progress.value = index
                    label.value = u'{name}: {index} / {size}'.format(
                        name=name,
                        index=index,
                        size=size
                    )
                    
            yield record
    except:
        progress.bar_style = 'danger'
        raise
    else:
        progress.bar_style = 'success'
        progress.value = index
        label.value = "{name}: {index}".format(
            name=name,
            index=str(index or '?')
        )

def try_makedirs(d):
    import os
    try:
        os.makedirs(d)
    except FileExistsError as e:
        pass

    
def load_image(filename):
    return \
        PIL.Image.open(filename)\
          .convert('RGB')

    
# def load_image(filename):
#     return \
#         PIL.Image.open(filename)\
#           .resize((64,64), _interpolation_method)\
#           .convert('RGB')

# def load_image(filename):
#     return np.array(
#         PIL.Image.open(filename)\
#           .resize((64,64), _interpolation_method)\
#           .convert('RGB'),
#         dtype=np.float32
#     )/255


def plot_results(
    results_1, model_1_name, model_1_color,
    results_2, model_2_name, model_2_color
):
    import matplotlib.pyplot as plt
    
    x = range(1, 1+results_1.shape[0])
    one_plt = plt.plot(x, results_1[:,1], model_1_color + '-', x, results_1[:,1], model_1_color + 'o')
    two_plt = plt.plot(x, results_2[:,1], model_2_color + '-', x, results_2[:,1], model_2_color + 'o')
    
    plt.legend((one_plt[0], two_plt[0]), (model_1_name, model_2_name))
    plt.title('Model accuracy improvement over time')
    plt.xlabel('# of epochs')
    plt.xticks(np.linspace(1, results_1.shape[0], 11))
    plt.ylabel('Accuracy (%)')
    plt.ylim(0, 100)
    plt.yticks(np.linspace(0, 100, 11));
    plt.draw()
    
    
def plot_results_to_grid(
    grid, where,
    results_1, model_1_name, model_1_color,
    results_2, model_2_name, model_2_color
):
    import matplotlib.pyplot as plt
    
    with grid.output_to(where[0], where[1]):
        grid.clear_cell()
        plot_results(results_1, model_1_name, model_1_color, results_2, model_2_name, model_2_color)

        
def elapsed_time(model_name, i):
    import time
    
    if 0 == i:
        elapsed_time.start = time.time()
    else:
        print(
            "{{{}}} The last training epoch took {} seconds.\n\n".format(
                model_name, time.time() - elapsed_time.start
            )
        )
        

class CatchIO:
    def __init__(self):
        self._stdout = None
        self.buffer = None
    
    def __enter__(self):
        import sys
        import io
        
        self._stdout = sys.stdout
        sys.stdout = io.StringIO()
    
    def __exit__(self, type_, value, traceback):
        import sys
        
        self.buffer = sys.stdout.getvalue()
        sys.stdout = self._stdout
        if value is not None:
            raise value


#Models and datasets

In [None]:
#@title AlexNet{display-mode: "form"}
from torchvision.models.alexnet import alexnet
import torch.nn as nn
import torch.nn.functional as F
 
class AlexNet(nn.Module):
  
    def __init__(self):
        super(AlexNet, self).__init__()
        self.alex = alexnet(pretrained=_alex_pretrained, num_classes=1000);
   
    def forward(self, x):
        return F.log_softmax(self.alex.forward(x), dim=1)

In [None]:
#@title SqueezeNet {display-mode: "form"}
import torchvision.models.squeezenet
import torch.nn as nn
import torch.nn.functional as F

class SqueezeNet(nn.Module):
  
    def __init__(self):
        super(SqueezeNet, self).__init__()
        self.squeeze = torchvision.models.squeezenet.squeezenet1_0(pretrained=_squeeze_pretrained, num_classes=1000);
   
    def forward(self, x):
        return F.log_softmax(self.squeeze.forward(x), dim=1)

In [None]:
#@title ImageNet-16 _(warning: very slow)_
from torchvision.datasets import DatasetFolder

from zipfile import ZipFile
import os
import pandas
import PIL.Image
import numpy as np

path_to_ImageNet16 = "COMP551_Assignment4/data/" #@param {type:"string"}

if "ImageNet-16" == _dataset:
    IMGSZ = (16,16,3)
    NUM_IMAGES = 1275273

    train_pickle_folder = 'data/16/train/pickle'
    train_image_folder = 'data/16/train/images'
    valid_image_folder = 'data/16/validation/images'
    test_pickle_folder = 'data/16/test/pickle'
    test_image_folder = 'data/16/test/images'

    force = False #@param {type:"boolean"}
    train_valid_split = 0.9 #@param {type:"slider", min:0, max:1, step:0.05}

    ct = np.zeros(1000)

    if not os.path.exists(train_image_folder) or force:
        # Create folders
        try_makedirs(train_pickle_folder)
        try_makedirs(train_image_folder)
        try_makedirs(valid_image_folder)
        try_makedirs(test_pickle_folder)
        try_makedirs(test_image_folder)
        for ix in range(1000):
            try_makedirs(os.path.join(train_image_folder, "class_{}".format(1+ix)))
            try_makedirs(os.path.join(valid_image_folder, "class_{}".format(1+ix)))
            try_makedirs(os.path.join(test_image_folder, "class_{}".format(1+ix)))
        # Extract training data zipfile
        z = ZipFile(
            os.path.join(
                '/gdrive/My Drive/', path_to_ImageNet16, 'Imagenet16_train.zip'
            )
        )
        z.extractall(train_pickle_folder)
        z.close()
        # Extract training images to folders
        for ix in log_progress(
            range(1, 11), every=1, name="Converting pickle files to images"
        ):
            # Read pickle
            tmp = pandas.read_pickle(
                os.path.join(train_pickle_folder, 'train_data_batch_{}'.format(ix))
            )
            # CIFAR format to array of images
            train_images = tmp['data'] \
              .reshape(-1, IMGSZ[2], IMGSZ[0], IMGSZ[1]) \
              .transpose([0, 2, 3, 1])
            # Save images to png (lossless format)
            for jx in range(len(train_images)):
                label = tmp['labels'][jx]
                # Split training set into training and validation sets
                if ct[label-1] > np.floor(train_valid_split * NUM_IMAGES / 1000):
                    # validation
                    out_folder = valid_image_folder
                else:
                    # training
                    out_folder = train_image_folder
                    ct[label-1] = 1 + ct[label-1]
                # Save
                PIL.Image.fromarray(train_images[jx]).save(
                    os.path.join(
                        out_folder,
                        'class_{}'.format(label),
                        '{}.png'.format(jx)
                    )
                )
            # Remove pickle file since we're done with it
            os.remove(
                os.path.join(train_pickle_folder, 'train_data_batch_{}'.format(ix))
            )

        # Extract test data zipfile
        z = ZipFile(
            os.path.join(
                '/gdrive/My Drive/', path_to_ImageNet16, 'Imagenet16_val.zip'
            )
        )
        z.extractall(test_pickle_folder)
        z.close()
        # Extract test images to folders
        #   Read pickle
        tmp = pandas.read_pickle(os.path.join(test_pickle_folder, 'val_data'))
        #   CIFAR format to array of images
        test_images = tmp['data'] \
          .reshape(-1, IMGSZ[2], IMGSZ[0], IMGSZ[1]) \
          .transpose([0, 2, 3, 1])
        #   Save images to png (lossless format)
        for jx in range(len(test_images)):
            # Save
            PIL.Image.fromarray(test_images[jx]).save(
                os.path.join(
                    test_image_folder,
                    'class_{}'.format(tmp['labels'][jx]),
                    '{}.png'.format(jx)
                )
            )
        #   Remove pickle file since we're done with it
        os.remove(
            os.path.join(test_pickle_folder, 'val_data')
        )

    trainset = DatasetFolder(
        train_image_folder, load_image, ['png'],
        transform=_transformations
    )
    validset = DatasetFolder(
        valid_image_folder, load_image, ['png'],
        transform=_transformations
    )
    testset = DatasetFolder(
        test_image_folder, load_image, ['png'],
        transform=_transformations
    )

else:
    print('Disabled')


In [None]:
#@title CIFAR-10
from torchvision.datasets import DatasetFolder
import tarfile
import os
import pandas
import PIL.Image
import numpy as np

path_to_CIFAR10 = "COMP551_Assignment4/data/" #@param {type:"string"}

if "CIFAR-10" == _dataset:
    IMGSZ = (32, 32, 3)
    NUM_IMAGES = 50000
    NUM_CLASSES = 10

    pickle_folder = 'data/CIFAR10/pickle'
    train_image_folder = 'data/CIFAR10/train/images'
    valid_image_folder = 'data/CIFAR10/validation/images'
    test_image_folder = 'data/CIFAR10/test/images'

    force = False #@param {type:"boolean"}
    train_valid_split = 0.9 #@param {type:"slider", min:0, max:1, step:0.05}

    ct = np.zeros(NUM_CLASSES)

    if not os.path.exists(train_image_folder) or force:
        # Create folders
        try_makedirs(pickle_folder)
        try_makedirs(train_image_folder)
        try_makedirs(valid_image_folder)
        try_makedirs(test_image_folder)
        for ix in range(NUM_CLASSES):
            try_makedirs(os.path.join(train_image_folder, "class_{}".format(ix)))
            try_makedirs(os.path.join(valid_image_folder, "class_{}".format(ix)))
            try_makedirs(os.path.join(test_image_folder, "class_{}".format(ix)))
        # Extract training data zipfile
        tar = tarfile.open(
            os.path.join(
                '/gdrive/My Drive/', path_to_CIFAR10, 'cifar-10-python.tar.gz'
            )
        )
        tar.extractall(pickle_folder)
        tar.close()
        # Extract training images to folders
        for ix in log_progress(
            range(1, 6), every=1, name="Converting pickle files to images"
        ):
            # Read pickle
            tmp = pandas.read_pickle(
                os.path.join(
                    pickle_folder, 'cifar-10-batches-py', 'data_batch_{}'.format(ix)
                )
            )
            # CIFAR format to array of images
            train_images = tmp['data'] \
              .reshape(-1, IMGSZ[2], IMGSZ[0], IMGSZ[1]) \
              .transpose([0, 2, 3, 1])
            # Save images to png (lossless format)
            for jx in range(len(train_images)):
                label = tmp['labels'][jx]
                # Split training set into training and validation sets
                if ct[label-1] >= np.floor(train_valid_split * NUM_IMAGES / NUM_CLASSES):
                    # validation
                    out_folder = valid_image_folder
                else:
                    # training
                    out_folder = train_image_folder
                    ct[label-1] = 1 + ct[label-1]
                # Save
                PIL.Image.fromarray(train_images[jx]).save(
                    os.path.join(
                        out_folder,
                        'class_{}'.format(label),
                        '{}.png'.format(jx)
                    )
                )
            # Remove pickle file since we're done with it
            os.remove(
                os.path.join(
                    pickle_folder,
                    'cifar-10-batches-py', 
                    'data_batch_{}'.format(ix)
                )
            )

        # Extract test images to folders
        #   Read pickle
        tmp = pandas.read_pickle(
            os.path.join(pickle_folder, 'cifar-10-batches-py', 'test_batch')
        )
        #   CIFAR format to array of images
        test_images = tmp['data']\
          .reshape(-1, IMGSZ[2], IMGSZ[0], IMGSZ[1]) \
          .transpose([0, 2, 3, 1])
        #   Save images to png (lossless format)
        for jx in range(len(test_images)):
            # Save
            PIL.Image.fromarray(test_images[jx]).save(
                os.path.join(
                    test_image_folder,
                    'class_{}'.format(tmp['labels'][jx]),
                    '{}.png'.format(jx)
                )
            )
        #   Remove pickle file since we're done with it
        os.remove(
            os.path.join(pickle_folder, 'cifar-10-batches-py', 'test_batch')
        )
    trainset = DatasetFolder(
        train_image_folder, load_image, ['png'],
        transform=_transformations
    )
    validset = DatasetFolder(
        valid_image_folder, load_image, ['png'],
        transform=_transformations
    )
    testset = DatasetFolder(
        test_image_folder, load_image, ['png'],
        transform=_transformations
    )

else:
    print('Disabled')

In [None]:
#@title CIFAR-100
from torchvision.datasets import DatasetFolder
import tarfile
import os
import pandas
import PIL.Image
import numpy as np

path_to_CIFAR100 = "COMP551_Assignment4/data/" #@param {type:"string"}

if "CIFAR-100" == _dataset:
    print('Loading CIFAR-100... ', end='')
    
    IMGSZ = (32, 32, 3)
    NUM_IMAGES = 50000
    NUM_CLASSES = 100

    pickle_folder = 'data/CIFAR100/pickle'
    train_image_folder = 'data/CIFAR100/train/images'
    valid_image_folder = 'data/CIFAR100/validation/images'
    test_image_folder = 'data/CIFAR100/test/images'

    label_type = "fine (100)" #@param ["fine (100)", "coarse (20)"]
    if label_type == "fine (100)":
        label_type = "fine_labels"
    elif label_type == "coarse (20)":
        label_type = "coarse_labels"
    
    force = False #@param {type:"boolean"}
    train_valid_split = 0.9 #@param {type:"slider", min:0, max:1, step:0.05}

    ct = np.zeros(NUM_CLASSES)

    if not os.path.exists(train_image_folder) or force:
        # Create folders
        try_makedirs(pickle_folder)
        try_makedirs(train_image_folder)
        try_makedirs(valid_image_folder)
        try_makedirs(test_image_folder)
        for ix in range(NUM_CLASSES):
            try_makedirs(os.path.join(train_image_folder, "class_{}".format(ix)))
            try_makedirs(os.path.join(valid_image_folder, "class_{}".format(ix)))
            try_makedirs(os.path.join(test_image_folder, "class_{}".format(ix)))
        # Extract training data zipfile
        tar = tarfile.open(
            os.path.join(
                '/gdrive/My Drive/', path_to_CIFAR100, 'cifar-100-python.tar.gz'
            )
        )
        tar.extractall(pickle_folder)
        tar.close()
        # Extract training images to folders
        # Read pickle
        tmp = pandas.read_pickle(
            os.path.join(pickle_folder, 'cifar-100-python', 'train')
        )
        tmp.keys()
        # CIFAR format to array of images
        train_images = tmp['data'] \
          .reshape(-1, IMGSZ[2], IMGSZ[0], IMGSZ[1]) \
          .transpose([0, 2, 3, 1])
        # Save images to png (lossless format)
        for jx in range(len(train_images)):
            label = tmp[label_type][jx]
            # Split training set into training and validation sets
            if ct[label-1] > np.floor(train_valid_split * NUM_IMAGES / NUM_CLASSES):
                # validation
                out_folder = valid_image_folder
            else:
                # training
                out_folder = train_image_folder
                ct[label-1] = 1 + ct[label-1]
            # Save
            PIL.Image.fromarray(train_images[jx]).save(
                os.path.join(
                    out_folder,
                    'class_{}'.format(label),
                    '{}.png'.format(jx)
                )
            )
        # Remove pickle file since we're done with it
        os.remove(os.path.join(pickle_folder, 'cifar-100-python', 'train'))

        # Extract test images to folders
        #   Read pickle
        tmp = pandas.read_pickle(
            os.path.join(pickle_folder, 'cifar-100-python', 'test')
        )
        #   CIFAR format to array of images
        test_images = tmp['data']\
          .reshape(-1, IMGSZ[2], IMGSZ[0], IMGSZ[1]) \
          .transpose([0, 2, 3, 1])
        #   Save images to png (lossless format)
        for jx in range(len(test_images)):
            # Save
            PIL.Image.fromarray(test_images[jx]).save(
                os.path.join(
                    test_image_folder,
                    'class_{}'.format(tmp[label_type][jx]),
                    '{}.png'.format(jx)
                )
            )
        #   Remove pickle file since we're done with it
        os.remove(
            os.path.join(pickle_folder, 'cifar-100-python', 'test')
        )
    
    trainset = DatasetFolder(
        train_image_folder, load_image, ['png'],
        transform=_transformations
    )
    validset = DatasetFolder(
        valid_image_folder, load_image, ['png'],
        transform=_transformations
    )
    testset = DatasetFolder(
        test_image_folder, load_image, ['png'],
        transform=_transformations
    )
    
    print('Done.')
    
else:
    print('Disabled')

In [None]:
#@title CINIC-10
from torchvision.datasets import ImageFolder
import tarfile
import os

path_to_CINIC = "NordVPN216856 " #@param {type:"string"}
force = False #@param {type:"boolean"}
image_folder = 'data/CINIC-10'

if "CINIC-10" == _dataset:
    print('Loading CIFAR-100... ', end='')

    if not os.path.exists(image_folder) or force:
        tar = tarfile.open(os.path.join('/gdrive/My Drive/', path_to_CINIC))
        try_makedirs(image_folder)
        tar.extractall(image_folder)
        tar.close()

    trainset = ImageFolder(
        os.path.join(image_folder, 'train'),
        loader=load_image,
        transform=_transformations
    )
    validset = ImageFolder(
        os.path.join(image_folder, 'valid'),
        loader=load_image,
        transform=_transformations
    )
    testset = ImageFolder(
        os.path.join(image_folder, 'test'),
        loader=load_image,
        transform=_transformations
    )
    
    print('done')

else:
    print('Disabled')

In [None]:
#@title Dataloader and model initialization
import torch.nn.functional as F
import torch

def train(model, model_name, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % _log_interval == 0:
            print('{{{}}} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    model_name, epoch, batch_idx * len(data),
                    len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()
                )
            )


def validate(model, model_name, device, validation_loader):
    model.eval()
    validation_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in validation_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)

            # sum up batch loss
            validation_loss += F.nll_loss(output, target, reduction='sum').item()
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    validation_loss /= len(validation_loader.dataset)

    print(
        '\n{{{}}} Validation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
        .format(
            model_name, validation_loss, correct,
            len(validation_loader.dataset),
            100. * correct / len(validation_loader.dataset)
        )
    )

    return validation_loss, 100. * correct / len(validation_loader.dataset)


def test(model, model_name, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)

            # sum up batch loss
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        '\n{{{}}} Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
        .format(
            model_name, test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)
        )
    )

    return test_loss, 100. * correct / len(test_loader.dataset)


def train_nonan(model, model_name, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        
        loss.backward()
        optimizer.step()
        if batch_idx % _log_interval == 0:
            print('{{{}}} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    model_name, epoch, batch_idx * len(data),
                    len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()
                )
            )
    
    # Need to finish batch so that loader will reset for next model/batch
    return not np.isnan(loss.cpu().detach().numpy())


torch.manual_seed(1)



In [None]:
import matplotlib.pyplot as plt
import numpy as np

print(
    len(trainset),
    len(validset),
    len(trainset) / (len(trainset) + len(validset)),
    len(testset)
)

plt.subplot(1, 3, 1)
plt.imshow(trainset[0][0].numpy().transpose([1, 2, 0]))
plt.title('training set')

plt.subplot(1, 3, 2)
plt.imshow(validset[0][0].numpy().transpose([1, 2, 0]))
plt.title('validation set')

plt.subplot(1, 3, 3)
plt.imshow(testset[0][0].numpy().transpose([1, 2, 0]));
plt.title('test set')

#Training and results

In [None]:
#@title Learning parameter grid search
from torch.utils.data import DataLoader
import torch.optim as optim
import torch
import numpy as np
from google.colab import widgets
import matplotlib.pyplot as plt
import joblib
import datetime

drive_savefolder = "COMP551_Assignment4/" #@param {type:"string"}

grid = widgets.Grid(2, 1)

# BatchSize = [64, 512, 1024]
# LearningRate = [0.05, 0.01, 0.001]
# Momentum = [0.9, 0.5, 0.1]
# WeightDecay = [0, 0.0002]

BatchSize = [64]
LearningRate = [0.05, 0.01, 0.001]
Momentum = [0.9, 0.5, 0.1]
WeightDecay = [0]

# BatchSize = [1024]
# LearningRate = [0.01, 0.001]
# Momentum = [0.9, 0.5, 0.1]
# WeightDecay = [0, 0.0002]

# BatchSize = [512]
# LearningRate = [0.001]
# Momentum = [0.9, 0.1]
# WeightDecay = [0.0002]


space = [BatchSize, LearningRate, Momentum, WeightDecay]
search = {}
num = np.ones(len(space), dtype=int)
for ix, ll in enumerate(space[1:]):
    num[:1+ix] *= len(ll)
    
for ix in range(np.prod([len(ll) for ll in space])):
    key = [[]]*len(space)
    for jx, ll in enumerate(space):
        key[jx] = ll[int(ix / num[jx]) % len(ll)]
    search[tuple(key)] = 0

for params in search.keys():
    alex_good = True
    squeeze_good = True
    
    # Parameters
    _batch_size = params[0]
    _learning_rate = params[1]
    _momentum = params[2]
    _weight_decay = params[3]
    
    # Initialize training, validation, and test loaders
    train_loader = DataLoader(trainset, batch_size=_batch_size, shuffle=True)
    valid_loader = DataLoader(validset, batch_size=_batch_size, shuffle=True)
    test_loader = DataLoader(testset, batch_size=_batch_size, shuffle=True)

    # Initialize AlexNet
    if _alex_enabled:
        alex = AlexNet().to("cuda")
        alex_opt = optim.SGD(
            alex.parameters(), lr=_learning_rate, momentum=_momentum,
            weight_decay=_weight_decay
        )
    # Initialize SqueezeNet
    squeeze = SqueezeNet().to("cuda")
    squeeze_opt = optim.SGD(
        squeeze.parameters(), lr=_learning_rate, momentum=_momentum,
        weight_decay=_weight_decay
    )
    
    # Initialize result arrays
    if _alex_enabled or 'alex_results' not in locals():
        alex_results = np.full((1+_nb_epochs, 2), np.nan)
    squeeze_results = np.full((1+_nb_epochs, 2), np.nan)
    
    # Train
    for epoch in range(1, _nb_epochs + 1):
        # Display training output
        with grid.output_to(1, 0):
            if _alex_enabled:
                # No point in training AlexNet if the weights have exploded to NaN values
                if alex_good:
                    alex_good = train_nonan(alex, "AlexNet", "cuda", train_loader, alex_opt, epoch)
                # No point in validating AlexNet either (will inevitably give ~1/num_classes accuracy)
                if alex_good:
                    alex_results[epoch-1, :] = validate(alex, "AlexNet", "cuda", valid_loader)

            if squeeze_good:
                squeeze_good = train_nonan(squeeze, "SqueezeNet", "cuda", train_loader, squeeze_opt, epoch)
            if squeeze_good:
                squeeze_results[epoch-1, :2] = validate(squeeze, "SqueezeNet", "cuda", valid_loader)

        plot_results_to_grid(grid, (0, 0), alex_results, "AlexNet", "r", squeeze_results, "SqueezeNet", "b")
        
        if not (_alex_enabled and alex_good) and not squeeze_good:
            print('\n=================\nInterrupting training for set of parameters (NaN)\n=================\n\n')
            break


    if _alex_enabled and alex_good:
        alex_results[-1, :2] = test(alex, "AlexNet", "cuda", test_loader)
    if squeeze_good:
        squeeze_results[-1, :2] = test(squeeze, "SqueezeNet", "cuda", test_loader)
    
    params_to_save = \
    {
        'squeeze_pretrained': _squeeze_pretrained,
        'nb_epochs': _nb_epochs,
        'dataset': _dataset,
        'batch_size': params[0],
        'learning_rate': params[1],
        'momentum': params[2],
        'weight_decay': params[3]
    }
    
    joblib.dump(
        [params_to_save, alex_results, squeeze_results],
        os.path.join(
            '/gdrive/My Drive/', drive_savefolder, 'alex_squeeze_grid_{}.joblib'.format(
                datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")
            )
        )
    )
    
    search[params] = [alex_results[-1,1], squeeze_results[-1,1]]
    
joblib.dump(search, os.path.join('/gdrive/My Drive/', drive_savefolder, 'search_{}.joblib'.format(
    datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")
)))