In [None]:
import torch
from utils.pytorch_dataset import data_loaders
import yaml
import optuna
from torch.utils.tensorboard import SummaryWriter


In [None]:
IN_FEATURES = 8
CLASSES = 9
EPOCHS = 100

The code below uses Optuna to tune the hyperparameters of a neural network classifier

In [None]:

def define_model(trial):
  n_layers = trial.suggest_int("n_layers", 1, 3)
  layers = []

  in_features = IN_FEATURES

  for i in range(n_layers):
    out_features = trial.suggest_int(f"n_units_l{i}", 2, 18)
    layers.append(torch.nn.Linear(in_features, out_features))
    layers.append(torch.nn.ReLU())
    p = trial.suggest_float(f"dropout_l{i}", 0.2, 0.8)
    layers.append(torch.nn.Dropout(p))

    in_features = out_features
  
  layers.append(torch.nn.Linear(in_features, CLASSES))
  layers.append(torch.nn.LogSoftmax(dim=1))

  return torch.nn.Sequential(*layers)

def train(trial):
  model = define_model(trial)

  optimizer_name = trial.suggest_categorical("optimizer", ['Adam', 'RMSprop', 'SGD'])
  lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)
  # The first half finds the relevant class of optimizer by using the names as passed in above, and then you pass in the parameters and the learning rate into that model
  optimizer = getattr(torch.optim, optimizer_name)(model.parameters(), lr=lr)

  for epoch in range(EPOCHS):
    # Training of the model
    model.train()
    
    for batch_idx, (data, target) in enumerate(data_loaders['train']):
      # Option to limit data for faster epochs
      # if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
      #   break

      # Option to send to device
      # data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)

      optimizer.zero_grad()
      output = model(data)
      loss = torch.nn.functional.nll_loss(output, target)
      loss.backward()
      optimizer.step()
    
    # Validation
    model.eval()
    correct = 0

    with torch.no_grad():
      for batch_idx, (data, target) in enumerate(data_loaders['val']):
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    
    accuracy = correct / len(data_loaders['val'].dataset)

    trial.report(accuracy, epoch)

    if trial.should_prune():
      raise optuna.exceptions.TrialPruned()

  return accuracy

In [None]:
study = optuna.create_study(direction='maximize')
study.optimize(train, n_trials=100, timeout=600)

pruned_trials = study.get_trials(deepcopy=False, states=[optuna.trial.TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[optuna.trial.TrialState.COMPLETE])

print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
  print("    {}: {}".format(key, value))