# Finetuning AlexNet #

## Imports ##


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image

cudnn.benchmark = True # enable the inbuilt cudnn auto-tuner to find the best algorithm to use, works best when the input size doesn't change
plt.ion()   # interactive mode

  from .autonotebook import tqdm as notebook_tqdm


<matplotlib.pyplot._IonContext at 0x7f178811d0d0>

In [18]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  return torch._C._cuda_getDeviceCount() > 0


In [None]:
def evaluate_model(model, criterion):
    # Training phase
    model.eval()
    
    total_loss = 0.0
    total_correct_preds = 0

    with torch.no_grad():  
        for inputs, labels in test_dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
    
            # get output of model
            outputs = model(inputs)
            # the prediction is the class with the highest probability
            _, preds = torch.max(outputs, 1)
            # calculate the loss
            loss = criterion(outputs, labels)
    
            # stats
            total_loss += loss.item() * inputs.size(0)
            total_correct_preds += torch.sum(preds == labels.data)
            
    test_loss = total_loss / len(test_dataloader.dataset)
    test_acc = total_correct_preds.double() / len(test_dataloader.dataset)

    print(f'Test Loss: {test_loss:.4f} Accuracy: {test_acc:.4f}')

    return test_loss, test_acc


def train_model(model, criterion, optimizer, scheduler, epochs=25, save_path='results/', save_name='best_model_params.pt', eval_interval=1):
    start_time = time.time()
    
    if not os.path.isdir(save_path):
        os.mkdir(save_path) 
    model_params_path = os.path.join(save_path, save_name)

    torch.save(model.state_dict(), model_params_path)
    best_acc = 0.0

    for epoch in range(1, epochs + 1):
        print(f'Epoch {epoch} / {epochs}')
        print('*' * 20)

        epoch_start = time.time()
        # Training phase
        model.train()
        
        total_loss = 0.0
        total_correct_preds = 0

        for inputs, labels in train_dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # get output of model
            outputs = model(inputs)
            # the prediction is the class with the highest probability
            _, preds = torch.max(outputs, 1)
            # calculate the loss
            loss = criterion(outputs, labels)

            # backward + optimize the weights
            loss.backward()
            optimizer.step()

            # stats
            total_loss += loss.item() * inputs.size(0)
            total_correct_preds += torch.sum(preds == labels.data)
            
            scheduler.step()


        # Epoch stats
        epoch_loss = total_loss / len(train_dataloader.dataset)
        epoch_acc = total_correct_preds.double() / len(train_dataloader.dataset)

        print(f'Training Loss: {epoch_loss:.4f} Accuracy: {epoch_acc:.4f}')
        

        if epoch % eval_interval == 0:
            test_loss, test_acc = evaluate_model(model, criterion)

            # save the model
            if test_acc > best_acc:
                best_acc = test_acc
                torch.save(model.state_dict(), model_params_path)

        print()
        time_elapsed = time.time() - epoch_start
        print(f'Epoch complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print()
    
    time_elapsed = time.time() - start_time
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best evaluation Accuracy: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(torch.load(best_model_params_path))
    return model

In [None]:
models.list_models()

In [10]:
model = models.alexnet(weights='DEFAULT') # equivalent to ``models.alexnet(weights='IMAGENET1K_V1')``
print(model.children)

<bound method Module.children of AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216

In [19]:
# We take the last layer in the network, which is the output layer, in this case, is the 6-th layer inside ``classifier``
print(model.classifier[6])
in_features = model.classifier[6].in_features
print(in_features)

# We set the output size to 2 since we have 2 classes
model.classifier[6] = nn.Linear(in_features, 2)

Linear(in_features=4096, out_features=2, bias=True)
4096


In [None]:
model = model.to(device)

criterion = nn.CrossEntropyLoss()

# All parameters are being optimized
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [None]:
model_conv = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=25, 
                         save_path='results/', save_name='best_model_params.pt', eval_interval=1)