In [1]:
# import libraries
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter
from utils import device, get_num_correct

In [2]:
# declare the transforms
data_transforms = {
    'train': transforms.Compose([
        # add augmentations
        transforms.ColorJitter(brightness=0.25, saturation=0.1),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        # The output of torchvision datasets are PILImage images of range [0, 1].
        # We transform them to Tensors of normalized range [-1, 1]
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])
}

In [3]:
# choose the training and test datasets
train_set = torchvision.datasets.CIFAR10(
    root='./data/',
    train=True,
    download=True,
    transform=data_transforms['train']
)
test_set = torchvision.datasets.CIFAR10(
    root='./data/',
    train=False,
    download=True,
    transform=data_transforms['test']
)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
batch_size = 256
valid_size = 0.5  # percentage of test_set to be used as validation

# obtain training indices that will be used for validation
num_test = len(test_set)
indices = list(range(num_test))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_test))
test_idx, valid_idx = indices[split:], indices[:split]

# define samplers for obtaining training and validation batches
valid_sampler = SubsetRandomSampler(valid_idx)
test_sampler = SubsetRandomSampler(test_idx)

# prepare the data loaders
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=1)
valid_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, sampler=valid_sampler, num_workers=1)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, sampler=test_sampler, num_workers=1)

In [5]:
vgg16 = torchvision.models.vgg16(pretrained=True)
vgg16

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [6]:
# modify the network as per requirement
class vgg16modified(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(*list(vgg16.features.children())[:19])
        self.need_train = nn.Sequential(*list(vgg16.features.children())[19:])
        self.classifier = nn.Sequential(
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=128, out_features=10)
            )
        
    def forward(self, t):
        t = self.features(t)
        t = self.need_train(t)
        t = torch.flatten(t, 1)
        t = self.classifier(t)

        return t

In [7]:
model = vgg16modified()
model

vgg16modified(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, 

In [8]:
# freeze the parameters which don't need training
for param in model.features.parameters():
    param.requires_grad = False

In [9]:
model.to(device)

lr = 0.1
criterion = nn.CrossEntropyLoss()  # loss function (categorical cross-entropy)
params = list(model.need_train.parameters()) + list(model.classifier.parameters())
optimizer = optim.SGD(params, lr=lr)  # specify the optimizer
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1/3, patience=5, verbose=True) # lr scheduler

comment = f'-transferlr_vgg16(bsize={batch_size})'  # will be used for naming the run
tb = SummaryWriter(comment=comment)

# initialize tracker for minimum validation loss
valid_loss_min = np.Inf  # set initial minimum to infinity
num_epochs = 50  # number of epochs used for training

for epoch in range(num_epochs):
    train_loss, train_correct = 0, 0  # wil be used to track the running loss and correct
    ###################
    # train the model #
    ###################
    train_loop = tqdm(train_loader, total=len(train_loader))
    model.train()  # set the model to train mode
    for batch in train_loop:
        images, labels = batch[0].to(device), batch[1].to(device)  # load the batch to the available device (cpu/gpu)
        preds = model(images)  # forward pass
        loss = criterion(preds, labels)  # calculate loss
        optimizer.zero_grad()  # clear accumulated gradients from the previous pass
        loss.backward()  # backward pass
        optimizer.step()  # perform a single optimization step

        train_loss += loss.item() * labels.size(0) # update the running loss
        train_correct += get_num_correct(preds, labels)  # update running num correct

        train_loop.set_description(f'Epoch [{epoch+1}/{num_epochs}]')
        train_loop.set_postfix(loss=loss.item(), acc=train_correct/len(train_set))

    # add train loss and train accuracy for the current epoch to tensorboard
    tb.add_scalar('Train Loss', train_loss, epoch)
    tb.add_scalar('Train Accuracy', train_correct/len(train_set), epoch)

    model.eval()  # set the model to evaluation mode
    with torch.no_grad():  # turn off grad tracking, as we don't need gradients for validation
        valid_loss, valid_correct = 0, 0  # will be used to track the running validation loss and correct
        ######################
        # validate the model #
        ######################
        for batch in valid_loader:
            images, labels = batch[0].to(device), batch[1].to(device)  # load the batch to the available device
            preds = model(images)  # forward pass
            loss = criterion(preds, labels)  # calculate the loss

            valid_loss += loss.item() * labels.size(0)  # update the running loss
            valid_correct += get_num_correct(preds, labels)  # update running num correct
            

        # add validation loss and validation accuracy for the current epoch to tensorboard
        tb.add_scalar('Validation Loss', valid_loss, epoch)
        tb.add_scalar('Validation Accuracy', valid_correct/len(valid_loader.sampler), epoch)

        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = train_loss/len(train_set)
        valid_loss = valid_loss/len(valid_loader.sampler)
        train_loop.write(f'\t\tAvg training loss: {train_loss:.6f}\tAvg validation loss: {valid_loss:.6f}')
        scheduler.step(valid_loss)

        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            train_loop.write(f'\t\tvalid_loss decreased ({valid_loss_min:.6f} --> {valid_loss:.6f})  saving model...')
            torch.save(model.state_dict(), f'./models/model{comment}.pth')
            valid_loss_min = valid_loss


        test_loss, test_correct = 0, 0  # will be used to track the running test loss and correct
        ##################
        # test the model #
        ##################
        for batch in test_loader:
            images, labels = batch[0].to(device), batch[1].to(device)  # load the batch to the available device
            preds = model(images)  # forward pass
            loss = criterion(preds, labels)  # calculate the loss

            test_loss += loss.item() * labels.size(0)  # update the running loss
            test_correct += get_num_correct(preds, labels)  # update running num correct

        # add test loss and test accuracy for the current epoch to tensorboard
        tb.add_scalar('Test Loss', test_loss, epoch)
        tb.add_scalar('Test Accuracy', test_correct/len(test_loader.sampler), epoch)


Epoch [1/50]: 100%|██████████| 196/196 [00:41<00:00,  4.69it/s, acc=0.425, loss=1.05]


		Avg training loss: 1.609961	Avg validation loss: 1.191050
		valid_loss decreased (inf --> 1.191050)  saving model...


Epoch [2/50]: 100%|██████████| 196/196 [00:41<00:00,  4.73it/s, acc=0.696, loss=0.707]


		Avg training loss: 0.895720	Avg validation loss: 0.662040
		valid_loss decreased (1.191050 --> 0.662040)  saving model...


Epoch [3/50]: 100%|██████████| 196/196 [00:41<00:00,  4.74it/s, acc=0.764, loss=0.664]


		Avg training loss: 0.691515	Avg validation loss: 1.184928


Epoch [4/50]: 100%|██████████| 196/196 [00:40<00:00,  4.79it/s, acc=0.788, loss=0.607]


		Avg training loss: 0.614401	Avg validation loss: 0.548118
		valid_loss decreased (0.662040 --> 0.548118)  saving model...


Epoch [5/50]: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s, acc=0.807, loss=0.426]


		Avg training loss: 0.552967	Avg validation loss: 0.568714


Epoch [6/50]: 100%|██████████| 196/196 [00:41<00:00,  4.74it/s, acc=0.824, loss=0.511]


		Avg training loss: 0.507663	Avg validation loss: 0.516603
		valid_loss decreased (0.548118 --> 0.516603)  saving model...


Epoch [7/50]: 100%|██████████| 196/196 [00:41<00:00,  4.74it/s, acc=0.837, loss=0.614]


		Avg training loss: 0.467517	Avg validation loss: 0.601821


Epoch [8/50]: 100%|██████████| 196/196 [00:41<00:00,  4.74it/s, acc=0.849, loss=0.392]


		Avg training loss: 0.430958	Avg validation loss: 0.524161


Epoch [9/50]: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s, acc=0.859, loss=0.259]


		Avg training loss: 0.402474	Avg validation loss: 0.454039
		valid_loss decreased (0.516603 --> 0.454039)  saving model...


Epoch [10/50]: 100%|██████████| 196/196 [00:41<00:00,  4.74it/s, acc=0.867, loss=0.347]


		Avg training loss: 0.379196	Avg validation loss: 0.492217


Epoch [11/50]: 100%|██████████| 196/196 [00:41<00:00,  4.74it/s, acc=0.875, loss=0.333]


		Avg training loss: 0.354135	Avg validation loss: 0.479978


Epoch [12/50]: 100%|██████████| 196/196 [00:41<00:00,  4.75it/s, acc=0.883, loss=0.39]


		Avg training loss: 0.330846	Avg validation loss: 0.577499


Epoch [13/50]: 100%|██████████| 196/196 [00:41<00:00,  4.75it/s, acc=0.892, loss=0.239]


		Avg training loss: 0.310191	Avg validation loss: 0.477957


Epoch [14/50]: 100%|██████████| 196/196 [00:41<00:00,  4.75it/s, acc=0.899, loss=0.376]


		Avg training loss: 0.286275	Avg validation loss: 0.728352


Epoch [15/50]: 100%|██████████| 196/196 [00:41<00:00,  4.73it/s, acc=0.906, loss=0.191]


		Avg training loss: 0.269341	Avg validation loss: 0.488498
Epoch    15: reducing learning rate of group 0 to 3.3333e-02.


Epoch [16/50]: 100%|██████████| 196/196 [00:41<00:00,  4.72it/s, acc=0.934, loss=0.175]


		Avg training loss: 0.187682	Avg validation loss: 0.459692


Epoch [17/50]: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s, acc=0.94, loss=0.143]


		Avg training loss: 0.171454	Avg validation loss: 0.469020


Epoch [18/50]: 100%|██████████| 196/196 [00:41<00:00,  4.75it/s, acc=0.942, loss=0.166]


		Avg training loss: 0.164494	Avg validation loss: 0.451573
		valid_loss decreased (0.454039 --> 0.451573)  saving model...


Epoch [19/50]: 100%|██████████| 196/196 [00:41<00:00,  4.77it/s, acc=0.947, loss=0.148]


		Avg training loss: 0.151451	Avg validation loss: 0.477853


Epoch [20/50]: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s, acc=0.95, loss=0.158]


		Avg training loss: 0.141050	Avg validation loss: 0.501092


Epoch [21/50]: 100%|██████████| 196/196 [00:40<00:00,  4.79it/s, acc=0.952, loss=0.215]


		Avg training loss: 0.136179	Avg validation loss: 0.493665


Epoch [22/50]: 100%|██████████| 196/196 [00:41<00:00,  4.74it/s, acc=0.954, loss=0.0924]


		Avg training loss: 0.128966	Avg validation loss: 0.496725


Epoch [23/50]: 100%|██████████| 196/196 [00:40<00:00,  4.79it/s, acc=0.957, loss=0.159]


		Avg training loss: 0.123065	Avg validation loss: 0.564564


Epoch [24/50]: 100%|██████████| 196/196 [00:41<00:00,  4.77it/s, acc=0.958, loss=0.04]


		Avg training loss: 0.117405	Avg validation loss: 0.504744
Epoch    24: reducing learning rate of group 0 to 1.1111e-02.


Epoch [25/50]: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s, acc=0.967, loss=0.057]


		Avg training loss: 0.097268	Avg validation loss: 0.512511


Epoch [26/50]: 100%|██████████| 196/196 [00:41<00:00,  4.75it/s, acc=0.969, loss=0.0503]


		Avg training loss: 0.088173	Avg validation loss: 0.528223


Epoch [27/50]: 100%|██████████| 196/196 [00:40<00:00,  4.78it/s, acc=0.97, loss=0.0578]


		Avg training loss: 0.084758	Avg validation loss: 0.531362


Epoch [28/50]: 100%|██████████| 196/196 [00:41<00:00,  4.75it/s, acc=0.97, loss=0.0941]


		Avg training loss: 0.086011	Avg validation loss: 0.533336


Epoch [29/50]: 100%|██████████| 196/196 [00:41<00:00,  4.69it/s, acc=0.971, loss=0.0555]


		Avg training loss: 0.082620	Avg validation loss: 0.539577


Epoch [30/50]: 100%|██████████| 196/196 [00:41<00:00,  4.73it/s, acc=0.972, loss=0.0419]


		Avg training loss: 0.080705	Avg validation loss: 0.548233
Epoch    30: reducing learning rate of group 0 to 3.7037e-03.


Epoch [31/50]: 100%|██████████| 196/196 [00:41<00:00,  4.74it/s, acc=0.973, loss=0.059]


		Avg training loss: 0.076825	Avg validation loss: 0.553137


Epoch [32/50]: 100%|██████████| 196/196 [00:41<00:00,  4.73it/s, acc=0.975, loss=0.0886]


		Avg training loss: 0.071920	Avg validation loss: 0.553377


Epoch [33/50]: 100%|██████████| 196/196 [00:41<00:00,  4.75it/s, acc=0.975, loss=0.144]


		Avg training loss: 0.070319	Avg validation loss: 0.555232


Epoch [34/50]: 100%|██████████| 196/196 [00:41<00:00,  4.74it/s, acc=0.975, loss=0.0405]


		Avg training loss: 0.072448	Avg validation loss: 0.558689


Epoch [35/50]: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s, acc=0.976, loss=0.0888]


		Avg training loss: 0.068376	Avg validation loss: 0.564666


Epoch [36/50]: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s, acc=0.977, loss=0.0172]


		Avg training loss: 0.069103	Avg validation loss: 0.564100
Epoch    36: reducing learning rate of group 0 to 1.2346e-03.


Epoch [37/50]: 100%|██████████| 196/196 [00:40<00:00,  4.80it/s, acc=0.977, loss=0.106]


		Avg training loss: 0.066935	Avg validation loss: 0.564717


Epoch [38/50]: 100%|██████████| 196/196 [00:40<00:00,  4.81it/s, acc=0.977, loss=0.0668]


		Avg training loss: 0.065386	Avg validation loss: 0.565944


Epoch [39/50]: 100%|██████████| 196/196 [00:41<00:00,  4.75it/s, acc=0.978, loss=0.0139]


		Avg training loss: 0.067508	Avg validation loss: 0.566170


Epoch [40/50]: 100%|██████████| 196/196 [00:41<00:00,  4.77it/s, acc=0.977, loss=0.0643]


		Avg training loss: 0.065646	Avg validation loss: 0.570024


Epoch [41/50]: 100%|██████████| 196/196 [00:42<00:00,  4.66it/s, acc=0.978, loss=0.0587]


		Avg training loss: 0.065055	Avg validation loss: 0.570301


Epoch [42/50]: 100%|██████████| 196/196 [00:40<00:00,  4.80it/s, acc=0.978, loss=0.0461]


		Avg training loss: 0.064315	Avg validation loss: 0.573000
Epoch    42: reducing learning rate of group 0 to 4.1152e-04.


Epoch [43/50]: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s, acc=0.977, loss=0.032]


		Avg training loss: 0.064879	Avg validation loss: 0.573100


Epoch [44/50]: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s, acc=0.978, loss=0.0368]


		Avg training loss: 0.063979	Avg validation loss: 0.572865


Epoch [45/50]: 100%|██████████| 196/196 [00:41<00:00,  4.74it/s, acc=0.979, loss=0.0785]


		Avg training loss: 0.063747	Avg validation loss: 0.573190


Epoch [46/50]: 100%|██████████| 196/196 [00:40<00:00,  4.78it/s, acc=0.978, loss=0.0301]


		Avg training loss: 0.065863	Avg validation loss: 0.574302


Epoch [47/50]: 100%|██████████| 196/196 [00:42<00:00,  4.57it/s, acc=0.979, loss=0.0343]


		Avg training loss: 0.062233	Avg validation loss: 0.575469


Epoch [48/50]: 100%|██████████| 196/196 [00:43<00:00,  4.56it/s, acc=0.978, loss=0.0723]


		Avg training loss: 0.064981	Avg validation loss: 0.575820
Epoch    48: reducing learning rate of group 0 to 1.3717e-04.


Epoch [49/50]: 100%|██████████| 196/196 [00:42<00:00,  4.59it/s, acc=0.978, loss=0.0271]


		Avg training loss: 0.064553	Avg validation loss: 0.575631


Epoch [50/50]: 100%|██████████| 196/196 [00:42<00:00,  4.58it/s, acc=0.979, loss=0.0458]


		Avg training loss: 0.063950	Avg validation loss: 0.575285
