The three extensions below are optional, for more information, see
- `watermark`:  https://github.com/rasbt/watermark
- `pycodestyle_magic`: https://github.com/mattijn/pycodestyle_magic
- `nb_black`: https://github.com/dnanhkhoa/nb_black

In [None]:
%load_ext watermark
%watermark -p torch,pytorch_lightning,torchmetrics,matplotlib

In [None]:
%load_ext pycodestyle_magic
%flake8_on --ignore W291,W293,E703

In [None]:
%load_ext nb_black

<a href="https://pytorch.org"><img src="https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/_static/img/pytorch-logo-dark.svg" width="90"/></a> &nbsp; &nbsp;&nbsp;&nbsp;<a href="https://www.pytorchlightning.ai"><img src="https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/docs/source/_static/images/logo.svg" width="150"/></a>

# TITLE

- DESCRIPTION


### References

- ???

## General settings and hyperparameters

- Here, we specify some general hyperparameter values and general settings.

In [None]:
BATCH_SIZE = 256
NUM_EPOCHS = 10
LEARNING_RATE = 0.005
NUM_WORKERS = 4

- Note that using multiple workers can sometimes cause issues with too many open files in PyTorch for small datasets. If we have problems with the data loader later, try setting `NUM_WORKERS = 0` and reload the notebook.

## Implementing a Neural Network using PyTorch Lightning's `LightningModule`

- In this section, we set up the main model architecture using the `LightningModule` from PyTorch Lightning.
- In essence, `LightningModule` is a wrapper around a PyTorch module.
- We start with defining our neural network model in pure PyTorch, and then we use it in the `LightningModule` to get all the extra benefits that PyTorch Lightning provides.

In [None]:
# UNIQUE MODEL CODE

In [None]:
%load ../code_lightningmodule/lightningmodule_classifier_basic.py

## Setting up the dataset

- In this section, we are going to set up our dataset.

### Inspecting the dataset

In [None]:
%load ../code_dataset/dataset_???_check.py

### Performance baseline

- Especially for imbalanced datasets, it's pretty helpful to compute a performance baseline.
- In classification contexts, a useful baseline is to compute the accuracy for a scenario where the model always predicts the majority class -- we want our model to be better than that!

In [None]:
%load ../code_dataset/performance_baseline.py

## A quick visual check

In [None]:
%load plot_visual-check_basic.py

### Setting up a `DataModule`

- There are three main ways we can prepare the dataset for Lightning. We can
  1. make the dataset part of the model;
  2. set up the data loaders as usual and feed them to the fit method of a Lightning Trainer -- the Trainer is introduced in the following subsection;
  3. create a LightningDataModule.
- Here, we will use approach 3, which is the most organized approach. The `LightningDataModule` consists of several self-explanatory methods, as we can see below:

In [None]:
%load ../code_lightningmodule/datamodule_???_basic.py

- Note that the `prepare_data` method is usually used for steps that only need to be executed once, for example, downloading the dataset; the `setup` method defines the dataset loading -- if we run our code in a distributed setting, this will be called on each node / GPU. 
- Next, let's initialize the `DataModule`; we use a random seed for reproducibility (so that the data set is shuffled the same way when we re-execute this code):

In [None]:
torch.manual_seed(1) 
data_module = DataModule(data_path='./data')

## Training the model using the PyTorch Lightning Trainer class

- Next, we initialize our model.
- Also, we define a call back to obtain the model with the best validation set performance after training.
- PyTorch Lightning offers [many advanced logging services](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html) like Weights & Biases. However, here, we will keep things simple and use the `CSVLogger`:

In [None]:
pytorch_model = PyTorchModel(
    ???
)

In [None]:
%load ../code_lightningmodule/logger_csv_acc_basic.py

- Now it's time to train our model:

In [None]:
%load ../code_lightningmodule/trainer_nb_basic.py

## Evaluating the model

- After training, let's plot our training ACC and validation ACC using pandas, which, in turn, uses matplotlib for plotting (PS: you may want to check out [more advanced logger](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html) later on, which take care of it for us):

In [None]:
%load ../code_lightningmodule/logger_csv_plot_basic.py

- The `trainer` automatically saves the model with the best validation accuracy automatically for us, we which we can load from the checkpoint via the `ckpt_path='best'` argument; below we use the `trainer` instance to evaluate the best model on the test set:

In [None]:
trainer.test(model=lightning_model, datamodule=data_module, ckpt_path='best')

## Predicting labels of new data

- We can use the `trainer.predict` method either on a new `DataLoader` (`trainer.predict(dataloaders=...)`) or `DataModule` (`trainer.predict(datamodule=...)`) to apply the model to new data.
- Alternatively, we can also manually load the best model from a checkpoint as shown below:

In [None]:
path = trainer.checkpoint_callback.best_model_path
print(path)

In [None]:
lightning_model = LightningModel.load_from_checkpoint(path, model=pytorch_model)
lightning_model.eval();

- For simplicity, we reused our existing `pytorch_model` above. However, we could also reinitialize the `pytorch_model`, and the `.load_from_checkpoint` method would load the corresponding model weights for us from the checkpoint file.
- Now, below is an example applying the model manually. Here, pretend that the `test_dataloader` is a new data loader.

In [None]:
%load ../code_lightningmodule/datamodule_testloader.py

- As an internal check, if the model was loaded correctly, the test accuracy below should be identical to the test accuracy we saw earlier in the previous section.

In [None]:
test_acc = acc.compute()
print(f'Test accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)')

## Inspecting Failure Cases

- In practice, it is often informative to look at failure cases like wrong predictions for particular training instances as it can give us some insights into the model behavior and dataset.
- Inspecting failure cases can sometimes reveal interesting patterns and even highlight dataset and labeling issues.

In [None]:
# In the case of ???, the class label mapping
# ???
class_dict = {???}

In [None]:
%load ../code_lightningmodule/plot_failurecases_basic.py

- In addition to inspecting failure cases visually, it is also informative to look at which classes the model confuses the most via a confusion matrix:

In [None]:
%load ../code_lightningmodule/plot_confusion-matrix_basic.py