In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
module_path = os.path.abspath(os.path.join('../../../src/'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
from datasets.datasethandler import DatasetHandler
datasetHandler = DatasetHandler()

In [3]:
import torch
import config
import pytorch_lightning as pl
from classification.models.SpectrogramCNN import SpectrogramCNNPLModule
from classification.models.DeepRecursiveCNN import DeepRecursiveCNNPLModule
from classification.models.CRNN import CRNNPLModule

In [4]:
from pytorch_lightning.callbacks import Callback

class SaveCallback(Callback): 
    def __init__(self, model_name):
        super().__init__()
        self.model_name = model_name
        self.best_val_acc = None

    def on_epoch_end(self, trainer, pl_module):
        if not self.best_val_acc or pl_module.val_results_history[-1]["val_acc"] > self.best_val_acc:
            print("new best val acc", pl_module.val_results_history[-1]["val_acc"])
            self.best_val_acc = pl_module.val_results_history[-1]["val_acc"]
            save_path = self.model_name + str(self.best_val_acc) + "best.p"
            pl_module.save(save_path)
            print("Saved checkpoint at epoch {} at \"{}\"".format((trainer.current_epoch + 1), save_path))
            
cb = SaveCallback("new_best_")

In [5]:
hparams = {
    "batch_size": 32,
    "learning_rate": 0.002,
    "weight_decay": 0.01,
    "lr_decay": 1
}

model = CRNNPLModule(hparams)
model.prepare_data()
datasetHandler.load(model, 'training')
datasetHandler.load(model, 'validation')

trainer = pl.Trainer(
    max_epochs=30,
    #logger= loggers.TensorBoardLogger(config.LOG_DIR, name=type(model)._name_),
    gpus=1 if torch.cuda.is_available() else None,
    callbacks=[cb]
)

trainer.fit(model)

Loading cached training data of dataset 0 from /nfs/students/summer-term-2020/project-4/data/dataset1/dataset_48k/
Loading cached validation data of dataset 0 from /nfs/students/summer-term-2020/project-4/data/dataset1/dataset_48k/


GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.

   | Name                          | Type                  | Params
--------------------------------------------------------------------
0  | model                         | CRNN                  | 255 K 
1  | model.spec                    | MelspectrogramStretch | 0     
2  | model.spec.spectrogram        | Spectrogram           | 0     
3  | model.spec.mel_scale          | MelScale              | 0     
4  | model.spec.stft               | Spectrogram           | 0     
5  | model.spec.random_stretch     | RandomTimeStretch     | 0     
6  | model.spec.complex_norm       | ComplexNorm           | 0     
7  | model.spec.norm               | SpecNormalization     | 0     
8  | model.net                     | ModuleDict            | 255 K 
9  | model.net.convs               | Sequential            | 56 K  
10 | model.net.convs.conv2d_0      | Conv

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Val-Acc=0.01956135151155898




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7913455838767042
new best val acc 0.7913455838767042
Saved model to "new_best_0.7913455838767042best.p"
Saved checkpoint at epoch 1 at "new_best_0.7913455838767042best.p"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8411381149970362
new best val acc 0.8411381149970362
Saved model to "new_best_0.8411381149970362best.p"
Saved checkpoint at epoch 2 at "new_best_0.8411381149970362best.p"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8506224066390041
new best val acc 0.8506224066390041
Saved model to "new_best_0.8506224066390041best.p"
Saved checkpoint at epoch 3 at "new_best_0.8506224066390041best.p"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8565500889152341
new best val acc 0.8565500889152341
Saved model to "new_best_0.8565500889152341best.p"
Saved checkpoint at epoch 4 at "new_best_0.8565500889152341best.p"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7368109069353883


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8559573206876111


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8612922347362182
new best val acc 0.8612922347362182
Saved model to "new_best_0.8612922347362182best.p"
Saved checkpoint at epoch 7 at "new_best_0.8612922347362182best.p"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8583283935981031


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8524007113218731


Detected KeyboardInterrupt, attempting graceful shutdown...





1

In [6]:
# model.save("/nfs/students/summer-term-2020/project-4/SAVED_MODELS/SpectrogramCNN/vanilla.p")