# Pytorch CNN image classification

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

In [2]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

## Import data

In [4]:
path_data = '/glade/u/home/ksha/torch_data/'
batch_size = 64
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [5]:
trainset = torchvision.datasets.CIFAR10(root=path_data, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=path_data, train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
# dataiter = iter(trainloader)
# images, labels = next(dataiter)
# imshow(torchvision.utils.make_grid(images))

## Define model

In [7]:
# nn.Conv2d(in_channels, out_channels, kernel_size, 
#           stride=1, padding='valid', dilation=1, groups=1, bias=True, 
#           padding_mode='zeros', device=None, dtype=None)

# nn.MaxPool2d(kernel_size, stride=None, 
#              padding=0, dilation=1, return_indices=False, ceil_mode=False)

# nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

class model_maker(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)

        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool(x)
        
        x = torch.flatten(x, 1) # flatten all dimensions except batch

        x = self.fc1(x)
        x = F.relu(x)
        
        x = self.fc2(x)
        x = F.relu(x)
        
        x = self.fc3(x)
        return x


model = model_maker()

**Configuration**

In [8]:
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

## Training

In [9]:
save_dir = '/glade/work/ksha/torch_models/cifar_net.pth'

**Validation set**

In [10]:
L_valid = len(testset)
input_, y_true_ = testset[0]

grid_shape = np.array(input_).shape
input_test = np.empty((L_valid,)+grid_shape)
y_true = np.empty(L_valid)

for i in range(L_valid):
    input_, y_true_ = testset[i]
    input_test[i, ...] = np.array(input_)
    y_true[i] = y_true_

y_true = torch.from_numpy(y_true).long()
input_test = torch.from_numpy(input_test).float()

In [13]:
min_del = 0.0
max_tol = 3 # early stopping with 2-epoch patience
tol = 0

y_pred = model(input_test)
record = loss_func(y_pred, y_true).detach().numpy()
print('Initial loss: {}'.format(record))

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    
    for i, data in enumerate(trainloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()
        
    # on-epoch-end validation
    y_pred = model(input_test)
    record_temp = loss_func(y_pred, y_true).detach().numpy()
    print('Validation loss: {}'.format(record_temp))
    
    if record - record_temp > min_del:
        print('Validation loss improved from {} to {}'.format(record, record_temp))
        record = record_temp
        print("Save to {}".format(save_dir))
        torch.save(model.state_dict(), save_dir)
        
    else:
        print('Validation loss {} NOT improved'.format(record_temp))

Initial loss: 2.282742738723755
Validation loss: 1.9610371589660645
Validation loss improved from 2.282742738723755 to 1.9610371589660645
Save to /glade/work/ksha/torch_models/cifar_net.pth
Validation loss: 1.7789337635040283
Validation loss improved from 1.9610371589660645 to 1.7789337635040283
Save to /glade/work/ksha/torch_models/cifar_net.pth


## Load model

In [14]:
model = model_maker()
model.load_state_dict(torch.load(save_dir))

<All keys matched successfully>

In [15]:
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=5e-4, momentum=0.9)

## Fine tuning

In [16]:
save_dir = '/glade/work/ksha/torch_models/cifar_net_tune.pth'

In [17]:
min_del = 0.0
max_tol = 3 # early stopping with 2-epoch patience
tol = 0

y_pred = model(input_test)
record = loss_func(y_pred, y_true).detach().numpy()
print('Initial loss: {}'.format(record))

for epoch in range(30):  # loop over the dataset multiple times

    running_loss = 0.0
    
    for i, data in enumerate(trainloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()
        
    # on-epoch-end validation
    y_pred = model(input_test)
    record_temp = loss_func(y_pred, y_true).detach().numpy()
    print('Validation loss: {}'.format(record_temp))
    
    if record - record_temp > min_del:
        print('Validation loss improved from {} to {}'.format(record, record_temp))
        record = record_temp
        print("Save to {}".format(save_dir))
        torch.save(model.state_dict(), save_dir)
        
    else:
        print('Validation loss {} NOT improved'.format(record_temp))

Initial loss: 1.7789337635040283
Validation loss: 1.6968176364898682
Validation loss improved from 1.7789337635040283 to 1.6968176364898682
Save to /glade/work/ksha/torch_models/cifar_net_tune.pth
Validation loss: 1.617132544517517
Validation loss improved from 1.6968176364898682 to 1.617132544517517
Save to /glade/work/ksha/torch_models/cifar_net_tune.pth
Validation loss: 1.5685003995895386
Validation loss improved from 1.617132544517517 to 1.5685003995895386
Save to /glade/work/ksha/torch_models/cifar_net_tune.pth
Validation loss: 1.5218746662139893
Validation loss improved from 1.5685003995895386 to 1.5218746662139893
Save to /glade/work/ksha/torch_models/cifar_net_tune.pth
Validation loss: 1.4906188249588013
Validation loss improved from 1.5218746662139893 to 1.4906188249588013
Save to /glade/work/ksha/torch_models/cifar_net_tune.pth
Validation loss: 1.4542299509048462
Validation loss improved from 1.4906188249588013 to 1.4542299509048462
Save to /glade/work/ksha/torch_models/cifar