In [14]:
%load_ext autoreload
%autoreload 2

import os
import sys
module_path = os.path.abspath(os.path.join('../../../src/'))
if module_path not in sys.path:
    sys.path.append(module_path)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
import config
from datasets.datasethandler import DatasetHandler
datasetHandler = DatasetHandler()

# Tutorial: Hyperparameter Search

In [None]:
import torch
import pytorch_lightning as pl
import optuna
from optuna.integration import PyTorchLightningPruningCallback

from classification.models.M5 import M5PLModule
from classification.models.SpectrogramCNN import SpectrogramCNNPLModule
from classification.models.DeepRecursiveCNN import DeepRecursiveCNNPLModule
from classification.models.MelSpectrogramCNN_8K import MelSpectrogramCNN_8KPLModule

from classification.models.SpectrogramCNN_8K_Dataset2 import SpectrogramCNN_8K_Dataset2_PLModule
from classification.models.CRNN8k_D2 import CRNN8k_D2_PLModule
from classification.trainer.HyperParamSearch import MetricsCallback, save_model

In [9]:
from pytorch_lightning.callbacks import Callback

class SaveBestCallback(Callback): 
    def __init__(self, model_name = "newest_model", add_v_number = True):
        super().__init__()
        self.model_name = model_name
        self.best_val_acc = None
        self.add_v_number = add_v_number

    def on_epoch_end(self, trainer, pl_module):
        if not self.best_val_acc or pl_module.val_results_history[-1]["val_acc"] > self.best_val_acc:
            print("new best val acc", pl_module.val_results_history[-1]["val_acc"])
            self.best_val_acc = pl_module.val_results_history[-1]["val_acc"]
            #save_path = self.model_name + (( "_v{}".format(trainer.logger.version) +  "_best.p") if self.add_v_number else "")
            save_path = self.model_name + str(pl_module.val_results_history[-1]["val_acc"]) + ".p"
            pl_module.save(save_path, overwrite_if_exists=True)
            print("Saved checkpoint at epoch {} at \"{}\"".format((trainer.current_epoch + 1), save_path))

cb = SaveBestCallback("optuna_crnn8kd2_", add_v_number = False)

In [10]:
class Print(nn.Module):
    def __init__(self):
        super(Print, self).__init__()

    def forward(self, x):
        print(x.shape)
        return x

### Objective

In [18]:
def objective(trial):
    metrics_callback = MetricsCallback()  

    # here we sample the hyper params, similar as in our old random search
    trial_hparams = {"batch_size": 64,
                     "learning_rate": trial.suggest_loguniform("learning_rate", 1e-6, 1e-1),
                     "p_dropout": trial.suggest_float("p_drop", 0, 1),
                     "lr_decay": trial.suggest_float("lr_decay", 0.7, 1),
                     "n_hidden": trial.suggest_int("n_hidden", 10, 1000),
                     "lstm_hidden_size": trial.suggest_int("n_hidden", 10, 1000),
                     "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-1)
                    }    

    model = CRNN8k_D2_PLModule(trial_hparams)
    
    # create a trainer
    trainer = pl.Trainer(
        logger=pl.loggers.TensorBoardLogger(config.LOG_DIR, name=type(model.model).__name__),
        max_epochs=15,                                                               
        gpus=1 if torch.cuda.is_available() else None,
        callbacks=[metrics_callback, cb],                                                   # save latest accuracy
        early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="validation_acc"), # early stopping
    )
    
    datasetHandler.load(model, 'training', dataset_id=config.DATASET_CONTROL)
    datasetHandler.load(model, 'validation', dataset_id=config.DATASET_CONTROL)
    trainer.fit(model)

    # save model
    model.save("saved_modelsq/"+'{}.p'.format(trial.number))

    # return validation accuracy from latest model, as that's what we want to minimize by our hyper param search
    return metrics_callback.metrics[-1]["validation_acc"]

pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(direction="maximize", pruner=pruner, study_name='CRNN_8K_D2', storage='sqlite:///CRNN_8K_D2_study.db', load_if_exists=True)

[I 2020-07-14 23:32:02,763] A new study created with name: CRNN_8K_D2


### Run Search

In [None]:
study.optimize(objective, n_trials=2000, timeout=1.5*21600) #9h

GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]


Loading cached training data of dataset 1 from /nfs/students/summer-term-2020/project-4/data/dataset2/dataset_8k/
Loading cached validation data of dataset 1 from /nfs/students/summer-term-2020/project-4/data/dataset2/dataset_8k/


Set SLURM handle signals.

   | Name                      | Type                  | Params
----------------------------------------------------------------
0  | model                     | CRNN8k_D2             | 2 M   
1  | model.spec                | MelspectrogramStretch | 0     
2  | model.spec.spectrogram    | Spectrogram           | 0     
3  | model.spec.mel_scale      | MelScale              | 0     
4  | model.spec.stft           | Spectrogram           | 0     
5  | model.spec.random_stretch | RandomTimeStretch     | 0     
6  | model.spec.complex_norm   | ComplexNorm           | 0     
7  | model.spec.norm           | SpecNormalization     | 0     
8  | model.convs               | Sequential            | 213 K 
9  | model.convs.0             | Conv2d                | 320   
10 | model.convs.1             | BatchNorm2d           | 64    
11 | model.convs.2             | ELU                   | 0     
12 | model.convs.3             | Conv2d                | 9 K   
13 | model.c

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Val-Acc=0.03793716656787196




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.5103734439834025
Train-Acc=0.7027667984189724
new best val acc 0.5103734439834025




Saved model to "optuna_crnn8kd2_0.5103734439834025.p"
Saved checkpoint at epoch 1 at "optuna_crnn8kd2_0.5103734439834025.p"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7611144042679312
Train-Acc=0.7618577075098815
new best val acc 0.7611144042679312
Saved model to "optuna_crnn8kd2_0.7611144042679312.p"
Saved checkpoint at epoch 2 at "optuna_crnn8kd2_0.7611144042679312.p"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Detected KeyboardInterrupt, attempting graceful shutdown...



Saved model to "saved_modelsq/0.p"


[I 2020-07-14 23:32:32,895] Finished trial#0 with value: 0.7611144042679312 with parameters: {'learning_rate': 0.035720295454924424, 'p_drop': 0.7375927486897605, 'lr_decay': 0.9443988985897781, 'n_hidden': 677, 'weight_decay': 0.00255158147362303}. Best is trial#0 with value: 0.7611144042679312.
GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.

   | Name                      | Type                  | Params
----------------------------------------------------------------
0  | model                     | CRNN8k_D2             | 1 M   
1  | model.spec                | MelspectrogramStretch | 0     
2  | model.spec.spectrogram    | Spectrogram           | 0     
3  | model.spec.mel_scale      | MelScale              | 0     
4  | model.spec.stft           | Spectrogram           | 0     
5  | model.spec.random_stretch | RandomTimeStretch     | 0     
6  | model.spec.complex_norm   | ComplexNorm   

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Val-Acc=0.03793716656787196


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.5103734439834025
Train-Acc=0.6956521739130435


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.5103734439834025
Train-Acc=0.7608695652173914


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.5080023710729105
Train-Acc=0.7618577075098815


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.5097806757557795
Train-Acc=0.7630434782608696


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.5091879075281565
Train-Acc=0.766600790513834


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.5749851807943094
Train-Acc=0.7660079051383399


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.5305275637225845
Train-Acc=0.7654150197628459


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.5411973918197984
Train-Acc=0.766798418972332


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7302904564315352
Train-Acc=0.766798418972332


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.5785417901600475
Train-Acc=0.7713438735177865


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.5275637225844695
Train-Acc=0.766205533596838


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.6283343212803794
Train-Acc=0.7709486166007905


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.6366330764671013
Train-Acc=0.7691699604743083


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.5649081209247184
Train-Acc=0.7782608695652173


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.5085951393005335
Train-Acc=0.76699604743083

Saved model to "saved_modelsq/1.p"


In [16]:
#!mkdir saved_modelsq

In [10]:
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
best_trial = study.best_trial
print("  Value: {}".format(best_trial.value))
print("  Params: ")
for key, value in best_trial.params.items():
    print("    {}: {}".format(key, value))

Number of finished trials: 666
Best trial:
  Value: 0.8974510966212211
  Params: 
    learning_rate: 0.0002514789609647403
    lr_decay: 0.7403277844353311
    n_hidden: 938
    p_drop: 0.11472402894609018
    weight_decay: 0.08362983139422733
