Training a convolutional neural network
===============

To distinguish planes from birds. Starts at section 8.4.

In [1]:
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import collections

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.set_printoptions(edgeitems=2)
torch.manual_seed(123)

<torch._C.Generator at 0x117a7fd10>

For ASCII plots of training progress

In [2]:
import asciichartpy
from IPython.display import clear_output

Used for logging

In [3]:
import datetime 

## Image dataset

In [4]:
class_names = ['airplane','automobile','bird','cat','deer',
               'dog','frog','horse','ship','truck']

In [5]:
from torchvision import datasets, transforms
data_path = '../data-unversioned/p1ch6/'
cifar10 = datasets.CIFAR10(
    data_path, train=True, download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4915, 0.4823, 0.4468),
                             (0.2470, 0.2435, 0.2616))
    ]))

Files already downloaded and verified


In [6]:
cifar10_val = datasets.CIFAR10(
    data_path, train=False, download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4915, 0.4823, 0.4468),
                             (0.2470, 0.2435, 0.2616))
    ]))

Files already downloaded and verified


In [7]:
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']
cifar2 = [(img, label_map[label])
          for img, label in cifar10
          if label in [0, 2]]
cifar2_val = [(img, label_map[label])
              for img, label in cifar10_val
              if label in [0, 2]]

## Defines the CNN architecture 

With a functional API.

This is the architecture:

![img description](fig8.10.png)

In [8]:
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(8 * 8 * 8, 32)
        self.fc2 = nn.Linear(32, 2)
        
    def forward(self, x):
        out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)
        out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)
        out = out.view(-1, 8 * 8 * 8)
        out = torch.tanh(self.fc1(out))
        out = self.fc2(out)
        return out

This tests that the model is giving the expected output.

In [9]:
# picks one of the images
img, _ = cifar2[0]

In [10]:
model = Net()
model(img.unsqueeze(0))

tensor([[0.0908, 0.0938]], grad_fn=<AddmmBackward0>)

## Training (CPU)

Remember what the loop is doing: 

![img description](IMG_4B8DBAE41433-1.jpeg)

![img description](IMG_5AE714E751EB-1.jpeg)

This is the original training loop as in the book. I modified it to produce a nice ascii plot as in the previous chapter.

Improved code

In [11]:
def training_loop(n_epochs, optimizer, model, loss_fn, train_loader):
    lossList=[]
    epochList=[]
    
    for epoch in range(1, n_epochs + 1):  # <2>
        loss_train = 0.0
        
        for imgs, labels in train_loader:  # <3>
            outputs = model(imgs)  # <4>
            loss = loss_fn(outputs, labels)  # <5
            optimizer.zero_grad()  # <6>
            loss.backward()  # <7>
            optimizer.step()  # <8>

            loss_train += loss.item()  # <9>

        lossList.append(np.log10(loss_train/len(train_loader)))
        epochList.append(epoch)
        clear_output(wait=True)
        print("              ",datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),"| Epoch", epoch, "| Loss",round(loss_train/len(train_loader),4))
        print(asciichartpy.plot(lossList, {'height': 10}))

Starts the training

In [12]:
%%time
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,
                                           shuffle=True)  # <1>

model = Net()  #  <2>
optimizer = optim.SGD(model.parameters(), lr=1e-2)  #  <3>
loss_fn = nn.CrossEntropyLoss()  #  <4>

training_loop(  # <5>
    n_epochs = 100,
    optimizer = optimizer,
    model = model,
    loss_fn = loss_fn,
    train_loader = train_loader,
)

               2024-10-30 10:09:04.434106 | Epoch 100 | Loss 0.1488
   -0.25  ┼╮
   -0.30  ┤╰╮
   -0.36  ┤ ╰─╮
   -0.41  ┤   ╰─╮
   -0.46  ┤     ╰─────╮
   -0.51  ┤           ╰──────────────╮
   -0.57  ┤                          ╰─────────────╮
   -0.62  ┤                                        ╰────────────╮
   -0.67  ┤                                                     ╰───────────────╮
   -0.72  ┤                                                                     ╰──────────────╮╭╮
   -0.78  ┤                                                                                    ╰╯╰────────────
   -0.83  ┤
CPU times: user 12min 19s, sys: 2min 9s, total: 14min 28s
Wall time: 3min 34s


### Accuracy

In [13]:
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,
                                           shuffle=False)
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,
                                         shuffle=False)

Defines method that computes accuracy

In [14]:
def validate(model, train_loader, val_loader):
    for name, loader in [("train", train_loader), ("val", val_loader)]:
        correct = 0
        total = 0

        with torch.no_grad():  # <1>
            for imgs, labels in loader:
                outputs = model(imgs)
                _, predicted = torch.max(outputs, dim=1) # <2>
                total += labels.shape[0]  # <3>
                correct += int((predicted == labels).sum())  # <4>

        print("Accuracy {}: {:.2f}".format(name , correct / total))

In [15]:
validate(model, train_loader, val_loader)

Accuracy train: 0.95
Accuracy val: 0.89


Remember that the fully connected model got only 79% accuracy.

### Saves the model to a file

In [16]:
torch.save(model.state_dict(), data_path + 'birds_vs_airplanes.pt')

When loading the model back, make sure you do not change anything in the class definition

```python
loaded_model = Net()  # <1>
loaded_model.load_state_dict(torch.load(data_path
                                        + 'birds_vs_airplanes.pt'))
```

## Training on the GPU

Selected appropriate device

In [17]:
if torch.cuda.is_available(): 
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

In [18]:
device

device(type='mps')

Train on the GPU

In [26]:
def training_loop(n_epochs, optimizer, model, loss_fn, train_loader):
    lossList=[]
    epochList=[]
    
    for epoch in range(1, n_epochs + 1):
        loss_train = 0.0
        for imgs, labels in train_loader:
            imgs = imgs.to(device=device)  # moves data to GPU
            labels = labels.to(device=device) # moves labels to GPU
            outputs = model(imgs)
            loss = loss_fn(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_train += loss.item()

        lossList.append(np.log10(loss_train/len(train_loader)))
        epochList.append(epoch)
        clear_output(wait=True)
        print("              ",datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),"| Epoch", epoch, "| Loss",round(loss_train/len(train_loader),4))
        print(asciichartpy.plot(lossList, {'height': 10}))

In [None]:
%%time
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,
                                           shuffle=True)

model = Net().to(device=device)  # moves model and parameters to GPU
optimizer = optim.SGD(model.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()

training_loop(
    n_epochs = 100,
    optimizer = optimizer,
    model = model,
    loss_fn = loss_fn,
    train_loader = train_loader,
)

               2024-10-30 10:14:04 | Epoch 29 | Loss 0.2737
   -0.24  ┤
   -0.27  ┼╮
   -0.30  ┤│
   -0.33  ┤╰─╮
   -0.36  ┤  ╰─╮
   -0.39  ┤    ╰╮
   -0.42  ┤     ╰╮
   -0.45  ┤      ╰─╮
   -0.48  ┤        ╰──╮
   -0.50  ┤           ╰──────╮
   -0.53  ┤                  ╰───────╮
   -0.56  ┤                          ╰─


### Accuracy

In [21]:
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,
                                           shuffle=False)
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,
                                         shuffle=False)

In [22]:
def validate(model, train_loader, val_loader):
    accdict = {}
    for name, loader in [("train", train_loader), ("val", val_loader)]:
        correct = 0
        total = 0

        with torch.no_grad():
            for imgs, labels in loader:
                imgs = imgs.to(device=device)
                labels = labels.to(device=device)
                outputs = model(imgs)
                _, predicted = torch.max(outputs, dim=1) # <1>
                total += labels.shape[0]
                correct += int((predicted == labels).sum())

        print("Accuracy {}: {:.2f}".format(name , correct / total))
        accdict[name] = correct / total
    return accdict

### Saving and loading

#### Results

Create ordered dictionary for storing accuracy of different models

In [23]:
all_acc_dict = collections.OrderedDict()
all_acc_dict["baseline"] = validate(model, train_loader, val_loader)

Accuracy train: 0.94
Accuracy val: 0.90


Saves dictionary to disk, for loading it in the next notebook

In [24]:
import pickle

# Write the dictionary to a binary file
with open('model-accuracy-dict.pkl', 'wb') as pickleFile:
    pickle.dump(all_acc_dict, pickleFile)

#### Model

Saves model parameters to disk

In [25]:
loaded_model = Net().to(device=device)
loaded_model.load_state_dict(torch.load(data_path
                                        + 'birds_vs_airplanes.pt',
                                        map_location=device))

  loaded_model.load_state_dict(torch.load(data_path


<All keys matched successfully>