# Import packages

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import os, os.path
import random
from skimage import io, transform
import torchvision.models as models
from PIL import Image
import time
import copy
import matplotlib.pyplot as plt
%matplotlib inline

# Define classes and functions

In [2]:
# Label: cat == 0; dog == 1

# the number of cat/dog pics for each breed
NUM_CATS = [99, 98, 100, 100, 100, 92, 100, 100, 99, 100, 100, 100]
NUM_DOGS = [100 for i in range(12)]

In [3]:

# define a dataset class to streamline the training, validation and testing
# 3-fold cross validation
# A&B -> C, then repeat for A&C -> B and B&C -> A (according to an email from Bob)
# data is split into 3 equal parts: A, B, C
# training and validation set is split 75%/25%

class CatDogDataset(Dataset):
    def __init__(self, root_dir, set_name, transform=None):
        # root directory of the dataset is usually '../catdog/'
        self.root_dir = root_dir
        # set_name is used to identify which part (A, B or C) is testing set
        # and determine the output of the dataset is training/validation/test
        self.set_name = set_name
        # apply transform if needed
        self.transform = transform
        
        # rng (random number generator) is used to generate random indices for part A, B and C
        # we use a seed to make sure the split is fixed in each run
        rng = np.random.RandomState(seed=0)
        
        # initialise empty lists to store names for different image sets
        # we will use the names of images to retrieve images later
        cat_A, cat_B, cat_C, dog_A, dog_B, dog_C = [], [], [], [], [], []
        
        # we use for loop to sample images from each breed of dogs and cats
        # to make sure set A, B and C have the same data distribution
        for b in range(1, 13):
            
            # here we sample image indices for each breed
            # for each breed (around 100 images), sample 66 images for set catA and catB
            catAncatB = rng.choice(a=[i for i in range(NUM_CATS[b-1])], size=66, replace=False)
            
            # catA samples 33 images from combined catAncatB set
            catA = rng.choice(a=catAncatB, size=33, replace=False)
            
            # the rest of images in combined catAncatB set will go to catB
            catB = [i for i in catAncatB if i not in catA]
            
            # the rest of images in the original set will go to catC
            catC = [i for i in range(NUM_CATS[b-1]) if i not in catAncatB]
            
            # use the image indices to build the name/path of the images in each set
            # 0 is the label for cats
            cat_A += [['CATS/cat_'+str(b)+'_'+str(i)+'.png', 0] for i in catA]
            cat_B += [['CATS/cat_'+str(b)+'_'+str(i)+'.png', 0] for i in catB]
            cat_C += [['CATS/cat_'+str(b)+'_'+str(i)+'.png', 0] for i in catC]
            
            
            # it is the same process but for dogs
            # 1 is the label for dogs
            dogAndogB = rng.choice(a=[i for i in range(NUM_DOGS[b-1])], size=66, replace=False)
            dogA = rng.choice(a=dogAndogB, size=33, replace=False)
            dogB = [i for i in dogAndogB if i not in dogA]
            dogC = [i for i in range(NUM_DOGS[b-1]) if i not in dogAndogB]
            
            dog_A += [['DOGS/dog_'+str(b)+'_'+str(i)+'.png', 1] for i in dogA]
            dog_B += [['DOGS/dog_'+str(b)+'_'+str(i)+'.png', 1] for i in dogB]
            dog_C += [['DOGS/dog_'+str(b)+'_'+str(i)+'.png', 1] for i in dogC]
        
        # combine cat and dog images to form complete set A, B and C, and reshuffle
        catdogA = cat_A + dog_A
        catdogB = cat_B + dog_B
        catdogC = cat_C + dog_C
        # use 0 as the seed to make sure the reproducibility
        random.Random(0).shuffle(catdogA)
        random.Random(0).shuffle(catdogB)
        random.Random(0).shuffle(catdogC)
        
        # return a dataset depending on the set_name
        # first uppercase letter represents the test dataset in the current split
        # train/val/test represents the dataset returned in the current split
        
        if self.set_name == 'C_train':
            # when set C is the test set, set A and B combined will be training and validation set
            self.data = catdogA + catdogB
            random.Random(1).shuffle(self.data)
            # with 75%/25% split, first 1200 images will be training set
            self.data = self.data[:1200]
        elif self.set_name == 'C_val':
            self.data = catdogA + catdogB
            random.Random(1).shuffle(self.data)
            # the rest of it will be validation set
            self.data = self.data[1200:]
        elif self.set_name == 'C_test':
            self.data = catdogC
        elif self.set_name == 'B_train':
            self.data = catdogA + catdogC
            random.Random(1).shuffle(self.data)
            self.data = self.data[:1200]
        elif self.set_name == 'B_val':
            self.data = catdogA + catdogC
            random.Random(1).shuffle(self.data)
            self.data = self.data[1200:]
        elif self.set_name == 'B_test':
            self.data = catdogB
        elif self.set_name == 'A_train':
            self.data = catdogB + catdogC
            random.Random(1).shuffle(self.data)
            self.data = self.data[:1200]
        elif self.set_name == 'A_val':
            self.data = catdogB + catdogC
            random.Random(1).shuffle(self.data)
            self.data = self.data[1200:]
        elif self.set_name == 'A_test':
            self.data = catdogA
            
    
    def __len__(self):
        
        return len(self.data)
    
    def __getitem__(self, index):
        
        image_name, target = self.data[index][0], self.data[index][1]
        img = io.imread(self.root_dir + image_name)
        
        # there is a grayscale image in the cat images
        # we will convert it to a colour image by duplicating it across RGB channels
        if len(img.shape) == 2:
            img = np.repeat(img[:, :, np.newaxis], 3, axis=2)
            
        img = Image.fromarray(img)
        
        if self.transform:
            img = self.transform(img)
            
        return img, np.array([target])

In [4]:
# Data augmentation is limited to translation, rotation, flipping and scale variations
data_transforms = {
    'train': transforms.Compose([
        # apply random flips to make networks more robust
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(degrees=(-90, 90)),
        transforms.ToTensor(),
        # normalise the images
        # the first tuple is the mean in each channel, the second one is the std
        # these numbers are pre-calculated in https://pytorch.org/docs/stable/torchvision/models.html
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
}


In [5]:
def load_data(setname, batch_size=4, root_dir='../catdog/'):
    image_datasets = {x: CatDogDataset(root_dir=root_dir, 
                                       set_name=setname+'_'+x, 
                                       transform=data_transforms[x])
                  for x in ['train', 'val', 'test']}
    
    dataloaders = {x: DataLoader(image_datasets[x], 
                                 batch_size = batch_size, 
                                 shuffle=True, 
                                 num_workers=2)
              for x in ['train', 'val', 'test']}
    
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
    
    
    return dataloaders, dataset_sizes

In [6]:
def train_model(model, criterion, optimizer, model_path, num_epochs=50):
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    
    best_acc = 0.0
    best_running_corrects = 0
    
    acc_list_train = []
    acc_list_val = []
    
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
        
            running_loss = 0.0
            running_corrects = 0
            
            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                labels = labels.squeeze(1)
                # zero the parameter gradients
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            if phase == 'train':
                acc_list_train.append(epoch_acc)
            elif phase == 'val':
                acc_list_val.append(epoch_acc)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
        
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_running_corrects = running_corrects
                best_model_wts = copy.deepcopy(model.state_dict())
            
        print()
        
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    
    torch.save(best_model_wts, model_path)
    
    model.load_state_dict(best_model_wts)
    
    return model, best_acc, best_running_corrects, dataset_sizes['val'], acc_list_train, acc_list_val

In [7]:
def test_model(model):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    for inputs, labels in dataloaders['test']:
        inputs = inputs.to(device)
        labels = labels.to(device)

        labels = labels.squeeze(1)

        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels)

    epoch_loss = running_loss / dataset_sizes['test']
    epoch_acc = running_corrects.double() / dataset_sizes['test']

    print('{} Loss: {:.4f} Acc: {:.4f}'.format(
        'test', epoch_loss, epoch_acc))
    
    return epoch_acc, running_corrects

# Run the model

In [8]:
# hyper-parameters
learning_rate = 1e-3
weight_decay = 1e-2
batch_size = 16
num_epochs = 25

# use gpu if it is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# three-fold validation

# use different set (A, B or C) as test set
testsets = ['A', 'B', 'C']

corrects_val = []
val_sizes = []
corrects_test = []

for testset in testsets:
    
    print('#' * 20)
    print('Use ' + testset + ' as test set')
    print('#' * 20)
    
    dataloaders, dataset_sizes = load_data(setname=testset, batch_size=batch_size)
    
    resnet18 = models.resnet18(pretrained=True)
    
    # ‘freezing’ the convolutional layers except for the fully connected layer
    for param in resnet18.parameters():
        param.requires_grad = False
        
    # when assigning a new layer, requires_grad will be set to True automatically
    resnet18.fc = nn.Linear(512, 2)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(resnet18.parameters(), 
                                 lr=learning_rate, weight_decay=weight_decay)
    model_path = './models/' + testset + '_astest'
    
    resnet18 = resnet18.to(device)
    
    # training
    resnet18, acc_val, correct_val, val_size , acc_list_train, acc_list_val= train_model(resnet18, criterion, optimizer, 
                                                                            model_path=model_path, num_epochs=num_epochs)
    
    # test                                                                        
    acc_test, correct_test = test_model(resnet18)
    
    corrects_val.append(correct_val)
    corrects_test.append(correct_test)
    val_sizes.append(val_size)
    
    # plot and store the graphs
    
    epoch_plot = [e for e in range(1, num_epochs+1)]
    plt.plot(epoch_plot, acc_list_train, label='train')
    plt.plot(epoch_plot, acc_list_val, label='val')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    plt.xticks([x for x in epoch_plot if x%5==0 ])
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.savefig('./graphs/' + testset + '_astest.png')   
    plt.show()
    
    
average_acc_val = sum(corrects_val).double() / sum(val_sizes)
average_acc_test = sum(corrects_test).double() / 2388

print('val average acc {:.4f} '.format(average_acc_val))
print('test average acc {:.4f} '.format(average_acc_test))

####################
Use A as test set
####################
####################
Use B as test set
####################
####################
Use C as test set
####################


In [None]:
for testset in testsets:
    model = models.resnet18()
    model.fc = nn.Linear(512, 2)
    
    model_path = './models/' + testset + '_astest'
    model.load_state_dict(torch.load(model_path))

In [None]:
# layers in the model
for name, param in model.named_parameters():
    print(name, param.size())