# Pytorch CNN image classification

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

In [2]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

## Import data

In [3]:
path_data = '/glade/u/home/ksha/torch_data/'
batch_size = 64
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [4]:
trainset = torchvision.datasets.CIFAR10(root=path_data, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=path_data, train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# dataiter = iter(trainloader)
# images, labels = next(dataiter)
# imshow(torchvision.utils.make_grid(images))

## Define model

In [6]:
class model_maker(nn.Module):
    def __init__(self, input_num, class_nums):
        super().__init__()
        
        self.filter_nums = [64, 128, 256]
        self.dense_nums = [256, 64]
        
        self.conv2d_layer0 = nn.Conv2d(input_num, self.filter_nums[0], kernel_size=3, padding='same')
        self.conv2d_layer1 = nn.Conv2d(self.filter_nums[0], self.filter_nums[1], kernel_size=3, padding='same')
        self.conv2d_layer2 = nn.Conv2d(self.filter_nums[1], self.filter_nums[2], kernel_size=3, padding='same')
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.dense_layer0 = nn.LazyLinear(self.dense_nums[0])
        self.dense_layer1 = nn.LazyLinear(self.dense_nums[1])
        
        self.dense_out = nn.LazyLinear(class_nums)

    def forward(self, x):

        x = self.conv2d_layer0(x)
        x = F.relu(x)
        x = self.pool(x)

        x = self.conv2d_layer1(x)
        x = F.relu(x)
        x = self.pool(x)

        x = self.conv2d_layer2(x)
        x = F.relu(x)
        x = self.pool(x)
        
        x = torch.flatten(x, 1)
        
        x = self.dense_layer0(x)
        x = F.relu(x)

        x = self.dense_layer1(x)
        x = F.relu(x)
        
        x = self.dense_out(x)
        
        return x
        
model = model_maker(input_num=3, class_nums=10)



In [7]:
model

model_maker(
  (conv2d_layer0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (conv2d_layer1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (conv2d_layer2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dense_layer0): LazyLinear(in_features=0, out_features=256, bias=True)
  (dense_layer1): LazyLinear(in_features=0, out_features=64, bias=True)
  (dense_out): LazyLinear(in_features=0, out_features=10, bias=True)
)

**Configuration**

In [12]:
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

## Training

In [13]:
save_dir = '/glade/work/ksha/torch_models/cifar_net.pth'

**Validation set**

In [14]:
L_valid = len(testset)
input_, y_true_ = testset[0]

grid_shape = np.array(input_).shape
input_test = np.empty((L_valid,)+grid_shape)
y_true = np.empty(L_valid)

for i in range(L_valid):
    input_, y_true_ = testset[i]
    input_test[i, ...] = np.array(input_)
    y_true[i] = y_true_

y_true = torch.from_numpy(y_true).long()
input_test = torch.from_numpy(input_test).float()

In [15]:
min_del = 0.0
max_tol = 3 # early stopping with 2-epoch patience
tol = 0

y_pred = model(input_test)
record = loss_func(y_pred, y_true).detach().numpy()
print('Initial loss: {}'.format(record))

for epoch in range(10):  # loop over the dataset multiple times

    running_loss = 0.0
    
    for i, data in enumerate(trainloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()
        
    # on-epoch-end validation
    y_pred = model(input_test)
    record_temp = loss_func(y_pred, y_true).detach().numpy()
    print('Validation loss: {}'.format(record_temp))
    
    if record - record_temp > min_del:
        print('Validation loss improved from {} to {}'.format(record, record_temp))
        record = record_temp
        print("Save to {}".format(save_dir))
        torch.save(model.state_dict(), save_dir)
        
    else:
        print('Validation loss {} NOT improved'.format(record_temp))

Initial loss: 2.3033981323242188
Validation loss: 2.247927188873291
Validation loss improved from 2.3033981323242188 to 2.247927188873291
Save to /glade/work/ksha/torch_models/cifar_net.pth
Validation loss: 1.853750228881836
Validation loss improved from 2.247927188873291 to 1.853750228881836
Save to /glade/work/ksha/torch_models/cifar_net.pth
Validation loss: 1.6193350553512573
Validation loss improved from 1.853750228881836 to 1.6193350553512573
Save to /glade/work/ksha/torch_models/cifar_net.pth
Validation loss: 1.488877773284912
Validation loss improved from 1.6193350553512573 to 1.488877773284912
Save to /glade/work/ksha/torch_models/cifar_net.pth
Validation loss: 1.3921916484832764
Validation loss improved from 1.488877773284912 to 1.3921916484832764
Save to /glade/work/ksha/torch_models/cifar_net.pth
Validation loss: 1.3202134370803833
Validation loss improved from 1.3921916484832764 to 1.3202134370803833
Save to /glade/work/ksha/torch_models/cifar_net.pth
Validation loss: 1.256

## Load model

In [14]:
model = model_maker()
model.load_state_dict(torch.load(save_dir))

<All keys matched successfully>

In [15]:
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=5e-4, momentum=0.9)

## Fine tuning

In [16]:
save_dir = '/glade/work/ksha/torch_models/cifar_net_tune.pth'

In [16]:
min_del = 0.0
max_tol = 3 # early stopping with 2-epoch patience
tol = 0

y_pred = model(input_test)
record = loss_func(y_pred, y_true).detach().numpy()
print('Initial loss: {}'.format(record))

for epoch in range(30):  # loop over the dataset multiple times

    running_loss = 0.0
    
    for i, data in enumerate(trainloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()
        
    # on-epoch-end validation
    y_pred = model(input_test)
    record_temp = loss_func(y_pred, y_true).detach().numpy()
    print('Validation loss: {}'.format(record_temp))
    
    if record - record_temp > min_del:
        print('Validation loss improved from {} to {}'.format(record, record_temp))
        record = record_temp
        print("Save to {}".format(save_dir))
        torch.save(model.state_dict(), save_dir)
        
    else:
        print('Validation loss {} NOT improved'.format(record_temp))