In [6]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
import numpy as np

from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter
from utils import device, get_num_correct, RunBuilder
from network import Network

In [7]:
train_set = torchvision.datasets.MNIST(
    root='./data/',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
test_set = torchvision.datasets.MNIST(
    root='./data/',
    train=False,
    download=True,
    transform=transforms.ToTensor()
)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [10]:
valid_size = 0.2

# obtain training indices that will be used for validation
num_train = len(train_set)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]

# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

In [11]:
from collections import OrderedDict

params = OrderedDict(
    lr = [0.01, 0.003, 0.001, 0.0003],
    batch_size = [256, 512]
)

In [30]:
criterion = nn.CrossEntropyLoss()

for run in RunBuilder.get_runs(params):
    print(f'{run}')
    network = Network().to(device)
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=run.batch_size,
        sampler=train_sampler,
        num_workers=1
        )
    valid_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=run.batch_size,
        sampler=valid_sampler,
        num_workers=1
        )
    optimizer = optim.Adam(network.parameters(), lr=run.lr)

    comment = f'-val-{run}'
    tb = SummaryWriter(comment=comment)

    valid_loss_min = np.Inf

    for epoch in range(10):

        train_loss = 0
        train_correct = 0

        ###################
        # train the model #
        ###################
        network.train()
        for batch in train_loader:
            images, labels = batch[0].to(device), batch[1].to(device)
            preds = network(images)
            loss = criterion(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * run.batch_size
            train_correct += get_num_correct(preds, labels)

        tb.add_scalar('Train Loss', train_loss, epoch)
        tb.add_scalar('Train Accuracy', train_correct / len(train_loader.sampler), epoch)

        network.eval()
        with torch.no_grad():

            valid_loss = 0
            valid_correct = 0

            ######################
            # validate the model #
            ######################
            for batch in valid_loader:
                images, labels = batch[0].to(device), batch[1].to(device)
                preds = network(images)
                loss = criterion(preds, labels)

                valid_loss += loss.item() * run.batch_size
                valid_correct += get_num_correct(preds, labels)

            tb.add_scalar('Valid Loss', valid_loss, epoch)
            tb.add_scalar('Valid Accuracy', valid_correct / len(valid_loader.sampler), epoch)

            # print training/validation statistics
            # calculate average loss over an epoch
            train_loss = train_loss/len(train_loader.sampler)
            valid_loss = valid_loss/len(valid_loader.sampler)
            print(f'Epoch {epoch+1}: Training Loss: {train_loss:.6f} Validation Loss: {valid_loss:.6f}')

            # save model if validation loss has decreased
            if valid_loss <= valid_loss_min:
                print(f'\t valid_loss decreased ({valid_loss_min:.6f} --> {valid_loss:.6f})  saving model...')
                torch.save(network.state_dict(), f'./models/with_validation/model-{run}.ckpt')
                valid_loss_min = valid_loss

            # load the model with least validation loss
            network.load_state_dict(
                torch.load(
                    f'models/with_validation/model-{run}.ckpt',
                    map_location=device
                    )
                )

            test_loss = 0
            test_correct = 0

            ##################
            # test the model #
            ##################
            for batch in test_loader:
                images, labels = batch[0].to(device), batch[1].to(device)
                preds = network(images)
                loss = criterion(preds, labels)

                test_loss += loss.item() * images.size(0)
                test_correct += get_num_correct(preds, labels)

            tb.add_scalar('Test Loss', test_loss, epoch)
            tb.add_scalar('Test Accuracy', test_correct / len(test_set), epoch)


        for name, weight in network.named_parameters():
            tb.add_histogram(name, weight, epoch)
            tb.add_histogram(f'{name}.grad', weight.grad, epoch)

run(lr=0.01, batch_size=256)
Epoch 1: Training Loss: 0.520361 Validation Loss: 0.076252
	 valid_loss decreased (inf --> 0.076252)  saving model...
Epoch 2: Training Loss: 0.090637 Validation Loss: 0.052912
	 valid_loss decreased (0.076252 --> 0.052912)  saving model...
Epoch 3: Training Loss: 0.064374 Validation Loss: 0.054400
Epoch 4: Training Loss: 0.067970 Validation Loss: 0.052701
	 valid_loss decreased (0.052912 --> 0.052701)  saving model...
Epoch 5: Training Loss: 0.052571 Validation Loss: 0.036333
	 valid_loss decreased (0.052701 --> 0.036333)  saving model...
Epoch 6: Training Loss: 0.039477 Validation Loss: 0.037052
Epoch 7: Training Loss: 0.045599 Validation Loss: 0.035199
	 valid_loss decreased (0.036333 --> 0.035199)  saving model...
Epoch 8: Training Loss: 0.035710 Validation Loss: 0.033657
	 valid_loss decreased (0.035199 --> 0.033657)  saving model...
Epoch 9: Training Loss: 0.031703 Validation Loss: 0.037189
Epoch 10: Training Loss: 0.031481 Validation Loss: 0.039418
r