# Fully-connected neural network

This notebook implements a precipitation retrieval for the IPWG SPR dataset based on a fully-connected neural network implemented using PyTorch and Lightning.

## The training data

The ``ipwgml`` package provides dataset classes that take care of downloading, preprocessing and loading of the SPR training data. The ``ipwgml.pytorch.dataset.SPRTabular`` class implements the PyTorch dataset interface for the SPR data in tabular format, i.e., for pixel-based retrievals.

The retrieval input data to load and the quality criteria for the reference surface precipitation data can be configured using the ``InputConfig`` classes from the ``ipwgml.input`` module and the ``TargetConfig`` class from the ``ipgml.target`` module, respectively.

For this example notebook we use only GMI observations as input. We configure the ``GMI`` input to not include the earth-incidence angles. Moreover, enable minimum-maximum normalization of the input and replace NAN values with -1.5.

> **NOTE**: Both input and reference data in the SPR dataset may contain NANs because observations or precipitation estimates may be missing or of insufficient quality. It is up to the user to handle those.

We choose the ``on_swath`` geometry for the retrieval, which is a natural choice for pixel-based retrievals. We also set up batching in the dataset, which is more efficient for tabular data than leaving it to the PyTorch data loader to perform the batching.

In [None]:
from ipwgml.input import GMI, Ancillary, Geo, GeoIR
from ipwgml.target import TargetConfig

target_config = TargetConfig(min_rqi=0.5)
inputs = [GMI(include_angles=False, normalize="minmax", nan=-1.5)]
geometry = "on_swath"
batch_size = 1024

With these settings, we can instantiate the training data dataset. By setting ``stack=True`` we also tell the retrieval to stack all input tensors instead of loading the input data as a dictionary.

In [None]:
from torch.utils.data import DataLoader
from ipwgml.pytorch.datasets import SPRTabular

training_data = SPRTabular(
    sensor="gmi",
    geometry=geometry,
    split="training",
    retrieval_input=inputs,
    batch_size=batch_size,
    target_config=target_config,
    stack=True,
    download=True
)
training_loader = DataLoader(training_data, shuffle=True, batch_size=None)

We the above configuration, the training data loaded by the ``training_loader`` has the following dimensions.

In [None]:
inpt, target = next(iter(training_loader))
print("Input tensor shape: ", inpt.shape)
print("Target tensor shape: ", target.shape)

We also create a validation loader with the same configuration.

In [None]:
validation_data = SPRTabular(
    sensor="gmi",
    geometry="on_swath",
    split="validation",
    retrieval_input=inputs,
    batch_size=batch_size,
    target_config=target_config,
    stack=True,
    download=True
)
validation_loader = DataLoader(validation_data, shuffle=True, batch_size=None)

In [None]:
from typing import Any, Callable, Dict

import torch
from torch import optim
from torch import nn
from torch.nn.functional import binary_cross_entropy_with_logits
import lightning as L

OUTPUTS = [
    "surface_precip",
    "probability_of_precipitation",
    "probability_of_heavy_precipitation"
]

N_EPOCHS = 20

class MLP(L.LightningModule):
    """
    Lightning module implementing a multi-layer perceptron (MLP) for retrieving precipitation from satellite
    observations.
    """
    def __init__(
        self,
        n_input_features: int,
        n_hidden_layers: int,
        n_neurons: int,
        activation_fn: Callable[[], nn.Module] = nn.ReLU,
        normalization_layer: Callable[[int], nn.Module] = nn.LayerNorm
    ):
        """
        Args:
            n_input_features: The number of features in the input
            n_hidden_layers: The number of hidden layers in the MLP
            n_neurons: The number of neurons in the hidden layers
            activation_fn: A callable to create activation function layers.
            normalization_layer: A callable to create normalization layers.
        """
        super().__init__()
        blocks = [
            nn.Linear(n_input_features, n_neurons),
            activation_fn(),
            normalization_layer(n_neurons)
        ]
        for _ in range(n_hidden_layers):
            blocks += [
                nn.Linear(n_neurons, n_neurons),
                activation_fn(),
                normalization_layer(n_neurons)
            ]
        self.body = nn.Sequential(*blocks)

        heads = {}
        for output in OUTPUTS:
            heads[output] =  nn.Sequential(
                nn.Linear(n_neurons, n_neurons),
                activation_fn(),
                normalization_layer(n_neurons),
                nn.Linear(n_neurons, 1)
            )
        self.heads = nn.ModuleDict(heads)

    def forward(self, retrieval_input: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Forward retrieval input through network and produce dictionary with predictions.

        Args:
            retrieval_input: The retrieval input as a single torch.Tensor.

        Return:
            A dictionary containing the predictions for 'surface_precip', 'probability_of_precipitation',
            and 'probability_of_heavy_precipitation'.
        """
        y = self.body(retrieval_input)
        return {
            name: module(y) for name, module in self.heads.items()
        }
        
    def training_step(self, batch, batch_idx) -> torch.Tensor:
        """
        Calculates the loss-function gradients for the MLP.

        The loss is calculated as the sum of the MSE for 'surface_precip' and the binary cross-entropy loss
        for precipitation detection and heavy precipitation detection.

        Args:
            batch: A tuple containing the training data loaded from the data loader.
            batch_idx: The index of the batch in the current epoch. Not used.

        Return:
            A scalar torch.Tensor containing the total loss.
        """
        inpt, surface_precip = batch
        pred = self(inpt)

        valid = torch.isfinite(surface_precip)
        surface_precip = surface_precip[valid]
        precip_mask = (surface_precip > 1e-3).to(dtype=torch.float32)
        heavy_precip_mask = (surface_precip > 10).to(dtype=torch.float32)
        surface_precip_pred = pred["surface_precip"][valid]
        pop = pred["probability_of_precipitation"][valid]
        pohp = pred["probability_of_heavy_precipitation"][valid]
        
        # MSE loss for QPE
        loss_estim = ((surface_precip_pred[..., 0] - surface_precip) ** 2).mean()
        # BCE loss for detection targets
        loss_detect = binary_cross_entropy_with_logits(pop[..., 0], precip_mask)
        loss_detect_heavy = binary_cross_entropy_with_logits(pop[..., 0], heavy_precip_mask)
        tot_loss =  loss_estim + loss_detect + loss_detect_heavy
        self.log("loss", loss_estim, prog_bar=True)
        return tot_loss

    def validation(self, batch, batch_idx) -> None:
        """
        Calculates the loss-function values on validation data.

        Args:
            batch: A tuple containing the training data loaded from the data loader.
            batch_idx: The index of the batch in the current epoch. Not used.
        """
        inpt, surface_precip = batch
        pred = self(inpt)

        valid = torch.isfinite(surface_precip)
        surface_precip = surface_precip[valid]
        precip_mask = (surface_precip > 1e-3).to(dtype=torch.float32)
        heavy_precip_mask = (surface_precip > 10).to(dtype=torch.float32)
        surface_precip_pred = pred["surface_precip"][valid]
        pop = pred["probability_of_precipitation"][valid]
        pohp = pred["probability_of_heavy_precipitation"][valid]
        
        # MSE loss for QPE
        loss_estim = ((surface_precip_pred[..., 0] - surface_precip) ** 2).mean()
        # BCE loss for detection targets
        loss_detect = binary_cross_entropy_with_logits(pop[..., 0], precip_mask)
        loss_detect_heavy = binary_cross_entropy_with_logits(pop[..., 0], heavy_precip_mask)
        tot_loss =  loss_estim + loss_detect + loss_detect_heavy
        self.log("loss_estim", loss_estim, prog_bar=True)
        self.log("loss_detect", loss_detect, prog_bar=True)
        self.log("loss_detect_heavy", loss_detect, prog_bar=True)
    
    def configure_optimizers(self) -> Dict[str, Any]:
        """
        We use the Adam optimizer with a cosine annealing learning rate schedule.
        """
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = N_EPOCHS)
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler
        }

In [None]:
mlp = MLP(n_input_features=13, n_hidden_layers=4, n_neurons=256)

In [None]:
trainer = L.Trainer(max_epochs=N_EPOCHS)
trainer.fit(model=mlp, train_dataloaders=training_loader)

In [None]:
from ipwgml.evaluation import Evaluator
evaluator = Evaluator(
    sensor="gmi",
    geometry="on_swath",
    retrieval_input=[{"name": "gmi", "normalize": "minmax", "nan": -1.5}],
    download=False
)
    

In [None]:
import xarray as xr
mlp = mlp.eval()

def retrieval_fn(retrieval_input: xr.Dataset):
    obs_gmi = torch.tensor(retrieval_input.obs_gmi.data)
    with torch.no_grad():
        pred = mlp(obs_gmi.T)
        dims = ("samples",)
        results = xr.Dataset({
            "surface_precip": (dims, pred["surface_precip"].cpu().numpy()[..., 0]),
            "probability_of_precipitation": (dims, nn.functional.sigmoid(pred["probability_of_precipitation"][..., 0]).cpu().numpy()),
            "probability_of_heavy_precipitation": (dims, nn.functional.sigmoid(pred["probability_of_heavy_precipitation"][..., 0]).cpu().numpy()),
        })
    return results
        

In [None]:
evaluator.plot_retrieval_results(75, retrieval_fn, input_data_format="tabular", batch_size=1024)

In [None]:
evaluator.evaluate(retrieval_fn=retrieval_fn, input_data_format="tabular", batch_size=4048)

In [None]:
evaluator.get_precipitation_estimation_results(name="MLP (GMI)").T