IMPORTS

In [None]:
import os
import time
import copy
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms

DATA IMPORT

In [None]:
data_dir = '/Users/pepijnschouten/Desktop/Python_Scripts/' \
    'Python_Scripts_Books/Deep_Learning/Mastering_Pytorch/' \
        'Datasets/BEES_ANTS_JPG'
        
#  define data transforms
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.490, 0.449, 0.411],
                             [0.231, 0.221, 0.230])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.490, 0.449, 0.411],
                        [0.231, 0.221, 0.230])
    ]),
}

img_data = {
    k: datasets.ImageFolder(os.path.join(data_dir, k),
                            data_transforms[k])
    for k in ['train', 'val']
}

data_loaders = {
    k: torch.utils.data.DataLoader(img_data[k],
                                   batch_size=8,
                                   shuffle=True,
                                   num_workers=4)
    for k in ['train', 'val']
}

dataset_sizes = {
    x: len(img_data[x]) for x in ['train', 'val']
}

classes = img_data['train'].classes

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
device = torch.device('cpu')

DATA INSPECTION

In [None]:
def image_show(img, text=None):
    img = img.numpy().transpose((1, 2, 0))
    mean = np.array([0.490, 0.449, 0.411])
    sd = np.array([0.231, 0.221, 0.230])
    img = sd * img + mean
    img = np.clip(img, 0, 1)
    
    plt.figure(dpi=100, tight_layout=True)
    plt.imshow(img)
    if text is not None:
        plt.title(text)
    plt.axis('off')
    plt.show()

# generate one training batch
imgs, cls = next(iter(data_loaders['train']))

# generate a grid
grid = torchvision.utils.make_grid(imgs)

image_show(grid,
           text=[classes[c] for c in cls]
)
    

DUAL TRAINING LOOP

In [None]:
def train_pretrained_model(pretrained_model, criterion,
                optimizer, scheduler, num_epochs=10):
    start_time = time.time()
    
    pretrained_model = pretrained_model.to(device)
    
    model_weights = copy.deepcopy(pretrained_model.state_dict())
    
    accuracy = 0. 
    
    for epoch in range(num_epochs):
        print(f"Epoch: {epoch+1:3}/{num_epochs}")
        print("="*30)
        
        for dataset in ['train', 'val']:
            if dataset == 'train':
                pretrained_model.train()
            else:
                pretrained_model.eval()
                
            running_loss = 0. 
            running_successes = 0. 
            
            for inputs, labels in data_loaders[dataset]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                # context manager for training or validation mode
                with torch.set_grad_enabled(dataset == 'train'):
                    outputs = pretrained_model(inputs)
                    _, predictions = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    if dataset == 'train':
                        loss.backward()
                        optimizer.step()
                    
                running_loss += loss.item() * inputs.size(0)
                running_successes += torch.sum(predictions == labels.data)
            
            epoch_loss = loss / dataset_sizes[dataset]
            epoch_accuracy = running_successes / dataset_sizes[dataset]
            
            print(f'{dataset} loss in this epoch: {epoch_loss:.4f}')
            print(f'{dataset} accuracy in this epoch: {100*epoch_accuracy:.2f}%')
            
            if dataset == 'val' and epoch_accuracy > accuracy:
                accuracy = epoch_accuracy
                best_weights = copy.deepcopy(pretrained_model.state_dict())
        print()
        
    elapsed_time = time.time() - start_time
    print(f"Training complete in {elapsed_time//60:.0f}m {elapsed_time%60:.0f}s")
    print(f"Best val accuracy: {100*accuracy:.2f}%")

    pretrained_model.load_state_dict(best_weights)

    return pretrained_model

START TRAINING

In [None]:
pretrained_alexnet = models.alexnet(
    weights='DEFAULT'
)

print(pretrained_alexnet)
print(pretrained_alexnet.classifier)

#  change the last classifier layer
pretrained_alexnet.classifier[6] = nn.Linear(4096, len(classes))

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pretrained_alexnet.parameters(), lr=0.0001)

In [None]:
# start the training
pretrained_alexnet = train_pretrained_model(pretrained_alexnet, criterion,
                optimizer, scheduler=None, num_epochs=10
)

VISUALIZE PREDICTIONS

In [None]:
def visualize_predictions(pretrained_model, max_num_imgs=4):
    torch.manual_seed(1)
    was_model_training = pretrained_model.training
    pretrained_model.eval()
    imgs_counter = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(data_loaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)
            ops = pretrained_model(inputs)
            _, predictions = torch.max(ops, 1)
            
            for j in range(inputs.size()[0]):
                imgs_counter += 1
                ax = plt.subplot(max_num_imgs//2, 2, imgs_counter)
                ax.axis('off')
                ax.set_title(f'pred: {classes[predictions[j]]} || target: {classes[labels[j]]}')
                image_show(inputs.cpu().data[j])

                if imgs_counter == max_num_imgs:
                    pretrained_model.train(mode=was_model_training)
                    plt.show()
                    return
        pretrained_model.train(mode=was_model_training)
        
visualize_predictions(pretrained_alexnet)