In [1]:
# 1. Transfer learning

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

torch.manual_seed(1)

def train(fineTune=False):

    mean = np.array([0.5, 0.5, 0.5])
    std = np.array([0.25, 0.25, 0.25])

    #############################################################################################
    momentum = 0
    loss_threshold = 0
    batch_size = 20
    learning_rate = 0.01
    num_epochs = 10
    debug_steps = 10
    #############################################################################################

    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]),
    }

    data_dir = '/kaggle/input/hymenoptera-data/hymenoptera_data/hymenoptera_data'
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x]) 
                      for x in ['train', 'val']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
                                                 shuffle=True, num_workers=0)
                  for x in ['train', 'val']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    class_names = image_datasets['train'].classes

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    new_output_features = len(class_names)
    alexnet = models.alexnet(pretrained=True)
    
    if fineTune:
        #freeze the parameters
        for param in alexnet.parameters():
            param.requires_grad = False
    
    alexnet_in_features = alexnet.classifier[6].in_features
    alexnet.classifier[6] = nn.Linear(alexnet_in_features, new_output_features)
    alexnet = alexnet.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(alexnet.parameters(), lr=learning_rate, momentum=momentum)

    train_loader = dataloaders['train']
    test_loader = dataloaders['val']

    t0 = time.time()
    n_total_steps = len(train_loader)

    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            # origin shape: [4, 3, 32, 32] = 4, 3, 1024
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = alexnet(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i+1) % debug_steps == 0:
                print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')

            if loss_threshold != 0 and loss.item() < loss_threshold:
                print(f'\nMin loss threshold reached. Stopping at epoch: {epoch}')
                break

        if loss_threshold != 0 and loss.item() < loss_threshold:
            print(f'\nMin loss threshold reached. Stopping at epoch: {epoch}')
            break

    print('Finished Training')
    PATH = './cnn.pth'
    torch.save(alexnet.state_dict(), PATH)

    # Evaluate transfer learning
    with torch.no_grad():
        n_correct = 0
        n_samples = 0
        n_class_correct = [0 for i in range(new_output_features)]
        n_class_samples = [0 for i in range(new_output_features)]
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = alexnet(images)
            # max returns (value ,index)
            _, predicted = torch.max(outputs, 1)
            n_samples += labels.size(0)
            n_correct += (predicted == labels).sum().item()

            for i in range(batch_size):
                if i < list(labels.shape)[0]:
                    label = labels[i]
                    pred = predicted[i]
                    if (label == pred):
                        n_class_correct[label] += 1
                    n_class_samples[label] += 1

        acc = 100.0 * n_correct / n_samples
        print(f'Accuracy of the network: {acc} %')

        for i in range(new_output_features):
            acc1 = 100.0 * n_class_correct[i] / n_class_samples[i]
            print(f'Accuracy of {class_names[i]}: {acc1} %')

    ttime = int(time.time() - t0)
    return acc, ttime

In [2]:
acc_without_fineTune, ttime_without_fineTune = train(fineTune=False)
print(f'Without fine tunning\nAccuracy: {acc_without_fineTune} and training time: {ttime_without_fineTune}')

Downloading: "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-4df8aa71.pth


  0%|          | 0.00/233M [00:00<?, ?B/s]

Epoch [1/10], Step [10/13], Loss: 0.3559
Epoch [2/10], Step [10/13], Loss: 0.1122
Epoch [3/10], Step [10/13], Loss: 0.2282
Epoch [4/10], Step [10/13], Loss: 0.4148
Epoch [5/10], Step [10/13], Loss: 0.1594
Epoch [6/10], Step [10/13], Loss: 0.3260
Epoch [7/10], Step [10/13], Loss: 0.1921
Epoch [8/10], Step [10/13], Loss: 0.1510
Epoch [9/10], Step [10/13], Loss: 0.1581
Epoch [10/10], Step [10/13], Loss: 0.0738
Finished Training
Accuracy of the network: 85.62091503267973 %
Accuracy of ants: 92.85714285714286 %
Accuracy of bees: 79.51807228915662 %
Without fine tunning
Accuracy: 85.62091503267973 and training time: 23


In [3]:
acc_fineTune, ttime_fineTune = train(fineTune=True)
print(f'Fine tunning\nAccuracy: {acc_fineTune} and training time: {ttime_fineTune}')

Epoch [1/10], Step [10/13], Loss: 0.9377
Epoch [2/10], Step [10/13], Loss: 0.1957
Epoch [3/10], Step [10/13], Loss: 0.3523
Epoch [4/10], Step [10/13], Loss: 0.1115
Epoch [5/10], Step [10/13], Loss: 0.0700
Epoch [6/10], Step [10/13], Loss: 0.3822
Epoch [7/10], Step [10/13], Loss: 0.1679
Epoch [8/10], Step [10/13], Loss: 0.0988
Epoch [9/10], Step [10/13], Loss: 0.0990
Epoch [10/10], Step [10/13], Loss: 0.2056
Finished Training
Accuracy of the network: 88.23529411764706 %
Accuracy of ants: 91.42857142857143 %
Accuracy of bees: 85.5421686746988 %
Fine tunning
Accuracy: 88.23529411764706 and training time: 19
