# Logging training progess using tensorboard

quantnn now has some limited functionality to log training progress using [Tensorflow tensorboard](https://www.tensorflow.org/tensorboard/). This notebook provides an example how this functionality can be used using the convolutional rain rate retrieval form [this notebook](https://github.com/simonpf/quantnn/blob/main/notebooks/convolutional_rain_rate_retrieval.ipynb).

> **Note**: This is still very new functionality so specific details of the API may change in the future.


In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
from quantnn.plotting import set_style
set_style()




Bad key "mathtext.fallback" on line 203 in
/home/simonpf/src/quantnn/quantnn/data/matplotlib_style.rc.
You probably need to get an updated matplotlibrc file from
https://github.com/matplotlib/matplotlib/blob/v3.1.3/matplotlibrc.template
or from the matplotlib source distribution


## Training setup

To setup the example, we load the training and validation data and define a very simple convolutional network.

In [2]:
from quantnn.examples.gprof_conv import download_data
download_data()

In [3]:
data = np.load("data/gprof_conv.npz")
x_train = data["x_train"]
y_train = data["y_train"]
x_val = data["x_val"]
y_val = data["y_val"]

In [4]:
import torch
from torch import nn
from quantnn.qrnn import QRNN

quantiles = np.linspace(0.01, 0.99, 99)

def make_nn_model():
    return nn.Sequential(
        nn.Conv2d(13, 128, 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        nn.Conv2d(128, 128, 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        nn.Conv2d(128, 128, 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        nn.Conv2d(128, quantiles.size, 1)
    )

model = make_nn_model()
qrnn = QRNN(quantiles=quantiles, model=model)

## Activate tensorboard logging

Tensor board logging is activated by passing a ``quantnn.models.pytorch.logging.TensorBoardLogger`` object as the ``logger`` keyword argument to the QRNNs ``train`` method. 

The directory to which the logging data is written can be controlled using the ``log_directory`` argument passed to the ``TensorBoardLogger`` class upon initialization. If it is set to ``None``, logs for all experiments will be written into separate sub-folders in the ``run`` directory.

In [5]:
from quantnn.models.pytorch import BatchedDataset
from quantnn.models.pytorch.logging import TensorBoardLogger
from torch.optim import Adam

training_data = BatchedDataset((x_train, y_train), 4)
validation_data = BatchedDataset((x_val, y_val), 4)
n_epochs = 1

# Set this explicitly 
log_directory = None
logger = TensorBoardLogger(n_epochs,
                           log_directory=log_directory)
qrnn.train(training_data=training_data,
           validation_data=validation_data,
           n_epochs=n_epochs,
           mask=-1,
           device="gpu",
           logger=logger,
           optimizer=Adam(qrnn.model.parameters(), lr=0.01));

Output()

## Tracking validation metrics

It is now also possible to track additional metrics over the validation set. So far, I have added mean squared error and bias of the posterior mean and the CRPS score. Additionally, I added two plot metrics, which produce plots of the calibration of the predicted quantiles and a scatter plot for the posterior mean.

After training the `qrnn` object now also keeps an `training_history` attribute which is a `xarray.Dataset` that contains all tracked training statistics. This attribute is automatically saved with the QRNN, which should make it easier to keep track of the performance of different QRNNs.

> **Note**: If you look at the calibration plots, you will see that they look terrible. The reason for that is the large number of 0s in the dataset. These cause the quantiles to be ill-defined for many predictions.

In [6]:
from quantnn.metrics import ScatterPlot

# Metrics to be tracked can be defined either via their class name
metrics = ["MeanSquaredError", "Bias", "CRPS", "CalibrationPlot"]
# or by directly providing a metric object (If there are configuration parameters to set).
scatter_plot = ScatterPlot(bins=np.logspace(-2, 2, 21), log_scale=True)
metrics.append(scatter_plot)

logger = TensorBoardLogger(n_epochs)
qrnn.train(training_data=training_data,
           validation_data=validation_data,
           n_epochs=n_epochs,
           mask=-1,
           device="gpu",
           logger=logger,
           metrics=metrics,
           optimizer=Adam(qrnn.model.parameters(), lr=0.01));

Output()

  img /= norm


In [7]:
qrnn.training_history

## Keeping track of hyperparameters

To keep track of hyperparameters, the logger now as an additional method ``set_attributes`` that allows you to pass a ``dict`` containing numerical values and strings that will be stored in the tensor board as well as the attributes field of the QRNNs training history.

In [8]:
# Metrics to be tracked can be defined either via their class name
metrics = ["MeanSquaredError", "Bias", "CRPS", "CalibrationPlot"]
# or by directly providing a metric object (If there are configuration parameters to set).
scatter_plot = ScatterPlot(bins=np.logspace(-2, 2, 21), log_scale=True)
metrics.append(scatter_plot)


for lr in [1e-1, 1e-2, 1e-3]:
    # Use new model for each training.
    qrnn.model = make_nn_model()
    
    # Log hyperparameters.
    logger = TensorBoardLogger(n_epochs)
    logger.set_attributes({"optimizer": "Adam", "learning_rate": lr})
    
    optimizer = Adam(qrnn.model.parameters(), lr=lr)
    
    qrnn.train(training_data=training_data,
               validation_data=validation_data,
               n_epochs=5,
               mask=-1,
               device="gpu",
               logger=logger,
               metrics=metrics,
               optimizer=optimizer)

Output()

  img /= norm
  img /= norm
  img /= norm
  img /= norm
  img /= norm


Output()

  img /= norm
  img /= norm
  img /= norm
  img /= norm
  img /= norm


Output()

  img /= norm
  img /= norm
  img /= norm
  img /= norm
  img /= norm


In [9]:
qrnn.training_history

## Tracking progress on a specific input

By default, the tensor board logger will log only training and validaiton error. Custom data can be logged using a callback hook (defined by the ``epoch_begin_callback`` attributed of the ``TensorBoardLogger`` class) that is called at the beginning of each epoch.


The ``epoch_begin_callback`` is expected to have the following signature:


````python
def epoch_begin_callback(model, writer, epoch_index):
````
where the arguments correspond to the following
- ``model``: The PyTorch model that is trained in its current state.
- ``writer``: The  [SummaryWriter](https://pytorch.org/docs/stable/tensorboard.html)  that is used to log data for the current training session.
                
- ``epoch_index``: The index (zero-based) of the current epoch.
        
The example below illustrates how this functionality can be used to track the prediction on a specific input at the beginning of each epoch:

In [10]:
x = x_val[:1]
y = y_val[:1]

def make_prediction(writer, model, epoch_index):
    """
    Predicts the mean precipitation rate on the first sample
    from the validation set.
    
    Args:
        writer: The SummaryWriter object that is used to log
             to the tensbor board.
        model: The model attributed of the qrnn object being
            trained.
        epoch_index: The index (zero-based) of the current
            epoch.
    """
    # Make prediction
    y_mean = qrnn.posterior_mean(x)
    # Store output using add_image function of SummaryWriter
    writer.add_image("predicted_rain_rate", y_mean, epoch_index)
    
    # Store reference imager using add_image function of
    # SummaryWriter. No need to store for every epoch.
    writer.add_image("reference_rain_rate", y, 0)

In [11]:
from quantnn.models.pytorch import BatchedDataset
from quantnn.models.pytorch.logging import TensorBoardLogger


training_data = BatchedDataset((x_train, y_train), 4)
validation_data = BatchedDataset((x_val, y_val), 4)
n_epochs = 5
logger = TensorBoardLogger(n_epochs,
                           log_directory=None,
                           epoch_begin_callback=make_prediction)
qrnn.train(training_data=training_data,
           validation_data=validation_data,
           n_epochs=n_epochs,
           mask=-1,
           device="gpu",
              logger=logger);

Output()