# Inference

This notebook runs the ML heating-rate emulator on the testing data.

In [1]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

## Models and Datasets

In [17]:
from pytorch_retrieve import load_model
from hrem.datasets import HREMDataset0, HREMDataset1, HREMDataset2, HREMDataset3, HREMDataset4, HREMDataset5

model_path = Path("/home/simon/src/hrem/models/")
data_path = Path("/home/simon/data/heating_rates/4Simon/")

models = {
    "V2.0": (
        model_path / "v2.0" / "hrem_v2_0.pt",
        HREMDataset0(data_path / "HR_test_patches.zarr/", validation=True),
    ),
    "V2.1": (
        model_path / "v2.1" / "hrem_v2_1.pt",
        HREMDataset1(data_path / "HR_test_patches_withLWPReff.zarr/", validation=True),
    ),
    "V2.2": (
        model_path / "v2.2" / "hrem_v2_2.pt",
        HREMDataset2(data_path / "HR_test_patches_with_split_bands.zarr/", validation=True),
    ),
    "V2.3": (
        model_path / "v2.3" / "hrem_v2_3.pt",
        HREMDataset3(data_path / "HR_test_patches_with_split_bands_and_tau.zarr/", validation=True),
    ),
    "V2.4": (
        model_path / "v2.4" / "hrem_v2_4.pt",
        HREMDataset4(data_path / "HR_test_patches_with_split_bands_and_tau.zarr/", validation=True),
    ),
    "V2.4s": (
        model_path / "v2.4s" / "hrem_v2_4_s.pt",
        HREMDataset4(data_path / "HR_test_patches_with_split_bands_and_tau.zarr/", validation=True),
    ),
    "V2.5": (
        model_path / "v2.5" / "hrem_v2_5.pt",
        HREMDataset5(
            data_path / "HR_test_patches_with_split_bands_and_tau.zarr/",
            data_path / "HR_test_CKD_level_vars.nc",
            validation=True
        ),
    ),
}

## Inference

In [14]:
import zarr
from hrem.utils import run_emulator
for version, (model, ds) in models.items():
    mdl = load_model(model).eval()
    results, reference = run_emulator(mdl, ds, device="cuda:0")
    output_zarr = f"results_{version}.zarr"
    root = zarr.open_group(output_zarr, mode='w')
    root["CNN_output"] = results
    del root

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 384/384 [03:57<00:00,  1.61it/s]
  0%|                                                                                                              | 0/384 [00:00<?, ?it/s]


ValueError: found the following matches with the input file in xarray's IO backends: ['netcdf4', 'h5netcdf']. But their dependencies may not be installed, see:
https://docs.xarray.dev/en/stable/user-guide/io.html 
https://docs.xarray.dev/en/stable/getting-started-guide/installing.html

## Error Statistics


In [10]:
import torch

from hrem.utils import evaluate_scene

def evaluate_model(model, dataset):
    """
    Calculate error statistics for model configuration.

    Args:
        model: The trained model instance.
        dataset: The dataset to load the input data.

    Return:
        An xarray.Dataset containing the error statistics.
    """
    rel_err_bins = np.linspace(-50, 50, 101)
    err_bins = np.linspace(-100, 100, 101)
    
    err_cs_sum = 0.0
    abs_err_cs_sum = 0.0
    squared_err_cs_sum = 0.0
    rel_err_cs_sum = 0.0
    rel_abs_err_cs_sum = 0.0
    cts_cs = 0.0
    hr_ref_cs_sum = 0.0
    rel_err_dist_cs = np.zeros((16, 100))
    err_dist_cs = np.zeros((16, 100))
    
    rel_err_ic_sum = 0.0
    rel_abs_err_ic_sum = 0.0
    abs_err_ic_sum = 0.0
    squared_err_ic_sum = 0.0
    err_ic_sum = 0.0
    cts_ic = 0.0
    hr_ref_ic_sum = 0.0
    rel_err_dist_ic = np.zeros((16, 100))
    err_dist_ic = np.zeros((16, 100))
    
    for case in tqdm(range(16)):
        hr_pred, hr_ref, cloud_mask = evaluate_scene(model, dataset, case, device="cuda:0", dtype=torch.bfloat16)

        # Clear Sky Stats
        hr_pred_cs = hr_pred[~cloud_mask]
        hr_ref_cs = hr_ref[~cloud_mask]
        
        err_cs = hr_pred_cs - hr_ref_cs
        abs_err_cs = np.abs(err_cs)
        squared_err_cs = err_cs ** 2
        rel_err_cs = err_cs / hr_ref_cs
        rel_abs_err_cs = abs_err_cs / hr_ref_cs

        err_cs_sum += err_cs.sum()
        abs_err_cs_sum += abs_err_cs.sum()
        squared_err_cs_sum += squared_err_cs.sum()
        rel_err_cs_sum += rel_err_cs.sum()
        rel_abs_err_cs_sum += rel_abs_err_cs.sum()
        hr_ref_cs_sum += hr_ref_cs.sum()
        cts_cs += (~cloud_mask).sum()
        
        rel_err_dist_cs[case] += np.histogram(100.0 * rel_err_cs, bins=rel_err_bins)[0]
        err_dist_cs[case] += np.histogram(err_cs, bins=err_bins)[0]
        
        # Cloudy Sky Stats
        hr_pred_ic = hr_pred[cloud_mask]
        hr_ref_ic = hr_ref[cloud_mask]
        
        err_ic = hr_pred_ic - hr_ref_ic
        abs_err_ic = np.abs(err_ic)
        squared_err_ic = err_ic ** 2
        rel_err_ic = err_ic / hr_ref_ic
        rel_abs_err_ic = abs_err_ic / hr_ref_ic

        err_ic_sum += err_ic.sum()
        abs_err_ic_sum += abs_err_ic.sum()
        squared_err_ic_sum += squared_err_ic.sum()
        rel_err_ic_sum += rel_err_ic.sum()
        rel_abs_err_ic_sum += rel_abs_err_ic.sum()
        hr_ref_ic_sum += hr_ref_ic.sum()
        cts_ic += (cloud_mask).sum()

        rel_err_dist_ic[case] += np.histogram(100.0 * rel_err_ic, bins=rel_err_bins)[0]
        err_dist_ic[case] += np.histogram(err_ic, bins=err_bins)[0]

    hr_ref_cs = hr_ref_cs_sum / cts_cs
    hr_ref_ic = hr_ref_ic_sum / cts_ic
    hr_ref = (hr_ref_cs_sum + hr_ref_ic_sum) / (cts_cs + cts_ic)
    
    bias_cs = err_cs_sum / cts_cs
    bias_ic = err_ic_sum / cts_ic
    bias = (err_cs_sum + err_ic_sum) / (cts_cs + cts_ic)
    
    rel_bias_cs = 100.0 * bias_cs / hr_ref_cs
    rel_bias_ic = 100.0 * bias_ic / hr_ref_ic
    rel_bias= 100.0 * bias/ hr_ref

    abs_err_ic = abs_err_ic_sum / cts_ic
    abs_err_cs = abs_err_cs_sum / cts_cs
    abs_err = (abs_err_ic_sum + abs_err_cs_sum) / (cts_ic + cts_cs)
    
    squared_err_ic = squared_err_ic_sum / cts_ic
    squared_err_cs = squared_err_cs_sum / cts_cs
    squared_err = (squared_err_ic_sum + squared_err_cs_sum) / (cts_ic + cts_cs)
    
    rel_err_ic = rel_err_ic_sum / cts_ic
    rel_err_cs = rel_err_cs_sum / cts_cs
    rel_err = (rel_err_ic_sum + rel_err_cs_sum) / (cts_ic + cts_cs)
    
    rel_abs_err_ic = rel_abs_err_ic_sum / cts_ic
    rel_abs_err_cs = rel_abs_err_cs_sum / cts_cs
    rel_abs_err = (rel_abs_err_ic_sum + rel_abs_err_cs_sum) / (cts_ic + cts_cs)

    rel_err_dist = rel_err_dist_cs + rel_err_dist_ic
    err_dist = err_dist_cs + err_dist_ic

    rel_err_bin = 0.5 * (rel_err_bins[1:] + rel_err_bins[:-1])
    err_bin = 0.5 * (err_bins[1:] + err_bins[:-1])

    return xr.Dataset({
        "bias_cs": bias_cs,
        "bias_ic": bias_ic,
        "bias": bias,
        "rel_bias_cs": rel_bias_cs,
        "rel_bias_ic": rel_bias_ic,
        "rel_bias": rel_bias,
        "abs_err_ic": abs_err_ic,
        "abs_err_cs": abs_err_cs,
        "abs_err": abs_err,
        "squared_err_ic": squared_err_ic,
        "squared_err_cs": squared_err_cs,
        "squared_err": squared_err,
        "rel_err_ic": rel_err_ic,
        "rel_err_cs": rel_err_cs,
        "rel_err": rel_err,
        "rel_abs_err_ic": rel_abs_err_ic,
        "rel_abs_err_cs": rel_abs_err_cs,
        "rel_abs_err": rel_abs_err,
        "rel_err_bin": (("rel_err_bin",), rel_err_bin),
        "rel_err_dist_cs": (("case", "rel_err_bin",), rel_err_dist_cs),
        "rel_err_dist_ic": (("case", "rel_err_bin",), rel_err_dist_ic),
        "rel_err_dist": (("case", "rel_err_bin",), rel_err_dist),
        "err_bin": (("err_bin",), err_bin),
        "err_dist_cs": (("case", "err_bin",), err_dist_cs),
        "err_dist_ic": (("case", "err_bin",), err_dist_ic),
        "err_dist": (("case", "err_bin",), err_dist),
    })

In [16]:
from tqdm import tqdm

tasks = {}
for version, (model, ds) in tqdm(models.items()):
    mdl = load_model(model).eval()
    res = evaluate_model(mdl, ds)
    res["version"] = version
    res.to_netcdf(f"results_{version}.nc")
    

  0%|                                                                                                                | 0/2 [00:00<?, ?it/s]
  0%|                                                                                                               | 0/16 [00:00<?, ?it/s][A
  6%|██████▍                                                                                                | 1/16 [00:27<06:58, 27.91s/it][A
 12%|████████████▉                                                                                          | 2/16 [00:45<05:06, 21.91s/it][A
 19%|███████████████████▎                                                                                   | 3/16 [01:03<04:18, 19.90s/it][A
 25%|█████████████████████████▊                                                                             | 4/16 [01:20<03:47, 18.93s/it][A
 31%|████████████████████████████████▏                                                                      | 5/16 [01:37<03:21, 18.34s/it][A
 3

ValueError: found the following matches with the input file in xarray's IO backends: ['netcdf4', 'h5netcdf']. But their dependencies may not be installed, see:
https://docs.xarray.dev/en/stable/user-guide/io.html 
https://docs.xarray.dev/en/stable/getting-started-guide/installing.html