# CNN architectures



In [30]:
import os
import torch
import torchvision
import numpy as np
from torchvision import datasets, transforms, models
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.model_selection import train_test_split

In [31]:
# Set manual seeds
torch.manual_seed(42)
np.random.seed(42)

In [32]:
NETWORK = 'resnet' # 'resnet'

# Custom network params
INPUT_CHANNELS = 3
CONV1_CHANNELS = 6
CONV1_KERNEL_SIZE = 5
CONV2_CHANNELS = 16
CONV2_KERNEL_SIZE = 5
HIDDEN_LAYER1_SIZE = 256
HIDDEN_LAYER2_SIZE = 120
HIDDEN_LAYER3_SIZE = 84

# Resnet params
FINETUNING = True

# Number of Pokémon to classify: all 1st generation Pokémon
OUTPUT_LAYER_SIZE = 151

OPTIM = 'SGD' # 'Adam'
LR = 0.001
MOMENTUM = 0.9

MINIBATCH_SIZE = 32
EPOCHS = 8

In [33]:
# Simple CNN based on MNIST example of the PyTorch documentation
# Input images are of size 3x224x224
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(INPUT_CHANNELS, CONV1_CHANNELS, CONV1_KERNEL_SIZE)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(CONV1_CHANNELS, CONV2_CHANNELS, CONV2_KERNEL_SIZE)
        self.fc1 = nn.Linear(16*53*53, HIDDEN_LAYER1_SIZE)
        self.fc2 = nn.Linear(HIDDEN_LAYER1_SIZE, HIDDEN_LAYER2_SIZE)
        self.fc3 = nn.Linear(HIDDEN_LAYER2_SIZE, HIDDEN_LAYER3_SIZE)
        self.fc4 = nn.Linear(HIDDEN_LAYER3_SIZE, OUTPUT_LAYER_SIZE)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*53*53)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x


# Use GPU if available 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if NETWORK == 'resnet':
    # Use resnet, modify the final layer to adapt to our classification problem
    # with 151 classes
    net = models.resnet18(pretrained=True)
    net.fc = nn.Linear(in_features=net.fc.in_features, out_features=OUTPUT_LAYER_SIZE, bias=True)
else:
    # Use custom net
    net = Net()

net.to(device)
print(net)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

In [34]:
# Colorwise means and std of ImageNet, used to train resnet
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

# Transforms to apply
data_transforms = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

# The dataset should be in ./pokemon-generation-one
# It can be downloaded from 
# https://www.kaggle.com/thedagger/pokemon-generation-one
image_directory = "./pokemon-generation-one"
dataset_full = datasets.ImageFolder(image_directory, data_transforms)
loader_full = torch.utils.data.DataLoader(dataset_full,
                                          batch_size=MINIBATCH_SIZE,
                                          shuffle=True,
                                          num_workers=4)

# Split dataset into 3 sets, for train, val and test
samples_train, samples_test = train_test_split(dataset_full.samples)
samples_train, samples_val = train_test_split(samples_train)

print(f"Training images: {len(samples_train)}")
print(f"Validation images: {len(samples_val)}")
print(f"Test images: {len(samples_test)}")


dataset_train = datasets.ImageFolder(image_directory, data_transforms)
dataset_train.samples = samples_train
dataset_train.imgs = samples_train
loader_train = torch.utils.data.DataLoader(dataset_train, 
                                           batch_size=MINIBATCH_SIZE, 
                                           shuffle=True, 
                                           num_workers=4)

dataset_val = datasets.ImageFolder(image_directory, data_transforms)
dataset_val.samples = samples_val
dataset_val.imgs = samples_val
loader_val = torch.utils.data.DataLoader(dataset_val, 
                                         batch_size=MINIBATCH_SIZE, 
                                         shuffle=True, 
                                         num_workers=4)

dataset_test = datasets.ImageFolder(image_directory, data_transforms)
dataset_test.samples = samples_test
dataset_test.imgs = samples_test
loader_test = torch.utils.data.DataLoader(dataset_test, 
                                          batch_size=MINIBATCH_SIZE, 
                                          shuffle=True, 
                                          num_workers=4)

Training images: 5994
Validation images: 1998
Test images: 2665


In [35]:
# Define loss and optimizer
criterion = nn.CrossEntropyLoss()

if NETWORK == 'resnet' and not FINETUNING:
    params = net.fc.parameters()
else:
    params = net.parameters()

if OPTIM == 'Adam':
    optimizer = optim.Adam(params, lr=LR)
else:
    optimizer = optim.SGD(params, lr=LR, momentum=MOMENTUM)

In [36]:
# Training function
# Calculates validation accuracy after each epoch
def train_model(model, loader_train, loader_val, optimizer, criterion, n_epochs=5):
    for epoch in range(n_epochs): # à chaque epochs
        model.train()
        running_loss = 0.0

        print(f"======= EPOCH {epoch+1} =======")

        for i, data in enumerate(loader_train):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            loss.backward()
            optimizer.step()
            if i % 10 == 9: # Print every 10 mini-batches
                print(f"Batches {i-8}-{i+1} loss: {running_loss/10:.3f}")
                running_loss = 0.0

        # Evaluate validation loss
        model.eval()
        loss, accuracy = evaluate(net, loader_val)
        print(f"Accuracy: {100*accuracy:.1f}%")  

# Evaluation function
def evaluate(model, loader):
    avg_loss = 0.
    avg_accuracy = 0
    n_total = 0
    
    for data in loader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        _, preds = torch.max(outputs, 1)
        n_correct = torch.sum(preds == labels)
        
        n_total += outputs.size()[0]
        avg_loss += loss.item()
        avg_accuracy += n_correct

    return avg_loss / n_total, float(avg_accuracy) / n_total


In [39]:
# Train the model
train_model(net, loader_train, loader_val, optimizer, criterion, n_epochs=EPOCHS)

Batches 1-10 loss: 0.213
Batches 11-20 loss: 0.226
Batches 21-30 loss: 0.222
Batches 31-40 loss: 0.223
Batches 41-50 loss: 0.209
Batches 51-60 loss: 0.238
Batches 61-70 loss: 0.239
Batches 71-80 loss: 0.221


  'to RGBA images')


Batches 81-90 loss: 0.250


  'to RGBA images')


Batches 91-100 loss: 0.202
Batches 101-110 loss: 0.230
Batches 111-120 loss: 0.233
Batches 121-130 loss: 0.229
Batches 131-140 loss: 0.221
Batches 141-150 loss: 0.219


  'to RGBA images')


Batches 151-160 loss: 0.254
Batches 161-170 loss: 0.221
Batches 171-180 loss: 0.246
Accuracy: 79.2%
Batches 1-10 loss: 0.151
Batches 11-20 loss: 0.197


  'to RGBA images')


Batches 21-30 loss: 0.200
Batches 31-40 loss: 0.189
Batches 41-50 loss: 0.184
Batches 51-60 loss: 0.150
Batches 61-70 loss: 0.189
Batches 71-80 loss: 0.176
Batches 81-90 loss: 0.192
Batches 91-100 loss: 0.194
Batches 101-110 loss: 0.195


  'to RGBA images')


Batches 111-120 loss: 0.156
Batches 121-130 loss: 0.180
Batches 131-140 loss: 0.192
Batches 141-150 loss: 0.203
Batches 151-160 loss: 0.199
Batches 161-170 loss: 0.200
Batches 171-180 loss: 0.215
Accuracy: 79.4%
Batches 1-10 loss: 0.158


  'to RGBA images')


Batches 11-20 loss: 0.144
Batches 21-30 loss: 0.163
Batches 31-40 loss: 0.140
Batches 41-50 loss: 0.146
Batches 51-60 loss: 0.165
Batches 61-70 loss: 0.161
Batches 71-80 loss: 0.157
Batches 81-90 loss: 0.135


  'to RGBA images')


Batches 91-100 loss: 0.168
Batches 101-110 loss: 0.156
Batches 111-120 loss: 0.153
Batches 121-130 loss: 0.152
Batches 131-140 loss: 0.170
Batches 141-150 loss: 0.167
Batches 151-160 loss: 0.156
Batches 161-170 loss: 0.150


  'to RGBA images')


Batches 171-180 loss: 0.168
Accuracy: 79.5%
Batches 1-10 loss: 0.147


  'to RGBA images')


Batches 11-20 loss: 0.130
Batches 21-30 loss: 0.132
Batches 31-40 loss: 0.120
Batches 41-50 loss: 0.132
Batches 51-60 loss: 0.134
Batches 61-70 loss: 0.131
Batches 71-80 loss: 0.121
Batches 81-90 loss: 0.134
Batches 91-100 loss: 0.126
Batches 101-110 loss: 0.139
Batches 111-120 loss: 0.159
Batches 121-130 loss: 0.148
Batches 131-140 loss: 0.130
Batches 141-150 loss: 0.132
Batches 151-160 loss: 0.116
Batches 161-170 loss: 0.160
Batches 171-180 loss: 0.121


  'to RGBA images')


Accuracy: 79.5%
Batches 1-10 loss: 0.113
Batches 11-20 loss: 0.112
Batches 21-30 loss: 0.099
Batches 31-40 loss: 0.102
Batches 41-50 loss: 0.112


  'to RGBA images')


Batches 51-60 loss: 0.114
Batches 61-70 loss: 0.108


  'to RGBA images')


Batches 71-80 loss: 0.095
Batches 81-90 loss: 0.126
Batches 91-100 loss: 0.108
Batches 101-110 loss: 0.113
Batches 111-120 loss: 0.111
Batches 121-130 loss: 0.141
Batches 131-140 loss: 0.112
Batches 141-150 loss: 0.113
Batches 151-160 loss: 0.129
Batches 161-170 loss: 0.127


  'to RGBA images')


Batches 171-180 loss: 0.108
Accuracy: 79.5%
Batches 1-10 loss: 0.104
Batches 11-20 loss: 0.081
Batches 21-30 loss: 0.107
Batches 31-40 loss: 0.104
Batches 41-50 loss: 0.101


  'to RGBA images')


Batches 51-60 loss: 0.109
Batches 61-70 loss: 0.096
Batches 71-80 loss: 0.112
Batches 81-90 loss: 0.080


  'to RGBA images')


Batches 91-100 loss: 0.112
Batches 101-110 loss: 0.121


  'to RGBA images')


Batches 111-120 loss: 0.119
Batches 121-130 loss: 0.102
Batches 131-140 loss: 0.114
Batches 141-150 loss: 0.104
Batches 151-160 loss: 0.085
Batches 161-170 loss: 0.099
Batches 171-180 loss: 0.122
Accuracy: 79.7%


  'to RGBA images')
  'to RGBA images')


Batches 1-10 loss: 0.085
Batches 11-20 loss: 0.104
Batches 21-30 loss: 0.088
Batches 31-40 loss: 0.082
Batches 41-50 loss: 0.078
Batches 51-60 loss: 0.090
Batches 61-70 loss: 0.108
Batches 71-80 loss: 0.106


  'to RGBA images')


Batches 81-90 loss: 0.095
Batches 91-100 loss: 0.089
Batches 101-110 loss: 0.077
Batches 111-120 loss: 0.101
Batches 121-130 loss: 0.121
Batches 131-140 loss: 0.090
Batches 141-150 loss: 0.110
Batches 151-160 loss: 0.086
Batches 161-170 loss: 0.093
Batches 171-180 loss: 0.095
Accuracy: 80.7%


  'to RGBA images')


Batches 1-10 loss: 0.077
Batches 11-20 loss: 0.075
Batches 21-30 loss: 0.094
Batches 31-40 loss: 0.082
Batches 41-50 loss: 0.074
Batches 51-60 loss: 0.098
Batches 61-70 loss: 0.079
Batches 71-80 loss: 0.082
Batches 81-90 loss: 0.082
Batches 91-100 loss: 0.075
Batches 101-110 loss: 0.099
Batches 111-120 loss: 0.077
Batches 121-130 loss: 0.084
Batches 131-140 loss: 0.078
Batches 141-150 loss: 0.084
Batches 151-160 loss: 0.078


  'to RGBA images')


Batches 161-170 loss: 0.093
Batches 171-180 loss: 0.106
Accuracy: 79.9%


In [28]:
# Evaluate the model on test set
net.eval()
loss, accuracy = evaluate(net, loader_test)
print(f"Accuracy: {100*accuracy:.1f}%")

Accuracy: 81.5%
