# Fully-connected neural network retrieval

This notebook implements a precipitation retrieval for the IPWG SPR dataset based on a convolutional neural network implemented using PyTorch and Lightning.

> **NOTE**: This notebook can be run on Google Colab. To install the necessary dependencies uncomment the following cell and execute it.

In [1]:
#!pip install ipwgml[complete]@git+https://github.com/simonpf/ipwgml

## The training data

The ``ipwgml`` package provides dataset classes that take care of downloading, preprocessing and loading of the SPR training data. The ``ipwgml.pytorch.dataset.SPRTabular`` class implements the PyTorch dataset interface for the SPR data in tabular format, i.e., for pixel-based retrievals.

The retrieval input data to load and the quality criteria for the reference surface precipitation data can be configured using the ``InputConfig`` classes from the ``ipwgml.input`` module and the ``TargetConfig`` class from the ``ipgml.target`` module, respectively.

For this example notebook we use only GMI observations as input. We configure the ``GMI`` input to not include the earth-incidence angles. Moreover, enable minimum-maximum normalization of the input and replace NAN values with -1.5.

> **NOTE**: Both input and reference data in the SPR dataset may contain NANs because observations or precipitation estimates may be missing or of insufficient quality. It is up to the user to handle those.

We choose the ``on_swath`` geometry for the retrieval, which is a natural choice for pixel-based retrievals. We also set up batching in the dataset, which is more efficient for tabular data than leaving it to the PyTorch data loader to perform the batching.

In [2]:
from ipwgml.input import GMI, Ancillary, Geo, GeoIR
from ipwgml.target import TargetConfig

target_config = TargetConfig(min_rqi=0.5)
inputs = [GMI(normalize="minmax", nan=-1.5, include_angles=False)]
geometry = "on_swath"
batch_size = 1024

With these settings, we can instantiate the training data dataset. By setting ``stack=True`` we also tell the retrieval to stack all input tensors instead of loading the input data as a dictionary.

> **Note**: The current implementation of the dataset loads all training data into memory upon instantiation. Therefore, executing the cell below may take some time.

In [None]:
from torch.utils.data import DataLoader
from ipwgml.pytorch.datasets import SPRTabular

training_data = SPRTabular(
    reference_sensor="gmi",
    geometry=geometry,
    split="training",
    retrieval_input=inputs,
    batch_size=batch_size,
    target_config=target_config,
    stack=True,
    download=True,
)
training_loader = DataLoader(training_data, shuffle=True, batch_size=None, num_workers=4)

With the above configuration, the training data loaded by the ``training_loader`` has the following dimensions.

In [None]:
inpt, target = next(iter(training_loader))
print("Input tensor shape: ", inpt.shape)
print("Target tensor shape: ", target.shape)

We also create a validation loader with the same configuration.

In [None]:
validation_data = SPRTabular(
    reference_sensor="gmi",
    geometry="on_swath",
    split="validation",
    retrieval_input=inputs,
    batch_size=batch_size,
    target_config=target_config,
    stack=True,
    download=True,
    shuffle=False
)
validation_loader = DataLoader(validation_data, shuffle=True, batch_size=None)

In [None]:
from typing import Any, Callable, Dict

import torch
from torch import optim
from torch import nn
from torch.nn.functional import binary_cross_entropy_with_logits
import lightning as L

OUTPUTS = [
    "surface_precip",
    "probability_of_precip",
    "probability_of_heavy_precip"
]

N_EPOCHS = 20

class MLP(L.LightningModule):
    """
    Lightning module implementing a multi-layer perceptron (MLP) for retrieving precipitation from satellite
    observations.
    """
    def __init__(
        self,
        n_input_features: int,
        n_hidden_layers: int,
        n_neurons: int,
        activation_fn: Callable[[], nn.Module] = nn.GELU,
        normalization_layer: Callable[[int], nn.Module] = nn.LayerNorm,
        n_epochs: int = N_EPOCHS
    ):
        """
        Args:
            n_input_features: The number of features in the input
            n_hidden_layers: The number of hidden layers in the MLP
            n_neurons: The number of neurons in the hidden layers
            activation_fn: A callable to create activation function layers.
            normalization_layer: A callable to create normalization layers.
            n_epochs: The numebr of epochs the model will be trained for.
        """
        super().__init__()
        blocks = [
            nn.Linear(n_input_features, n_neurons),
            normalization_layer(n_neurons),
            activation_fn(),
        ]
        for _ in range(n_hidden_layers):
            blocks += [
                nn.Linear(n_neurons, n_neurons),
                normalization_layer(n_neurons),
                activation_fn(),
            ]
        self.body = nn.Sequential(*blocks)

        heads = {}
        for output in OUTPUTS:
            heads[output] =  nn.Sequential(
                nn.Linear(n_neurons, n_neurons),
                normalization_layer(n_neurons),
                activation_fn(),
                nn.Linear(n_neurons, 1)
            )
        self.heads = nn.ModuleDict(heads)
        self.n_epochs = N_EPOCHS

    def forward(self, retrieval_input: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Forward retrieval input through network and produce dictionary with predictions.

        Args:
            retrieval_input: The retrieval input as a single torch.Tensor.

        Return:
            A dictionary containing the predictions for 'surface_precip', 'probability_of_precip',
            and 'probability_of_heavy_precip'.
        """
        y = self.body(retrieval_input)
        return {
            name: module(y) for name, module in self.heads.items()
        }
        
    def training_step(self, batch, batch_idx) -> torch.Tensor:
        """
        Calculates the loss-function gradients for the MLP.

        The loss is calculated as the sum of the MSE for 'surface_precip' and the binary cross-entropy loss
        for precipitation detection and heavy precipitation detection.

        Args:
            batch: A tuple containing the training data loaded from the data loader.
            batch_idx: The index of the batch in the current epoch. Not used.

        Return:
            A scalar torch.Tensor containing the total loss.
        """
        inpt, surface_precip = batch
        pred = self(inpt)

        valid = torch.isfinite(surface_precip)
        surface_precip = surface_precip[valid]
        precip_mask = (surface_precip > 1e-3).to(dtype=torch.float32)
        heavy_precip_mask = (surface_precip > 10).to(dtype=torch.float32)
        surface_precip_pred = pred["surface_precip"][valid]
        pop = pred["probability_of_precip"][valid]
        pohp = pred["probability_of_heavy_precip"][valid]
        
        # MSE loss for QPE
        loss_quant = ((surface_precip_pred[..., 0] - surface_precip) ** 2).mean()
        # BCE loss for detection targets
        loss_detect = binary_cross_entropy_with_logits(pop[..., 0], precip_mask)
        loss_detect_heavy = binary_cross_entropy_with_logits(pohp[..., 0], heavy_precip_mask)
        tot_loss =  loss_quant + loss_detect + loss_detect_heavy
        return tot_loss

    def validation_step(self, batch, batch_idx) -> None:
        """
        Calculates the loss-function values on validation data.

        Args:
            batch: A tuple containing the training data loaded from the data loader.
            batch_idx: The index of the batch in the current epoch. Not used.
        """
        inpt, surface_precip = batch
        pred = self(inpt)

        valid = torch.isfinite(surface_precip)
        surface_precip = surface_precip[valid]
        precip_mask = (surface_precip > 1e-3).to(dtype=torch.float32)
        heavy_precip_mask = (surface_precip > 10).to(dtype=torch.float32)
        surface_precip_pred = pred["surface_precip"][valid]
        pop = pred["probability_of_precip"][valid]
        pohp = pred["probability_of_heavy_precip"][valid]
        
        # MSE loss for QPE
        loss_quant = ((surface_precip_pred[..., 0] - surface_precip) ** 2).mean()
        # BCE loss for detection targets
        loss_detect = binary_cross_entropy_with_logits(pop[..., 0], precip_mask)
        loss_detect_heavy = binary_cross_entropy_with_logits(pohp[..., 0], heavy_precip_mask)
        tot_loss =  loss_quant + loss_detect + loss_detect_heavy

        opt = self.optimizers()
        learning_rate = opt.param_groups[0]['lr']
        
        self.log_dict(
            {
                "val_loss": loss_quant + loss_detect + loss_detect_heavy,
                "val_loss_quant": loss_quant,
                "val_loss_detect": loss_detect,
                "val_loss_detect_heavy": loss_detect_heavy,
                "learning_rate": learning_rate
            },
            on_epoch=True,
            prog_bar=True
        )
    
    def configure_optimizers(self) -> Dict[str, Any]:
        """
        We use the Adam optimizer with a cosine annealing learning rate schedule.
        """
        optimizer = optim.Adam(self.parameters(), lr=5e-4)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.n_epochs)
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler
        }

In [None]:
from ipwgml.input import calculate_input_features
input_features = calculate_input_features(inputs, stack=True)
mlp = MLP(n_input_features=input_features, n_hidden_layers=4, n_neurons=256)

In [None]:
trainer = L.Trainer(max_epochs=N_EPOCHS)
trainer.fit(model=mlp, train_dataloaders=training_loader, val_dataloaders=validation_loader)

## Evaluating the retrieval

In oder to evaluate the fully-connected precipitation retrieval, we need a function  to conventional precipitation retrievals, we can use the ``ipwgml.evaluation.Evaluator``. The evaluates the retrieval using the exact same data used to evaluate [the IMERG](evaluate_imerg.ipynb) and [GPROF](evaluate_gprof.ipynb) retrievals and thus ensures that the results are comparable. To ensure consistency between training and evaluation data, we instantiate the evaluator with the same values for ``geometry`` and ``retrieval_input`` as the training and validation dataset objects.

In [None]:
from ipwgml.evaluation import Evaluator
evaluator = Evaluator(
    sensor="gmi",
    geometry=geometry,
    retrieval_input=inputs,
    download=True
)

## The retrieval callback function

Evaluating the MLP using the ``ipwgml.evaluation.Evaluator`` requires implementing a retrieval callback function that the evaluator can call to obtain the rerieval results for a given collocation scene. The ``ipwgml.pytorch`` module provides a wrapper class that turns a given Pytorch retrieval into such a callback function. The ``PytorchRetrieval`` class simply takes the data from the evaluator and converts it to ``torch.Tensor`` objects and puts the results back into an ``xarray.Dataset``.

In [None]:
from ipwgml.pytorch import PytorchRetrieval
mlp_retrieval = PytorchRetrieval(mlp, retrieval_input=inputs, stack=True, device=torch.device("cuda"))

## Case study

In [None]:
fig = evaluator.plot_retrieval_results(86, mlp_retrieval, input_data_format="tabular", batch_size=1024)

## Evaluation 

In [None]:
evaluator.evaluate(retrieval_fn=mlp_retrieval, input_data_format="tabular", batch_size=4048, n_processes=1)

## Results

### Precipitation quantification

In [None]:
evaluator.get_precip_quantification_results(name="MLP (GMI)").T

In [None]:
sc = evaluator.precip_quantification_metrics[-1].compute()

### Precipitation detection

In [None]:
evaluator.get_precip_detection_results(name="MLP (GMI)").T

### Probabilistic precipitation detection

In [None]:
evaluator.get_prob_precip_detection_results(name="MLP (GMI)").T

### Heavy precipitation detection

In [None]:
evaluator.get_heavy_precip_detection_results(name="MLP (GMI)").T

### Heavy probabilistic precipitation detection

In [None]:
evaluator.get_prob_heavy_precip_detection_results(name="MLP (GMI)").T