### How to train a model with PyTorch Lightning

* Define your Model Class (in `classification/models`)
    * just as usual, as a subclass of `nn.Module`, where you define architecture in the initializer and implement the forward path in `forward(self, x)`
    * the outputs of the forward path must be of shape `[BATCH_SIZE, NUM_CLASSES]` 
* Also define your subclass of `PLModule`, i.e., PyTorch Lightning Module
    * the PyTorch Lightning Module combines your model with the DataLoader and Solver specifics
    * PyTorch Lightning will then run the training loop, log to TensorBoard, save checkpoints, etc.
    * to create your PyTorch Lightning Module, you can simply inherit from `GeneralPLModule` and set `self.model` to your model
    * if you want to customize further, you can overwrite functions from `GeneralPLModule`:
       * e.g. `prepare_data(self)`, which sets `self.dataset["train"]` and `self.dataset["val"]`

In [None]:
%load_ext autoreload
%autoreload 2

import pytorch_lightning as pl
import torch
import numpy as np
from M5 import M5, M5PLModule

In [None]:
hparams = {
    "batch_size": 64,
    "learning_rate": 3e-5,
    "weight_decay": 0,
    "lr_decay": 1
}

model = M5PLModule(hparams)

In [None]:
trainer = pl.Trainer(
    max_epochs=10,
    gpus=1 if torch.cuda.is_available() else None
)

trainer.fit(model)

### HyperParameter Search

In [None]:
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from HyperParamSearch import MetricsCallback, save_model

### Objective

In [None]:
def objective(trial):
    # as explained above, we'll use this callback to collect the validation accuracies
    metrics_callback = MetricsCallback()  
    
    # create a trainer
    trainer = pl.Trainer(
        logger=False,                                                                  # deactivate PL logging
        max_epochs=3,                                                                  # epochs
        gpus=1 if torch.cuda.is_available() else None,
        callbacks=[metrics_callback],                                                  # save latest accuracy
        early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="val_acc"), # early stopping
    )

    # here we sample the hyper params, similar as in our old random search
    trial_hparams = {"batch_size": 64, 
                     "learning_rate": trial.suggest_loguniform("learning_rate", 1e-6, 1e-1),
                     "p_drop": trial.suggest_float("p_drop", 1e-6, 1e-1),
                     "lr_decay": trial.suggest_float("lr_decay", 1e-6, 1e-1),
                     "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-1)
                    }    

    model = M5PLModule(trial_hparams)
    trainer.fit(model)

    # save model
    save_model(model, '{}.p'.format(trial.number), "checkpoints")

    print("metrics:", metrics_callback.metrics)
    # return validation accuracy from latest model, as that's what we want to minimize by our hyper param search
    return metrics_callback.metrics[-1]["val_acc"]

Search

In [None]:
pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=3, timeout=21600)

print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
best_trial = study.best_trial

print("  Value: {}".format(best_trial.value))

print("  Params: ")
for key, value in best_trial.params.items():
    print("    {}: {}".format(key, value))