# Evaluate SatRain retrievals

This notebook evaluates SatRain ML retrievals across all testing datasets.

In [7]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from typing import Tuple

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
from pytorch_retrieve import load_model
model_gmi = load_model("/home/simon/src/ipwgml/models/satrain/gmi/ipwgml_gmi.pt")
model_geo = load_model("/home/simon/src/ipwgml/models/satrain/geo/checkpoints/ipwgml_geo-v3.ckpt")
model_geo_ir = load_model("/home/simon/src/ipwgml/models/satrain/geo_ir/ipwgml_geo_ir.pt")

In [9]:
from typing import Any, List, Dict
import torch
from torch import nn
import xarray as xr

from satrain.input import calculate_input_features

class PytorchRetrieval:
    """
    This class provides a generic retrieval callback function for PyTorch-based
    retrievals.

    The PytorchRetrieval wraps around a torch.nn.Module and extracts the input
    data from the xarray.Dataset provided by the ipwgml.evaluation.Evaluator
    and feeds it into the module. It then transform the output back from
    PyTorch tensors to an xarray.Dataset containing the retrieval results.

    The PytorchRetrieval class expects the module to return a dict containing
    the keys 'surface_precip', 'probability_of_precip', and
    'probability_of_heavy_precip'.
    """
    def __init__(
            self,
            model: nn.Module,
            retrieval_input: List[str | Dict[str, Any]],
            precip_threshold: float = 0.5,
            heavy_precip_threshold: float = 0.5,
            stack: bool = False,
            logits: bool = True,
            device: torch.device = torch.device("cpu"),
            dtype: torch.dtype = torch.float32,
    ):
        """
        Args:
            model: A torch.nn.Module implementing the retrieval.
            retrieval_input: A list defining the retrieval input.
            precip_threshold: The probability threshold to apply to
                transform the 'probability_of_precip' to a 'precip_flag'
                output.
            heavy_precip_threshold: Same as 'precip_threshold' but for
                heavy precip flag output.
            stack: Whether or not the model expects the input data to
                be stacked ('True') or as dictionary.
            logits: Whether the model returns logits instead of probabilities.
            device: A torch.device defining the device on which to perform
                inference.
            dtype: The dtype to which to convert the retrieval input.
        """
        self.model = model.to(device=device).eval()
        self.features = calculate_input_features(retrieval_input, stack=False)
        self.precip_threshold = precip_threshold
        self.heavy_precip_threshold = heavy_precip_threshold
        self.stack = stack
        self.logits = logits
        self.device = device
        self.dtype = dtype

    def __call__(self, input_data: xr.Dataset) -> xr.Dataset:
        """
        Run retrieval on input data.
        """
        feature_dim = 0
        if "scan" in input_data.dims:
            spatial_dims = ("scan", "pixel")
        elif "latitude" in input_data.dims:
            spatial_dims = ("latitude", "longitude")
        else:
            spatial_dims = ()

        if "batch" in input_data.dims:
            dims = ("batch",) + spatial_dims
            feature_dim += 1
        else:
            dims = spatial_dims


        features = self.features
        inpt = {}
        for name in features:
            inpt_data = torch.tensor(input_data[name].data).to(self.device, self.dtype)
            if len(dims) == 1:
                inpt_data = inpt_data.transpose(0, 1)
            inpt[name] = inpt_data

        if self.stack:
            inpt = torch.cat(list(inpt.values()), dim=feature_dim)

        with torch.no_grad():
            pred = self.model(inpt)
            surface_precip = pred["surface_precip"].expected_value().float().cpu().numpy()
            #pop = pred["precip_mask"].probability().float().cpu().numpy()
            #precip_mask = 0.5 < pop 
            #pop_heavy = pred["heavy_precip_mask"].probability().float().cpu().numpy()
            #heavy_precip_mask = 0.5 < pop_heavy
            
            results = xr.Dataset()
            results["surface_precip"] = (dims, surface_precip[:, 0])
            #results["probability_of_precip"] = (dims, pop[:, 0])
            #results["precip_flag"] = (dims, precip_mask[:, 0])
            #results["probability_of_heavy_precip"] = (dims, pop_heavy[:, 0])
            #results["heavy_precip_flag"] = (dims, heavy_precip_mask[:, 0])

        return results

In [10]:
import torch
gmi_retrieval = PytorchRetrieval(
    model_gmi,
    retrieval_input=["gmi"],
    device=torch.device("cuda:1"),
    dtype=torch.float32 
)
geo_retrieval = PytorchRetrieval(
    model_geo,
    retrieval_input=["geo"],
    device=torch.device("cuda:1"),
    dtype=torch.float32 
)
geo_ir_retrieval = PytorchRetrieval(
    model_geo_ir,
    retrieval_input=["geo_ir"],
    device=torch.device("cuda:1"),
    dtype=torch.float32 
)

## CONUS and Korea

In [None]:
from satrain.evaluation import Evaluator
from satrain.target import TargetConfig
GOES_CHANNELS = [1, 2, 4, 6, 7, 9, 10, 11, 14, 15]

retrievals = {
    "gmi": gmi_retrieval,
    "geo_retrieval": geo_retrieval,
    "geo_ir_retrieval": geo_ir_retrieval
}
inputs = {
    "gmi": {"name": "gmi", "normalize": "minmax", "nan": -2.0},
    "geo_retrieval": {"name": "geo", "normalize": "minmax", "nan": -2.0, "channels": GOES_CHANNELS},
    "geo_ir_retrieval": {"name": "geo_ir", "normalize": "minmax", "nan": -2.0},
}

for domain in ["conus", "korea"]:
    for retrieval in retrievals:
        evaluator = Evaluator(
            domain=domain,
            base_sensor="gmi",
            geometry="gridded",
            retrieval_input=[inputs[retrieval]]
        )
        evaluator.evaluate(retrievals[retrieval], tile_size=256, batch_size=32)
        results = evaluator.get_results()
        results.to_netcdf(f'results_{retrieval}_{domain}.nc')

Output()

## Austria

In [None]:
from satrain.evaluation import Evaluator
from satrain.target import TargetConfig
SEVIRI_CHANNELS = [1, 2, 3, 4, 5, 6, 7, 8, 10, 11]

retrievals = {
    "gmi": gmi_retrieval,
    "geo_retrieval": geo_retrieval
    "geo_ir_retrieval": geo_ir_retrieval
}
inputs = {
    "gmi": {"name": "gmi", "normalize": "minmax", "nan": -2.0},
    "geo_retrieval": {"name": "seviri", "normalize": "minmax", "nan": -2.0, "channels": SEVIRI_CHANNELS, "remap_obs": True},
    "geo_ir_retrieval": {"name": "geo_ir", "normalize": "minmax", "nan": -2.0},
}

for domain in ["austria"]:
    for retrieval in retrievals:
        evaluator = Evaluator(
            domain=domain,
            base_sensor="gmi",
            geometry="gridded",
            retrieval_input=[inputs[retrieval]]
    )
    evaluator.evaluate()
    results = evaluator.get_results()
    results.to_netcdf(f'results_{retrieval}_{domain}.nc')

In [2]:
from satrain.input import Ancillary
anc = Ancillary(variables=["total_precipitation"])

def retrieve_era5(input_data: xr.Dataset) -> xr.Dataset:
    """
    Retrieval callback function to load GPROF data corresponding to IPWGML SPR evaluation data.

    Args:
        input_data: An xarray.Dataset containing the retrieval input data.

    Return:
        An xarray.Dataset containing the retrieval results.
    """
    lons = input_data.longitude.data
    lats = input_data.latitude.data
    tp = input_data.ancillary.data[:, 0]
        
    return xr.Dataset({
        "surface_precip": (("batch", "latitude", "longitude"), tp * 1e3),
    })