In [11]:
%load_ext autoreload
%autoreload 2

import os
import sys
module_path = os.path.abspath(os.path.join('../../../src/'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [13]:
from datasets.datasethandler import DatasetHandler
datasetHandler = DatasetHandler()

In [14]:
import torch
import config
import pytorch_lightning as pl
from classification.models.SpectrogramCNN import SpectrogramCNNPLModule
from classification.models.DeepRecursiveCNN import DeepRecursiveCNNPLModule

In [15]:
from pytorch_lightning.callbacks import Callback

class SaveCallback(Callback): 
    def __init__(self, model_name):
        super().__init__()
        self.model_name = model_name
        self.best_val_acc = None

    def on_epoch_end(self, trainer, pl_module):
        if not self.best_val_acc or pl_module.val_results_history[-1]["val_acc"] > self.best_val_acc:
            print("new best val acc", pl_module.val_results_history[-1]["val_acc"])
            self.best_val_acc = pl_module.val_results_history[-1]["val_acc"]
            save_path = self.model_name + str(self.best_val_acc) + "best.p"
            pl_module.save(save_path)
            print("Saved checkpoint at epoch {} at \"{}\"".format((trainer.current_epoch + 1), save_path))
            
cb = SaveCallback("new_best_")

In [None]:
hparams = {
    "batch_size": 2,
    "learning_rate": 0.001,
    "weight_decay": 0.00,
    "lr_decay": 0.0001
}

model = DeepRecursiveCNNPLModule(hparams)
model.prepare_data()
datasetHandler.load(model, 'training')
datasetHandler.load(model, 'validation')

trainer = pl.Trainer(
    max_epochs=30,
    #logger= loggers.TensorBoardLogger(config.LOG_DIR, name=type(model)._name_),
    gpus=1 if torch.cuda.is_available() else None,
    callbacks=[cb]
)

trainer.fit(model)

GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.

   | Name        | Type             | Params
---------------------------------------------
0  | model       | DeepRecursiveCNN | 2 M   
1  | model.bn0   | BatchNorm1d      | 2     
2  | model.conv1 | Conv1d           | 10 K  
3  | model.bn1   | BatchNorm1d      | 256   
4  | model.pool1 | MaxPool1d        | 0     
5  | model.drop1 | Dropout          | 0     
6  | model.conv2 | Conv1d           | 98 K  
7  | model.bn2   | BatchNorm1d      | 512   
8  | model.pool2 | MaxPool1d        | 0     
9  | model.conv3 | Conv1d           | 393 K 
10 | model.bn3   | BatchNorm1d      | 1 K   
11 | model.pool3 | MaxPool1d        | 0     
12 | model.conv4 | Conv1d           | 1 M   
13 | model.bn4   | BatchNorm1d      | 2 K   
14 | model.pool4 | MaxPool1d        | 0     
15 | model.fc1   | Linear           | 102 K 
16 | model.fcN   | Linear           | 1 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Val-Acc=0.0011855364552459987


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7658565500889153


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8061647895672792


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7895672791938352


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8156490812092472
new best val acc 0.8156490812092472
Saved model to "new_best_3best.p"
Saved checkpoint at epoch 4 at "new_best_3best.p"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8292827504445762
new best val acc 0.8292827504445762
Saved model to "new_best_3best.p"
Saved checkpoint at epoch 5 at "new_best_3best.p"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8209839952578541


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8280972139893301


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8352104327208062
new best val acc 0.8352104327208062
Saved model to "new_best_3best.p"
Saved checkpoint at epoch 8 at "new_best_3best.p"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7895672791938352


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8346176644931832


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7190278601066983


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7664493183165383


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7676348547717843


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8097213989330172


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7457024303497333


In [None]:
# model.save("/nfs/students/summer-term-2020/project-4/SAVED_MODELS/SpectrogramCNN/vanilla.p")