In [2]:
%load_ext autoreload
%autoreload 2

import os
import sys
module_path = os.path.abspath(os.path.join('../../../src/'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [5]:
import config
from datasets.datasethandler import DatasetHandler
datasetHandler = DatasetHandler()

In [6]:
import torch
import pytorch_lightning as pl
from classification.models.SpectrogramCNN import SpectrogramCNNPLModule
from classification.models.DeepRecursiveCNN import DeepRecursiveCNNPLModule
from classification.models.CRNN import CRNNPLModule

from classification.models.SpectrogramCNN_8K import SpectrogramCNN_8KPLModule

In [5]:
from pytorch_lightning.callbacks import Callback

class SaveCallback(Callback): 
    def __init__(self, model_name):
        super().__init__()
        self.model_name = model_name
        self.best_val_acc = None

    def on_epoch_end(self, trainer, pl_module):
        if not self.best_val_acc or pl_module.val_results_history[-1]["val_acc"] > self.best_val_acc:
            print("new best val acc", pl_module.val_results_history[-1]["val_acc"])
            self.best_val_acc = pl_module.val_results_history[-1]["val_acc"]
            save_path = self.model_name + str(self.best_val_acc) + "best.p"
            pl_module.save(save_path)
            print("Saved checkpoint at epoch {} at \"{}\"".format((trainer.current_epoch + 1), save_path))
            
cb = SaveCallback("new_best_")

In [36]:
hparams = {'batch_size': 16, 
            'learning_rate': 0.0009471138112165006, 
            'p_dropout': 0.3394112556659779, 
            'n_hidden': 711, 
            'lr_decay': 0.7514824092200452, 
            'weight_decay': 0.003018912473366329}

model = SpectrogramCNN_8KPLModule(hparams)
model.prepare_data()
datasetHandler.load(model, 'training', dataset_id=config.DATASET_CONTROL)
datasetHandler.load(model, 'validation', dataset_id=config.DATASET_CONTROL)

trainer = pl.Trainer(
    max_epochs=20,
    #logger= loggers.TensorBoardLogger(config.LOG_DIR, name=type(model)._name_),
    gpus=1 if torch.cuda.is_available() else None,
    callbacks=[cb]
)

trainer.fit(model)

GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.

   | Name           | Type                 | Params
----------------------------------------------------
0  | model          | MelSpectrogramCNN_8K | 1 M   
1  | model.convs    | Sequential           | 41 K  
2  | model.convs.0  | BatchNorm2d          | 2     
3  | model.convs.1  | Conv2d               | 1 K   
4  | model.convs.2  | BatchNorm2d          | 20    
5  | model.convs.3  | PReLU                | 1     
6  | model.convs.4  | MaxPool2d            | 0     
7  | model.convs.5  | Dropout              | 0     
8  | model.convs.6  | Conv2d               | 20 K  
9  | model.convs.7  | BatchNorm2d          | 40    
10 | model.convs.8  | PReLU                | 1     
11 | model.convs.9  | MaxPool2d            | 0     
12 | model.convs.10 | Dropout              | 0     
13 | model.convs.11 | Conv2d               | 20 K  
14 | model.convs.12 | Bat

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Val-Acc=0.007113218731475993


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7587433313574392
Train-Acc=0.7316205533596838


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7943094250148192
Train-Acc=0.7818181818181819


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7919383521043272
Train-Acc=0.8067193675889328


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7800829875518672
Train-Acc=0.8132411067193676


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8085358624777712
Train-Acc=0.824505928853755


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7747480735032602
Train-Acc=0.8318181818181818


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8079430942501482
Train-Acc=0.8355731225296442


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8138707765263782
Train-Acc=0.8436758893280633


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8280972139893301
Train-Acc=0.841699604743083
new best val acc 0.8280972139893301
Saved model to "new_best_0.8280972139893301best.p"
Saved checkpoint at epoch 9 at "new_best_0.8280972139893301best.p"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8197984588026082
Train-Acc=0.8492094861660079


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8286899822169532
Train-Acc=0.8450592885375494
new best val acc 0.8286899822169532
Saved model to "new_best_0.8286899822169532best.p"
Saved checkpoint at epoch 11 at "new_best_0.8286899822169532best.p"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8328393598103142
Train-Acc=0.8488142292490118
new best val acc 0.8328393598103142
Saved model to "new_best_0.8328393598103142best.p"
Saved checkpoint at epoch 12 at "new_best_0.8328393598103142best.p"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8192056905749852
Train-Acc=0.8517786561264822


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7599288678126852
Train-Acc=0.8523715415019762


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8310610551274452
Train-Acc=0.8488142292490118


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8441019561351512
Train-Acc=0.8547430830039525
new best val acc 0.8441019561351512
Saved model to "new_best_0.8441019561351512best.p"
Saved checkpoint at epoch 16 at "new_best_0.8441019561351512best.p"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8245406046235921
Train-Acc=0.8539525691699604


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7925311203319502
Train-Acc=0.8513833992094861


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8292827504445762
Train-Acc=0.8559288537549408


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8227622999407231
Train-Acc=0.857509881422925



1

In [6]:
# model.save("/nfs/students/summer-term-2020/project-4/SAVED_MODELS/SpectrogramCNN/vanilla.p")