In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
module_path = os.path.abspath(os.path.join('../../../src/'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import config
from datasets.datasethandler import DatasetHandler
datasetHandler = DatasetHandler()

# Tutorial: Hyperparameter Search

In [3]:
import torch
import pytorch_lightning as pl
import optuna
from optuna.integration import PyTorchLightningPruningCallback

from classification.models.M5 import M5PLModule
from classification.models.SpectrogramCNN import SpectrogramCNNPLModule

from classification.trainer.HyperParamSearch import MetricsCallback, save_model

In [4]:
from pytorch_lightning.callbacks import Callback

class SaveBestCallback(Callback): 
    def __init__(self, model_name = "newest_model", add_v_number = True):
        super().__init__()
        self.model_name = model_name
        self.best_val_acc = None
        self.add_v_number = add_v_number

    def on_epoch_end(self, trainer, pl_module):
        if not self.best_val_acc or pl_module.val_results_history[-1]["val_acc"] > self.best_val_acc:
            print("new best val acc", pl_module.val_results_history[-1]["val_acc"])
            self.best_val_acc = pl_module.val_results_history[-1]["val_acc"]
            save_path = self.model_name + (( "_v{}".format(trainer.logger.version) +  "_best.p") if self.add_v_number else "")
            pl_module.save(save_path)
            print("Saved checkpoint at epoch {} at \"{}\"".format((trainer.current_epoch + 1), save_path))

cb = SaveBestCallback("optuna_spectro_2", add_v_number = False)

### Objective

In [5]:
def objective(trial):
    metrics_callback = MetricsCallback()  
        
    # here we sample the hyper params, similar as in our old random search
    trial_hparams = {"batch_size": trial.suggest_categorical('batch_size', [2,4,8,16,32,64]),
                     "learning_rate": trial.suggest_loguniform("learning_rate", 1e-6, 1e-1),
                     "p_drop": trial.suggest_float("p_drop", 0, 0.),
                     "lr_decay": trial.suggest_float("lr_decay", 0.75, 1),
                     "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-1)
                    }    

    model = SpectrogramCNNPLModule(trial_hparams)
    
    # create a trainer
    trainer = pl.Trainer(
        logger=pl.loggers.TensorBoardLogger(config.LOG_DIR, name=type(model.model).__name__),
        max_epochs=30,                                                               
        gpus=1 if torch.cuda.is_available() else None,
        callbacks=[metrics_callback, cb],                                                         # save latest accuracy
        early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="validation_acc"), # early stopping
    )
    
    datasetHandler.load(model, 'training')
    datasetHandler.load(model, 'validation')
    trainer.fit(model)

    # save model
    model.save("saved_models"+'{}.p'.format(trial.number))

    # return validation accuracy from latest model, as that's what we want to minimize by our hyper param search
    return metrics_callback.metrics[-1]["validation_acc"]

pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(direction="maximize", pruner=pruner, study_name='spectrogram_study', storage='sqlite:///spectrogram_study.db')

[32m[I 2020-06-24 20:14:03,741][0m A new study created with name: spectrogram_study[0m


### Run Search

In [None]:
study.optimize(objective, n_trials=1000, timeout=21600) #6h

print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
best_trial = study.best_trial
print("  Value: {}".format(best_trial.value))
print("  Params: ")
for key, value in best_trial.params.items():
    print("    {}: {}".format(key, value))

GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]


Load: /nfs/students/summer-term-2020/project-4/data/dataset1/dataset_resampled/training.p
Load: /nfs/students/summer-term-2020/project-4/data/dataset1/dataset_resampled/validation.p


Set SLURM handle signals.

   | Name        | Type           | Params
-------------------------------------------
0  | model       | SpectrogramCNN | 338 K 
1  | model.bn0   | BatchNorm2d    | 2     
2  | model.conv1 | Conv2d         | 505   
3  | model.bn1   | BatchNorm2d    | 10    
4  | model.conv2 | Conv2d         | 2 K   
5  | model.bn2   | BatchNorm2d    | 10    
6  | model.conv3 | Conv2d         | 20 K  
7  | model.bn3   | BatchNorm2d    | 20    
8  | model.conv4 | Conv2d         | 60 K  
9  | model.bn4   | BatchNorm2d    | 30    
10 | model.fc1   | Linear         | 255 K 
11 | model.fc2   | Linear         | 102   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Val-Acc=0.03852993479549496




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8435091879075282
new best val acc 0.8435091879075282
Saved model to "optuna_spectro_2"
Saved checkpoint at epoch 1 at "optuna_spectro_2"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7842323651452282


In [None]:
# check out v-num 5

# morgen: nochmal schnell paar mal mit num_iter trainieren.
#trial 10-1: 17 epochs -> 89.3