# Hyperparameter Tuning with `optuna`

Run this notebook with `env_optuna`.

## Basic training script

Classify images from the CIFAR10 dataset. We look at the effect of batch size, learning rate, and momentum.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

In [2]:
DATA_PATH = "/home/tzhao/Workspace/hyperparameter-tuning-example/data"
NUM_EPOCHS = 4
DEVICE = "cpu"

In [3]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train(batch_size, lr, momentum, verbose=True, download=True):
    # Load dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root=DATA_PATH, train=True, download=download, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    testset = torchvision.datasets.CIFAR10(root=DATA_PATH, train=False, download=download, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    # Load nn
    net = Net().to(DEVICE)

    # Load loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    
    # Train
    for epoch in range(NUM_EPOCHS):
        
        # Train single epoch
        for data in trainloader:
            inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        # Test
        test_loss = 0
        total = 0
        correct = 0
        with torch.no_grad():
            for data in testloader:
                inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)
                outputs = net(inputs)

                # Get test loss
                loss = criterion(outputs, labels)
                test_loss += loss.item()

                # Get test acc
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_acc = correct / total

        if verbose:
            print(f"Epoch {epoch}: Test loss: {test_loss:.2f}")
            print(f"Epoch {epoch}: Test acc: {test_acc:.2f}")

    if verbose:
        print("Training completed.")
        print(f"Final test loss: {test_loss:.2f}")
        print(f"Final test acc: {test_acc:.2f}")
    
    return test_acc

In [4]:
train(batch_size=4, lr=0.001, momentum=0.9, download=False)

Epoch 0: Test loss: 3650.72
Epoch 0: Test acc: 0.47
Epoch 1: Test loss: 3196.43
Epoch 1: Test acc: 0.54
Epoch 2: Test loss: 2992.61
Epoch 2: Test acc: 0.57
Epoch 3: Test loss: 3002.47
Epoch 3: Test acc: 0.58
Training completed.
Final test loss: 3002.47
Final test acc: 0.58


0.5778

In [5]:
train(batch_size=128, lr=0.001, momentum=0.9, download=False)

Epoch 0: Test loss: 181.50
Epoch 0: Test acc: 0.11
Epoch 1: Test loss: 179.16
Epoch 1: Test acc: 0.16
Epoch 2: Test loss: 164.21
Epoch 2: Test acc: 0.24
Epoch 3: Test loss: 151.82
Epoch 3: Test acc: 0.30
Training completed.
Final test loss: 151.82
Final test acc: 0.30


0.3039

## Hyperparameter tuning with `optuna`

In [6]:
import optuna
from optuna.trial import TrialState

In [7]:
def train_optuna(trial):

    #########################################
    ##### OPTUNA: Initialize parameters #####
    #########################################

    batch_size=trial.suggest_int("batch_size", 4, 128)
    lr=trial.suggest_float("lr", 1e-5, 1e-1, log=True) # uniform probability in log space
    momentum=trial.suggest_float("momentum", 0.1, 0.99)
    verbose = False
    download = False

    #########################################
    #########################################
    #########################################

    # Load dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root=DATA_PATH, train=True, download=download, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    testset = torchvision.datasets.CIFAR10(root=DATA_PATH, train=False, download=download, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    # Load nn
    net = Net().to(DEVICE)

    # Load loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    
    # Train
    for epoch in range(NUM_EPOCHS):
        
        # Train single epoch
        for data in trainloader:
            inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        # Test
        test_loss = 0
        total = 0
        correct = 0
        with torch.no_grad():
            for data in testloader:
                inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)
                outputs = net(inputs)

                # Get test loss
                loss = criterion(outputs, labels)
                test_loss += loss.item()

                # Get test acc
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_acc = correct / total

        if verbose:
            print(f"Epoch {epoch}: Test loss: {test_loss:.2f}")
            print(f"Epoch {epoch}: Test acc: {test_acc:.2f}")

        #######################################################
        ##### OPTUNA: Report test loss and handle pruning #####
        #######################################################

        trial.report(test_acc, epoch)

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.TrialPruned()

        #######################################################
        #######################################################
        #######################################################

    if verbose:
        print("Training completed.")
        print(f"Final test loss: {test_loss:.2f}")
        print(f"Final test acc: {test_acc:.2f}")
    
    return test_acc

In [8]:
# Maximize accuracy, given some constraints
study = optuna.create_study(direction="maximize")
study.optimize(
    train_optuna,
    n_trials=None, # No limit to number of trials
    timeout=300, # Time limit of 5 minutes
)

[I 2023-07-21 07:43:59,186] A new study created in memory with name: no-name-6db2cdca-c65e-4fec-83eb-fe9ed2ebf217
[I 2023-07-21 07:44:34,193] Trial 0 finished with value: 0.1 and parameters: {'batch_size': 62, 'lr': 0.00010171654322735485, 'momentum': 0.3210127210065065}. Best is trial 0 with value: 0.1.
[I 2023-07-21 07:45:10,864] Trial 1 finished with value: 0.5849 and parameters: {'batch_size': 41, 'lr': 0.0027335863991340766, 'momentum': 0.9534029792195253}. Best is trial 1 with value: 0.5849.
[I 2023-07-21 07:45:50,836] Trial 2 finished with value: 0.1 and parameters: {'batch_size': 28, 'lr': 8.75442190148393e-05, 'momentum': 0.20748796316715568}. Best is trial 1 with value: 0.5849.
[I 2023-07-21 07:46:26,054] Trial 3 finished with value: 0.1 and parameters: {'batch_size': 55, 'lr': 1.3806462702311849e-05, 'momentum': 0.6777786611322322}. Best is trial 1 with value: 0.5849.
[I 2023-07-21 07:47:02,180] Trial 4 finished with value: 0.1 and parameters: {'batch_size': 46, 'lr': 5.2266

In [9]:
# Print some statistics about trials
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

print("Study statistics: ")
print("Number of finished trials: ", len(study.trials))
print("Number of pruned trials: ", len(pruned_trials))
print("Number of complete trials: ", len(complete_trials))

Study statistics: 
Number of finished trials:  8
Number of pruned trials:  0
Number of complete trials:  8


In [10]:
# Print the best trial
trial = study.best_trial
print("Best trial:")
print("Value:", study.best_trial.value)
print("Params:")
for key, value in study.best_trial.params.items():
    print("    {}: {}".format(key, value))

Best trial:
Value: 0.5849
Params:
    batch_size: 41
    lr: 0.0027335863991340766
    momentum: 0.9534029792195253
