In [1]:
# import libraries
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter
from utils import device, get_num_correct
from vgg16modified import Network

In [2]:
# declare the transforms
data_transforms = {
    'train': transforms.Compose([
        # add augmentations
        transforms.ColorJitter(brightness=0.25, saturation=0.1),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
}

In [3]:
# choose the training and test datasets
train_set = torchvision.datasets.CIFAR10(
    root='./data/',
    train=True,
    download=True,
    transform=data_transforms['train']
)
test_set = torchvision.datasets.CIFAR10(
    root='./data/',
    train=False,
    download=True,
    transform=data_transforms['test']
)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
batch_size = 256
valid_size = 0.5  # percentage of test_set to be used as validation

# obtain training indices that will be used for validation
num_test = len(test_set)
indices = list(range(num_test))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_test))
test_idx, valid_idx = indices[split:], indices[:split]

# define samplers for obtaining training and validation batches
valid_sampler = SubsetRandomSampler(valid_idx)
test_sampler = SubsetRandomSampler(test_idx)

# prepare the data loaders
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=1)
valid_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, sampler=valid_sampler, num_workers=1)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, sampler=test_sampler, num_workers=1)

In [5]:
vgg16 = torchvision.models.vgg16(pretrained=True)
vgg16

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [6]:
# replace the vgg16 classifier
model = Network(vgg16)
model

Network(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilati

In [7]:
# transfer learning (first 8 layers of vgg16)
# freeze the transferred weights which won't be trained
for layer_num, child in enumerate(model.features.children()):
    if layer_num < 19:
        for param in child.parameters():
            param.requires_grad_(False)

In [8]:
model.to(device)

criterion = nn.CrossEntropyLoss()  # loss function (categorical cross-entropy)
optimizer = optim.SGD(
    [      # parameters which need optimization
        {'params':model.features[19:].parameters(), 'lr':0.001},
        {'params':model.classifier.parameters()}
    ], lr=0.01, momentum=0.9, weight_decay=1e-3)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1/3, patience=5, verbose=True) # lr scheduler

comment = f'-transferlr_vgg16'  # will be used for naming the run
tb = SummaryWriter(comment=comment)

# initialize tracker for minimum validation loss
valid_loss_min = np.Inf  # set initial minimum to infinity
num_epochs = 30  # number of epochs used for training

for epoch in range(num_epochs):
    train_loss, train_correct = 0, 0  # wil be used to track the running loss and correct
    ###################
    # train the model #
    ###################
    train_loop = tqdm(train_loader)
    model.train()  # set the model to train mode
    for batch in train_loop:
        images, labels = batch[0].to(device), batch[1].to(device)  # load the batch to the available device (cpu/gpu)
        preds = model(images)  # forward pass
        loss = criterion(preds, labels)  # calculate loss
        optimizer.zero_grad()  # clear accumulated gradients from the previous pass
        loss.backward()  # backward pass
        optimizer.step()  # perform a single optimization step

        train_loss += loss.item() * labels.size(0) # update the running loss
        train_correct += get_num_correct(preds, labels)  # update running num correct

        train_loop.set_description(f'Epoch [{epoch+1:2d}/{num_epochs}]')
        train_loop.set_postfix(loss=loss.item(), acc=train_correct/len(train_set))

    # add train loss and train accuracy for the current epoch to tensorboard
    tb.add_scalar('Train Loss', train_loss, epoch)
    tb.add_scalar('Train Accuracy', train_correct/len(train_set), epoch)

    model.eval()  # set the model to evaluation mode
    with torch.no_grad():  # turn off grad tracking, as we don't need gradients for validation
        valid_loss, valid_correct = 0, 0  # will be used to track the running validation loss and correct
        ######################
        # validate the model #
        ######################
        for batch in valid_loader:
            images, labels = batch[0].to(device), batch[1].to(device)  # load the batch to the available device
            preds = model(images)  # forward pass
            loss = criterion(preds, labels)  # calculate the loss

            valid_loss += loss.item() * labels.size(0)  # update the running loss
            valid_correct += get_num_correct(preds, labels)  # update running num correct
            

        # add validation loss and validation accuracy for the current epoch to tensorboard
        tb.add_scalar('Validation Loss', valid_loss, epoch)
        tb.add_scalar('Validation Accuracy', valid_correct/len(valid_loader.sampler), epoch)

        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = train_loss/len(train_set)
        valid_loss = valid_loss/len(valid_loader.sampler)
        train_loop.write(f'\t\tAvg training loss: {train_loss:.6f}\tAvg validation loss: {valid_loss:.6f}')
        scheduler.step(valid_loss)

        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            train_loop.write(f'\t\tvalid_loss decreased ({valid_loss_min:.6f} --> {valid_loss:.6f})  saving model...')
            torch.save(model.state_dict(), f'./models/model{comment}.pth')
            valid_loss_min = valid_loss


        test_loss, test_correct = 0, 0  # will be used to track the running test loss and correct
        ##################
        # test the model #
        ##################
        for batch in test_loader:
            images, labels = batch[0].to(device), batch[1].to(device)  # load the batch to the available device
            preds = model(images)  # forward pass
            loss = criterion(preds, labels)  # calculate the loss

            test_loss += loss.item() * labels.size(0)  # update the running loss
            test_correct += get_num_correct(preds, labels)  # update running num correct

        # add test loss and test accuracy for the current epoch to tensorboard
        tb.add_scalar('Test Loss', test_loss, epoch)
        tb.add_scalar('Test Accuracy', test_correct/len(test_loader.sampler), epoch)


Epoch [ 1/30]: 100%|██████████| 196/196 [00:42<00:00,  4.58it/s, acc=0.555, loss=0.899]


		Avg training loss: 1.258526	Avg validation loss: 0.798657
		valid_loss decreased (inf --> 0.798657)  saving model...


Epoch [ 2/30]: 100%|██████████| 196/196 [00:42<00:00,  4.61it/s, acc=0.716, loss=0.674]


		Avg training loss: 0.808466	Avg validation loss: 0.695511
		valid_loss decreased (0.798657 --> 0.695511)  saving model...


Epoch [ 3/30]: 100%|██████████| 196/196 [00:42<00:00,  4.60it/s, acc=0.749, loss=0.709]


		Avg training loss: 0.716494	Avg validation loss: 0.607289
		valid_loss decreased (0.695511 --> 0.607289)  saving model...


Epoch [ 4/30]: 100%|██████████| 196/196 [00:41<00:00,  4.68it/s, acc=0.767, loss=0.586]


		Avg training loss: 0.663401	Avg validation loss: 0.590412
		valid_loss decreased (0.607289 --> 0.590412)  saving model...


Epoch [ 5/30]: 100%|██████████| 196/196 [00:41<00:00,  4.69it/s, acc=0.782, loss=0.676]


		Avg training loss: 0.621336	Avg validation loss: 0.545776
		valid_loss decreased (0.590412 --> 0.545776)  saving model...


Epoch [ 6/30]: 100%|██████████| 196/196 [00:43<00:00,  4.54it/s, acc=0.796, loss=0.637]


		Avg training loss: 0.581307	Avg validation loss: 0.540139
		valid_loss decreased (0.545776 --> 0.540139)  saving model...


Epoch [ 7/30]: 100%|██████████| 196/196 [00:43<00:00,  4.55it/s, acc=0.802, loss=0.393]


		Avg training loss: 0.555703	Avg validation loss: 0.506904
		valid_loss decreased (0.540139 --> 0.506904)  saving model...


Epoch [ 8/30]: 100%|██████████| 196/196 [00:42<00:00,  4.62it/s, acc=0.81, loss=0.413]


		Avg training loss: 0.540095	Avg validation loss: 0.504078
		valid_loss decreased (0.506904 --> 0.504078)  saving model...


Epoch [ 9/30]: 100%|██████████| 196/196 [00:42<00:00,  4.62it/s, acc=0.819, loss=0.516]


		Avg training loss: 0.511801	Avg validation loss: 0.499604
		valid_loss decreased (0.504078 --> 0.499604)  saving model...


Epoch [10/30]: 100%|██████████| 196/196 [00:42<00:00,  4.65it/s, acc=0.823, loss=0.419]


		Avg training loss: 0.498117	Avg validation loss: 0.485179
		valid_loss decreased (0.499604 --> 0.485179)  saving model...


Epoch [11/30]: 100%|██████████| 196/196 [00:42<00:00,  4.64it/s, acc=0.831, loss=0.484]


		Avg training loss: 0.475245	Avg validation loss: 0.477229
		valid_loss decreased (0.485179 --> 0.477229)  saving model...


Epoch [12/30]: 100%|██████████| 196/196 [00:42<00:00,  4.60it/s, acc=0.835, loss=0.579]


		Avg training loss: 0.464521	Avg validation loss: 0.477346


Epoch [13/30]: 100%|██████████| 196/196 [00:42<00:00,  4.60it/s, acc=0.841, loss=0.462]


		Avg training loss: 0.446889	Avg validation loss: 0.463188
		valid_loss decreased (0.477229 --> 0.463188)  saving model...


Epoch [14/30]: 100%|██████████| 196/196 [00:42<00:00,  4.59it/s, acc=0.843, loss=0.351]


		Avg training loss: 0.445140	Avg validation loss: 0.451549
		valid_loss decreased (0.463188 --> 0.451549)  saving model...


Epoch [15/30]: 100%|██████████| 196/196 [00:42<00:00,  4.61it/s, acc=0.849, loss=0.343]


		Avg training loss: 0.425136	Avg validation loss: 0.450946
		valid_loss decreased (0.451549 --> 0.450946)  saving model...


Epoch [16/30]: 100%|██████████| 196/196 [00:42<00:00,  4.58it/s, acc=0.855, loss=0.377]


		Avg training loss: 0.411279	Avg validation loss: 0.452803


Epoch [17/30]: 100%|██████████| 196/196 [00:42<00:00,  4.58it/s, acc=0.857, loss=0.362]


		Avg training loss: 0.399843	Avg validation loss: 0.453718


Epoch [18/30]: 100%|██████████| 196/196 [00:42<00:00,  4.58it/s, acc=0.86, loss=0.502]


		Avg training loss: 0.392495	Avg validation loss: 0.435896
		valid_loss decreased (0.450946 --> 0.435896)  saving model...


Epoch [19/30]: 100%|██████████| 196/196 [00:42<00:00,  4.58it/s, acc=0.865, loss=0.334]


		Avg training loss: 0.378240	Avg validation loss: 0.444111


Epoch [20/30]: 100%|██████████| 196/196 [00:42<00:00,  4.63it/s, acc=0.867, loss=0.592]


		Avg training loss: 0.373910	Avg validation loss: 0.432857
		valid_loss decreased (0.435896 --> 0.432857)  saving model...


Epoch [21/30]: 100%|██████████| 196/196 [00:42<00:00,  4.65it/s, acc=0.869, loss=0.319]


		Avg training loss: 0.363137	Avg validation loss: 0.427644
		valid_loss decreased (0.432857 --> 0.427644)  saving model...


Epoch [22/30]: 100%|██████████| 196/196 [00:42<00:00,  4.64it/s, acc=0.875, loss=0.406]


		Avg training loss: 0.350984	Avg validation loss: 0.448720


Epoch [23/30]: 100%|██████████| 196/196 [00:42<00:00,  4.63it/s, acc=0.875, loss=0.373]


		Avg training loss: 0.347980	Avg validation loss: 0.439049


Epoch [24/30]: 100%|██████████| 196/196 [00:42<00:00,  4.66it/s, acc=0.883, loss=0.419]


		Avg training loss: 0.326903	Avg validation loss: 0.424339
		valid_loss decreased (0.427644 --> 0.424339)  saving model...


Epoch [25/30]: 100%|██████████| 196/196 [00:42<00:00,  4.62it/s, acc=0.885, loss=0.2]


		Avg training loss: 0.319745	Avg validation loss: 0.440760


Epoch [26/30]: 100%|██████████| 196/196 [00:42<00:00,  4.61it/s, acc=0.886, loss=0.363]


		Avg training loss: 0.322240	Avg validation loss: 0.426965


Epoch [27/30]: 100%|██████████| 196/196 [00:42<00:00,  4.59it/s, acc=0.889, loss=0.302]


		Avg training loss: 0.307315	Avg validation loss: 0.415051
		valid_loss decreased (0.424339 --> 0.415051)  saving model...


Epoch [28/30]: 100%|██████████| 196/196 [00:43<00:00,  4.56it/s, acc=0.893, loss=0.279]


		Avg training loss: 0.298390	Avg validation loss: 0.443596


Epoch [29/30]: 100%|██████████| 196/196 [00:42<00:00,  4.57it/s, acc=0.895, loss=0.263]


		Avg training loss: 0.294750	Avg validation loss: 0.443056


Epoch [30/30]: 100%|██████████| 196/196 [00:43<00:00,  4.55it/s, acc=0.895, loss=0.193]


		Avg training loss: 0.290253	Avg validation loss: 0.441863
