In [1]:
# import libraries
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter
from utils import device, get_num_correct
from vgg16modified import Network

In [2]:
# declare the transforms
data_transforms = {
    'train': transforms.Compose([
        # add augmentations
        transforms.ColorJitter(brightness=0.25, saturation=0.1),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        # The output of torchvision datasets are PILImage images of range [0, 1].
        # We transform them to Tensors of normalized range [-1, 1]
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])
}

In [3]:
# choose the training and test datasets
train_set = torchvision.datasets.CIFAR10(
    root='./data/',
    train=True,
    download=True,
    transform=data_transforms['train']
)
test_set = torchvision.datasets.CIFAR10(
    root='./data/',
    train=False,
    download=True,
    transform=data_transforms['test']
)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
batch_size = 256
valid_size = 0.5  # percentage of test_set to be used as validation

# obtain training indices that will be used for validation
num_test = len(test_set)
indices = list(range(num_test))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_test))
test_idx, valid_idx = indices[split:], indices[:split]

# define samplers for obtaining training and validation batches
valid_sampler = SubsetRandomSampler(valid_idx)
test_sampler = SubsetRandomSampler(test_idx)

# prepare the data loaders
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=1)
valid_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, sampler=valid_sampler, num_workers=1)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, sampler=test_sampler, num_workers=1)

In [5]:
vgg16 = torchvision.models.vgg16(pretrained=True)
vgg16

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [6]:
# replace the vgg16 classifier
model = Network(vgg16)
model

Network(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilati

In [7]:
# transfer learning (first 5 layers of vgg16)
# freeze the transferred weights which won't be trained
for layer_num, child in enumerate(model.features.children()):
    if layer_num < 19:
        for param in child.parameters():
            param.requires_grad = False

In [8]:
model.to(device)

criterion = nn.CrossEntropyLoss()  # loss function (categorical cross-entropy)
optimizer = optim.SGD(
    [      # parameters which need optimization
        {'params':model.features[19:].parameters(), 'lr':0.03},
        {'params':model.classifier.parameters()}
    ], lr=0.1, momentum=0.9, weight_decay=1e-3)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1/3, patience=5, verbose=True) # lr scheduler

comment = f'-transferlr_vgg16'  # will be used for naming the run
tb = SummaryWriter(comment=comment)

# initialize tracker for minimum validation loss
valid_loss_min = np.Inf  # set initial minimum to infinity
num_epochs = 30  # number of epochs used for training

for epoch in range(num_epochs):
    train_loss, train_correct = 0, 0  # wil be used to track the running loss and correct
    ###################
    # train the model #
    ###################
    train_loop = tqdm(train_loader, total=len(train_loader))
    model.train()  # set the model to train mode
    for batch in train_loop:
        images, labels = batch[0].to(device), batch[1].to(device)  # load the batch to the available device (cpu/gpu)
        preds = model(images)  # forward pass
        loss = criterion(preds, labels)  # calculate loss
        optimizer.zero_grad()  # clear accumulated gradients from the previous pass
        loss.backward()  # backward pass
        optimizer.step()  # perform a single optimization step

        train_loss += loss.item() * labels.size(0) # update the running loss
        train_correct += get_num_correct(preds, labels)  # update running num correct

        train_loop.set_description(f'Epoch [{epoch+1:2d}/{num_epochs}]')
        train_loop.set_postfix(loss=loss.item(), acc=train_correct/len(train_set))

    # add train loss and train accuracy for the current epoch to tensorboard
    tb.add_scalar('Train Loss', train_loss, epoch)
    tb.add_scalar('Train Accuracy', train_correct/len(train_set), epoch)

    model.eval()  # set the model to evaluation mode
    with torch.no_grad():  # turn off grad tracking, as we don't need gradients for validation
        valid_loss, valid_correct = 0, 0  # will be used to track the running validation loss and correct
        ######################
        # validate the model #
        ######################
        for batch in valid_loader:
            images, labels = batch[0].to(device), batch[1].to(device)  # load the batch to the available device
            preds = model(images)  # forward pass
            loss = criterion(preds, labels)  # calculate the loss

            valid_loss += loss.item() * labels.size(0)  # update the running loss
            valid_correct += get_num_correct(preds, labels)  # update running num correct
            

        # add validation loss and validation accuracy for the current epoch to tensorboard
        tb.add_scalar('Validation Loss', valid_loss, epoch)
        tb.add_scalar('Validation Accuracy', valid_correct/len(valid_loader.sampler), epoch)

        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = train_loss/len(train_set)
        valid_loss = valid_loss/len(valid_loader.sampler)
        train_loop.write(f'\t\tAvg training loss: {train_loss:.6f}\tAvg validation loss: {valid_loss:.6f}')
        scheduler.step(valid_loss)

        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            train_loop.write(f'\t\tvalid_loss decreased ({valid_loss_min:.6f} --> {valid_loss:.6f})  saving model...')
            torch.save(model.state_dict(), f'./models/model{comment}.pth')
            valid_loss_min = valid_loss


        test_loss, test_correct = 0, 0  # will be used to track the running test loss and correct
        ##################
        # test the model #
        ##################
        for batch in test_loader:
            images, labels = batch[0].to(device), batch[1].to(device)  # load the batch to the available device
            preds = model(images)  # forward pass
            loss = criterion(preds, labels)  # calculate the loss

            test_loss += loss.item() * labels.size(0)  # update the running loss
            test_correct += get_num_correct(preds, labels)  # update running num correct

        # add test loss and test accuracy for the current epoch to tensorboard
        tb.add_scalar('Test Loss', test_loss, epoch)
        tb.add_scalar('Test Accuracy', test_correct/len(test_loader.sampler), epoch)


Epoch [ 1/30]: 100%|██████████| 196/196 [00:42<00:00,  4.56it/s, acc=0.639, loss=0.904]


		Avg training loss: 1.060830	Avg validation loss: 0.725085
		valid_loss decreased (inf --> 0.725085)  saving model...


Epoch [ 2/30]: 100%|██████████| 196/196 [00:42<00:00,  4.61it/s, acc=0.767, loss=0.804]


		Avg training loss: 0.687660	Avg validation loss: 0.712803
		valid_loss decreased (0.725085 --> 0.712803)  saving model...


Epoch [ 3/30]: 100%|██████████| 196/196 [00:42<00:00,  4.64it/s, acc=0.802, loss=0.633]


		Avg training loss: 0.581167	Avg validation loss: 0.486661
		valid_loss decreased (0.712803 --> 0.486661)  saving model...


Epoch [ 4/30]: 100%|██████████| 196/196 [00:41<00:00,  4.71it/s, acc=0.82, loss=0.595]


		Avg training loss: 0.528874	Avg validation loss: 0.513192


Epoch [ 5/30]: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s, acc=0.836, loss=0.466]


		Avg training loss: 0.485242	Avg validation loss: 0.468349
		valid_loss decreased (0.486661 --> 0.468349)  saving model...


Epoch [ 6/30]: 100%|██████████| 196/196 [00:41<00:00,  4.77it/s, acc=0.847, loss=0.422]


		Avg training loss: 0.447901	Avg validation loss: 0.532781


Epoch [ 7/30]: 100%|██████████| 196/196 [00:41<00:00,  4.78it/s, acc=0.851, loss=0.498]


		Avg training loss: 0.435389	Avg validation loss: 0.461138
		valid_loss decreased (0.468349 --> 0.461138)  saving model...


Epoch [ 8/30]: 100%|██████████| 196/196 [00:41<00:00,  4.77it/s, acc=0.862, loss=0.287]


		Avg training loss: 0.408889	Avg validation loss: 0.445688
		valid_loss decreased (0.461138 --> 0.445688)  saving model...


Epoch [ 9/30]: 100%|██████████| 196/196 [00:40<00:00,  4.80it/s, acc=0.866, loss=0.488]


		Avg training loss: 0.390850	Avg validation loss: 0.432318
		valid_loss decreased (0.445688 --> 0.432318)  saving model...


Epoch [10/30]: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s, acc=0.874, loss=0.63]


		Avg training loss: 0.371673	Avg validation loss: 0.436332


Epoch [11/30]: 100%|██████████| 196/196 [00:41<00:00,  4.74it/s, acc=0.878, loss=0.281]


		Avg training loss: 0.356769	Avg validation loss: 0.432596


Epoch [12/30]: 100%|██████████| 196/196 [00:41<00:00,  4.73it/s, acc=0.882, loss=0.247]


		Avg training loss: 0.343391	Avg validation loss: 0.432879


Epoch [13/30]: 100%|██████████| 196/196 [00:41<00:00,  4.75it/s, acc=0.885, loss=0.264]


		Avg training loss: 0.332500	Avg validation loss: 0.435752


Epoch [14/30]: 100%|██████████| 196/196 [00:41<00:00,  4.74it/s, acc=0.891, loss=0.275]


		Avg training loss: 0.322740	Avg validation loss: 0.433673


Epoch [15/30]: 100%|██████████| 196/196 [00:41<00:00,  4.75it/s, acc=0.895, loss=0.301]


		Avg training loss: 0.309547	Avg validation loss: 0.494816
Epoch    15: reducing learning rate of group 0 to 1.0000e-02.
Epoch    15: reducing learning rate of group 1 to 3.3333e-02.


Epoch [16/30]: 100%|██████████| 196/196 [00:41<00:00,  4.72it/s, acc=0.932, loss=0.151]


		Avg training loss: 0.197079	Avg validation loss: 0.358085
		valid_loss decreased (0.432318 --> 0.358085)  saving model...


Epoch [17/30]: 100%|██████████| 196/196 [00:41<00:00,  4.78it/s, acc=0.944, loss=0.133]


		Avg training loss: 0.161911	Avg validation loss: 0.386529


Epoch [18/30]: 100%|██████████| 196/196 [00:40<00:00,  4.84it/s, acc=0.95, loss=0.24]


		Avg training loss: 0.143179	Avg validation loss: 0.407683


Epoch [19/30]: 100%|██████████| 196/196 [00:40<00:00,  4.84it/s, acc=0.953, loss=0.123]


		Avg training loss: 0.137749	Avg validation loss: 0.397495


Epoch [20/30]: 100%|██████████| 196/196 [00:39<00:00,  4.92it/s, acc=0.956, loss=0.0684]


		Avg training loss: 0.129278	Avg validation loss: 0.417324


Epoch [21/30]: 100%|██████████| 196/196 [00:40<00:00,  4.88it/s, acc=0.958, loss=0.11]


		Avg training loss: 0.122153	Avg validation loss: 0.428293


Epoch [22/30]: 100%|██████████| 196/196 [00:39<00:00,  4.93it/s, acc=0.962, loss=0.0788]


		Avg training loss: 0.114404	Avg validation loss: 0.401970
Epoch    22: reducing learning rate of group 0 to 3.3333e-03.
Epoch    22: reducing learning rate of group 1 to 1.1111e-02.


Epoch [23/30]: 100%|██████████| 196/196 [00:40<00:00,  4.87it/s, acc=0.973, loss=0.0356]


		Avg training loss: 0.078829	Avg validation loss: 0.394227


Epoch [24/30]: 100%|██████████| 196/196 [00:39<00:00,  4.91it/s, acc=0.977, loss=0.0903]


		Avg training loss: 0.067433	Avg validation loss: 0.414836


Epoch [25/30]: 100%|██████████| 196/196 [00:40<00:00,  4.88it/s, acc=0.979, loss=0.0569]


		Avg training loss: 0.059944	Avg validation loss: 0.407773


Epoch [26/30]: 100%|██████████| 196/196 [00:39<00:00,  4.97it/s, acc=0.982, loss=0.0328]


		Avg training loss: 0.053937	Avg validation loss: 0.416741


Epoch [27/30]: 100%|██████████| 196/196 [00:39<00:00,  4.92it/s, acc=0.983, loss=0.0129]


		Avg training loss: 0.049537	Avg validation loss: 0.423936


Epoch [28/30]: 100%|██████████| 196/196 [00:39<00:00,  4.96it/s, acc=0.984, loss=0.0377]


		Avg training loss: 0.048497	Avg validation loss: 0.440331
Epoch    28: reducing learning rate of group 0 to 1.1111e-03.
Epoch    28: reducing learning rate of group 1 to 3.7037e-03.


Epoch [29/30]: 100%|██████████| 196/196 [00:39<00:00,  4.92it/s, acc=0.987, loss=0.0769]


		Avg training loss: 0.039893	Avg validation loss: 0.435346


Epoch [30/30]: 100%|██████████| 196/196 [00:39<00:00,  4.93it/s, acc=0.988, loss=0.0275]


		Avg training loss: 0.038943	Avg validation loss: 0.436046
