In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
module_path = os.path.abspath(os.path.join('../../../src/'))
if module_path not in sys.path:
    sys.path.append(module_path)

### How to train a model with PyTorch Lightning

* Define your Model Class (in `classification/models`)
    * just as usual, as a subclass of `nn.Module`, where you define architecture in the initializer and implement the forward path in `forward(self, x)`
    * the outputs of the forward path must be of shape `[BATCH_SIZE, NUM_CLASSES]` 
* Also define your subclass of `PLModule`, i.e., PyTorch Lightning Module
    * the PyTorch Lightning Module combines your model with the DataLoader and Solver specifics
    * PyTorch Lightning will then run the training loop, log to TensorBoard, save checkpoints, etc.
    * to create your PyTorch Lightning Module, you can simply inherit from `GeneralPLModule` and set `self.model` to your model
    * if you want to customize further, you can overwrite functions from `GeneralPLModule`:
       * e.g. `prepare_data(self)`, which sets `self.dataset["train"]` and `self.dataset["val"]`

In [2]:
import pytorch_lightning as pl
import torch
import numpy as np
import config

In [3]:
from classification.models.M5 import M5, M5PLModule

hparams = {
    "batch_size": 128,
    "learning_rate": 3e-5,
    "weight_decay": 0,
    "lr_decay": 0.8
}

model = M5PLModule(hparams)

In [4]:
from pytorch_lightning import loggers

trainer = None
trainer = pl.Trainer(
    max_epochs=5,
    logger= loggers.TensorBoardLogger(config.LOG_DIR, name="M5"),
    gpus=1 if torch.cuda.is_available() else None,
    log_gpu_memory='all'
)

trainer.fit(model)

GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]


Loading cached train data from /nfs/students/summer-term-2020/project-4/data/data_8k
Loading cached val data from /nfs/students/summer-term-2020/project-4/data/data_8k


Set SLURM handle signals.

   | Name           | Type         | Params
--------------------------------------------
0  | model          | M5           | 555 K 
1  | model.model    | Sequential   | 555 K 
2  | model.model.0  | Conv1d       | 10 K  
3  | model.model.1  | BatchNorm1d  | 256   
4  | model.model.2  | MaxPool1d    | 0     
5  | model.model.3  | Dropout      | 0     
6  | model.model.4  | Conv1d       | 49 K  
7  | model.model.5  | BatchNorm1d  | 256   
8  | model.model.6  | MaxPool1d    | 0     
9  | model.model.7  | Dropout      | 0     
10 | model.model.8  | Conv1d       | 98 K  
11 | model.model.9  | BatchNorm1d  | 512   
12 | model.model.10 | MaxPool1d    | 0     
13 | model.model.11 | Dropout      | 0     
14 | model.model.12 | Conv1d       | 393 K 
15 | model.model.13 | BatchNorm1d  | 1 K   
16 | model.model.14 | MaxPool1d    | 0     
17 | model.model.15 | AvgPool1d    | 0     
18 | model.model.16 | PermuteLayer | 0     
19 | model.model.17 | Linear       | 1 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Val-Acc=0.07705986959098993




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7996443390634262


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7972732661529343


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8192056905749852


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8221695317131001


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.8186129223473622



1

### HyperParameter Search

In [5]:
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from classification.trainer.HyperParamSearch import MetricsCallback, save_model

### Objective

In [6]:
def objective(trial):
    # as explained above, we'll use this callback to collect the validation accuracies
    metrics_callback = MetricsCallback()  
    
    # create a trainer
    trainer = pl.Trainer(
        logger=None,#loggers.TensorBoardLogger(config.LOG_DIR, name="M5"),
        max_epochs=2,                                                                  # epochs
        gpus=1 if torch.cuda.is_available() else None,
        callbacks=[metrics_callback],                                                  # save latest accuracy
        early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="val_acc"), # early stopping
    )
    
    #trial.logger_version = trainer.logger.version

    # here we sample the hyper params, similar as in our old random search
    trial_hparams = {"batch_size": trial.suggest_categorical('batch_size', [64]),
                     "learning_rate": trial.suggest_loguniform("learning_rate", 1e-6, 1e-1),
                     "p_drop": trial.suggest_float("p_drop", 0, 0.),
                     "lr_decay": trial.suggest_float("lr_decay", 0.8, 1),
                     "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-1)
                    }    

    model = M5PLModule(trial_hparams)
    trainer.fit(model)

    # save model
    save_model(model, '{}.p'.format(trial.number), "saved_models")

    # return validation accuracy from latest model, as that's what we want to minimize by our hyper param search
    return metrics_callback.metrics[-1]["val_acc"]

Search

In [None]:
pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=100, timeout=21600) #6h

print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
best_trial = study.best_trial

print("  Value: {}".format(best_trial.value))

print("  Params: ")
for key, value in best_trial.params.items():
    print("    {}: {}".format(key, value))

GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]


Loading cached train data from /nfs/students/summer-term-2020/project-4/data/data_8k
Loading cached val data from /nfs/students/summer-term-2020/project-4/data/data_8k


Set SLURM handle signals.

   | Name           | Type         | Params
--------------------------------------------
0  | model          | M5           | 555 K 
1  | model.model    | Sequential   | 555 K 
2  | model.model.0  | Conv1d       | 10 K  
3  | model.model.1  | BatchNorm1d  | 256   
4  | model.model.2  | MaxPool1d    | 0     
5  | model.model.3  | Dropout      | 0     
6  | model.model.4  | Conv1d       | 49 K  
7  | model.model.5  | BatchNorm1d  | 256   
8  | model.model.6  | MaxPool1d    | 0     
9  | model.model.7  | Dropout      | 0     
10 | model.model.8  | Conv1d       | 98 K  
11 | model.model.9  | BatchNorm1d  | 512   
12 | model.model.10 | MaxPool1d    | 0     
13 | model.model.11 | Dropout      | 0     
14 | model.model.12 | Conv1d       | 393 K 
15 | model.model.13 | BatchNorm1d  | 1 K   
16 | model.model.14 | MaxPool1d    | 0     
17 | model.model.15 | AvgPool1d    | 0     
18 | model.model.16 | PermuteLayer | 0     
19 | model.model.17 | Linear       | 1 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Val-Acc=0.04090100770598696




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.7581505631298162


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[32m[I 2020-06-09 14:14:22,213][0m Finished trial#0 with value: 0.8037937166567872 with parameters: {'batch_size': 64, 'learning_rate': 0.001438454886802962, 'p_drop': 0, 'lr_decay': 0.9157818891250453, 'weight_decay': 0.03254855118047262}. Best is trial#0 with value: 0.8037937166567872.[0m
GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]


Val-Acc=0.8037937166567872

Loading cached train data from /nfs/students/summer-term-2020/project-4/data/data_8k
Loading cached val data from /nfs/students/summer-term-2020/project-4/data/data_8k


Set SLURM handle signals.

   | Name           | Type         | Params
--------------------------------------------
0  | model          | M5           | 555 K 
1  | model.model    | Sequential   | 555 K 
2  | model.model.0  | Conv1d       | 10 K  
3  | model.model.1  | BatchNorm1d  | 256   
4  | model.model.2  | MaxPool1d    | 0     
5  | model.model.3  | Dropout      | 0     
6  | model.model.4  | Conv1d       | 49 K  
7  | model.model.5  | BatchNorm1d  | 256   
8  | model.model.6  | MaxPool1d    | 0     
9  | model.model.7  | Dropout      | 0     
10 | model.model.8  | Conv1d       | 98 K  
11 | model.model.9  | BatchNorm1d  | 512   
12 | model.model.10 | MaxPool1d    | 0     
13 | model.model.11 | Dropout      | 0     
14 | model.model.12 | Conv1d       | 393 K 
15 | model.model.13 | BatchNorm1d  | 1 K   
16 | model.model.14 | MaxPool1d    | 0     
17 | model.model.15 | AvgPool1d    | 0     
18 | model.model.16 | PermuteLayer | 0     
19 | model.model.17 | Linear       | 1 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Val-Acc=0.03793716656787196


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Val-Acc=0.6105512744516894


In [None]:
print(best_trial)

In [None]:
torch.save( {"state_dict": model.model.state_dict(), "hparams": model.hparams}, "adv_test_4.pt")