In [8]:
from typing import Tuple, Union

import optuna
from optuna.trial import TrialState
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms

In [9]:
def try_gpu(x: Union[nn.Module, torch.Tensor]) -> Union[nn.Module, torch.Tensor]:
    if torch.cuda.is_available():
        return x.cuda()
    else:
        return x

In [18]:
class NNModel(nn.Module):
    def __init__(self, trial: optuna.trial.Trial, n_input: int=28*28, n_output: int=10):
        super(NNModel, self).__init__()

        n_layers = trial.suggest_int("n_layers", 1, 3)
        layers = []
        for i in range(n_layers):
            n_hidden = trial.suggest_int(f"n_units_l{i}", 4, 128)
            p = trial.suggest_float(f"dropout_l{i}", 0.2, 0.5)
            layers.append(nn.Linear(n_input, n_hidden))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(p))
            n_input = n_hidden
        layers.append(nn.Linear(n_input, n_output))
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

In [19]:
def load_mnist(mnist_path: str, batch_size: int=128) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
    train_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST(mnist_path, train=True, download=True, transform=transforms.ToTensor()),
        batch_size=batch_size,
        shuffle=True
    )
    valid_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST(mnist_path, train=False, transform=transforms.ToTensor()),
        batch_size=batch_size,
        shuffle=True
    )

    return train_loader, valid_loader

In [20]:
def objective(trial: optuna.trial.Trial, epochs: int=10, batch_size: int=128):
    model = try_gpu(NNModel(trial))

    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)

    train_loader, valid_loader = load_mnist("/mnt/d/dataset/mnist", batch_size=batch_size)
    for epoch in range(epochs):
        model.train()
        for batch_index, (data, target) in enumerate(train_loader):
            if batch_index >= 30:
                break

            data = try_gpu(data.view(data.size(0), -1))
            target = try_gpu(target)

            optimizer.zero_grad()
            y_pred = model(data)
            loss = F.nll_loss(y_pred, target)
            loss.backward()
            optimizer.step()

        model.eval()
        correct = 0
        with torch.no_grad():
            for batch_index, (data, target) in enumerate(valid_loader):
                if batch_index >= 10:
                    break

                data = try_gpu(data.view(data.size(0), -1))
                target = try_gpu(target)

                y_pred = model(data)
                y_label = y_pred.argmax(dim=1, keepdim=True)
                correct += y_label.eq(target.view_as(y_label)).sum().item()
        
        accuracy = correct / min(len(valid_loader.dataset), batch_size * 10)
        trial.report(accuracy, epoch)

        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
    
    return accuracy

In [21]:
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100, timeout=600)

[32m[I 2022-01-30 19:03:15,864][0m A new study created in memory with name: no-name-7071fb18-3d70-4e2b-8551-2ec300099668[0m


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /mnt/d/dataset/mnist/FashionMNIST/raw/train-images-idx3-ubyte.gz


|          | 0/? [00:00<?, ?it/s]

Extracting /mnt/d/dataset/mnist/FashionMNIST/raw/train-images-idx3-ubyte.gz to /mnt/d/dataset/mnist/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /mnt/d/dataset/mnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz


|          | 0/? [00:00<?, ?it/s]

Extracting /mnt/d/dataset/mnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /mnt/d/dataset/mnist/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /mnt/d/dataset/mnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


|          | 0/? [00:00<?, ?it/s]

Extracting /mnt/d/dataset/mnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /mnt/d/dataset/mnist/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /mnt/d/dataset/mnist/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


|          | 0/? [00:00<?, ?it/s]

Extracting /mnt/d/dataset/mnist/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /mnt/d/dataset/mnist/FashionMNIST/raw
Processing...


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Done!


[32m[I 2022-01-30 19:03:46,570][0m Trial 0 finished with value: 0.09140625 and parameters: {'n_layers': 1, 'n_units_l0': 100, 'dropout_l0': 0.3398319065112406, 'optimizer': 'Adam', 'lr': 0.0006419330992797266}. Best is trial 0 with value: 0.09140625.[0m
[32m[I 2022-01-30 19:03:52,786][0m Trial 1 finished with value: 0.14296875 and parameters: {'n_layers': 2, 'n_units_l0': 44, 'dropout_l0': 0.3418626212899899, 'n_units_l1': 59, 'dropout_l1': 0.2838261521998486, 'optimizer': 'SGD', 'lr': 0.00018626623865679637}. Best is trial 1 with value: 0.14296875.[0m
[32m[I 2022-01-30 19:03:58,613][0m Trial 2 finished with value: 0.1171875 and parameters: {'n_layers': 2, 'n_units_l0': 127, 'dropout_l0': 0.2791741961606878, 'n_units_l1': 28, 'dropout_l1': 0.3497017831376982, 'optimizer': 'SGD', 'lr': 0.008033568928213733}. Best is trial 1 with value: 0.14296875.[0m
[32m[I 2022-01-30 19:04:04,592][0m Trial 3 finished with value: 0.09453125 and parameters: {'n_layers': 3, 'n_units_l0': 71, 'd

In [22]:
prune_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

In [23]:
print(f"number of finished trials: {len(study.trials)}")
print(f"number of pruned trials  : {len(prune_trials)}")
print(f"number of complete trials: {len(complete_trials)}")

number of finished trials: 100
number of pruned trials  : 78
number of complete trials: 22


In [24]:
print("best trial:")
trial = study.best_trial
print(f"  value: {trial.value}")
print("  param:")
for key, value in trial.params.items():
    print(f"     {key}: {value}")

best trial:
  value: 0.14296875
  param:
     n_layers: 2
     n_units_l0: 44
     dropout_l0: 0.3418626212899899
     n_units_l1: 59
     dropout_l1: 0.2838261521998486
     optimizer: SGD
     lr: 0.00018626623865679637
