# GPROF validation

This notebook assesses the performance of the GPROF GMI retrieval against ground radar data.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np
import xarray as xr
from pathlib import Path
import seaborn as sns
sns.reset_orig()

In [None]:
import seaborn as sns
sns.reset_orig()

In [None]:
import sys
sys.path.insert(0, "../..")
from gprof_nn.plotting import set_style
set_style(latex=True)

## Load collocated retrieval results

To prepare the validation collocations for evaluation, the retrieval results must first be collocated with the reference/validation data using the ``gprof_nn combine_validation_data`` command. Following this valid validation samples can be extracted using the ``extract_results`` function provided by the ``gprof_nn.validation`` module.

The collocated retrieval results are stored in separate files for every product and contain only pixels in which both MRMS, the product itself and GPM CMB have valid values. Samples from each year are combined into one file. We load all files into dictionaries ``results_db``, ``results_db_1``, ... ``results_db_3`` containing the samples from the database year + x. Additionally, we load collocation with the a priori database into ``results_db_ref``.

In [None]:
import pickle
DATA_PATH = Path("/home/simonpf/data/gprof_nn/validation/gmi")
groups = ["gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "combined", "gprof_nn_hr"]
results_db = {group: xr.load_dataset(DATA_PATH / f"results_{group}_db.nc") for group in groups}

In [None]:
results_db = {group: xr.load_dataset(DATA_PATH / f"results_{group}_db.nc") for group in groups}

In [None]:
results_db_1 = {group: xr.load_dataset(DATA_PATH / f"results_{group}_db_1.nc") for group in groups}

In [None]:
results_db_2 = {group: xr.load_dataset(DATA_PATH / f"results_{group}_db_2.nc") for group in groups}

In [None]:
results_db_2["gprof_v5"] = xr.load_dataset("/home/simonpf/data_3/gprof_nn/validation/gmi/results_gprof_v5_db_2.nc")

In [None]:
results_db_3 = {group: xr.load_dataset(DATA_PATH / f"results_{group}_db_3.nc") for group in groups}

In [None]:
results_db_ref = {group: xr.load_dataset(DATA_PATH / f"results_{group}_db_ref.nc") for group in groups}

## Observation frequencies

This section displays the observation frequencies throught the validation period.

In [None]:
time_bins = np.arange("2018-10-01", "2022-11-01", dtype="datetime64[M]").astype("datetime64[s]")
times = []
for results in [results_db, results_db_1, results_db_2, results_db_3]:
    data = results["gprof_nn_3d"]
    times.append(data.time.data)
    #with xr.open_dataset(results["gprof_nn_1d"]) as data:
times = np.concatenate(times)
counts, _ = np.histogram(times, bins=time_bins)
width = (time_bins[1:] - time_bins[:-1])
x = time_bins[:-1] + 0.5 * (time_bins[1:] - time_bins[:-1])

In [None]:
np.histogram?

In [None]:
f, ax = plt.subplots(1, 1, figsize=(12, 4))
ax.plot(x, counts)
for l in ax.get_xticklabels():
    l.set_rotation(45)

## Scatter plots

In [None]:
from matplotlib.colors import LogNorm, Normalize
from matplotlib.cm import ScalarMappable
from matplotlib.gridspec import GridSpec
from gprof_nn.validation import NAMES
import warnings
warnings.filterwarnings("ignore")
from gprof_nn.validation import calculate_scatter_plot, calculate_conditional_mean


def make_scatter_plots(
    no_frozen=True,
    no_snow_sfc=True,
    no_orographic=True
):
    """
    Create scatter plots showing the conditional distribution of retrieved precipitation.
    
    Args:
        no_frozen: Whether or not to exclude precipitation classified as frozen by MRMS
        no_snow_sfc: Whether or not to exclude precipitation estimates of snow surfaces.
        no_orographic: Whether or not to exclude mountain precipitation. 
        
    Return:
        A matplotlib figure containing the scatter plots.
    """
    f = plt.figure(figsize=(20, 24))
    gs = GridSpec(
        8, 6,
        width_ratios=[0.4]  + [1.0] * 5,
        height_ratios=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.1],
        hspace=0.15,
        wspace=0.2
    )
    axs = np.array([
        [f.add_subplot(gs[i, 1 + j]) for j in range(5)] 
        for i in range(6)
    ] )
    PERIODS = ["A priori (2019)", "MRMS (2019)", "MRMS (2020)", "MRMS (2021)", "MRMS (2022)"]

    groups = NAMES.keys()
    groups = ["gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr", "combined"]

    for i, group in enumerate(groups):

        ax = f.add_subplot(gs[i, 0])
        ax.set_axis_off()
        ax.text(0, 0, NAMES[group], rotation=90, ha="center", va="center")
        ax.set_ylim([-2, 2])

        for j, result in enumerate([results_db_ref, results_db, results_db_1, results_db_2, results_db_3]):

            ax = axs[i, j]
            if group == "gprof_v5" and j == 4:
                ax.set_title(f"{PERIODS[j]}", loc="center", pad=15)
                ax.set_axis_off()
                continue
            x, y = calculate_scatter_plot(
                result,
                group,
                no_frozen=no_frozen,
                no_snow_sfc=no_snow_sfc,
                no_orographic=no_orographic
            )

            #norm = LogNorm()
            norm = Normalize(0, 0.1)
            x = 0.5 * (x[1:] + x[:-1])
            levels = np.logspace(-4, 0, 9)
            mappable = ax.contourf(x, x, y.T, norm=LogNorm(1e-4, 1e0), levels=levels, extend="both")

            x, y = calculate_conditional_mean(
                result,
                group,
                no_frozen=no_frozen,
                no_snow_sfc=no_snow_sfc,
                no_orographic=no_orographic
            )
            ax.plot(x, y, ls="--", c="k", label="Conditional mean")

            ax.set_xscale("log")
            ax.set_yscale("log")
            ax.set_aspect(1.0)

            if i == 0:
                ax.set_title(f"{PERIODS[j]}", loc="center", pad=15)
            ax.set_xlim([1e-1, 1e2 + 1])
            ax.set_ylim([1e-1, 1e2 + 1])

            ax.plot(x, x, ls="--", c="grey")

            if j == 0:
                ax.set_ylabel(r"$P_\text{Retrieved}$ [$\si{\milli \meter \per \hour}$]")
            else:
                ax.set_yticklabels([])
                for m in ax.yaxis.get_majorticklines():
                    m.set_visible(False)
                for m in ax.yaxis.get_minorticklines():
                    m.set_visible(False)
                ax.spines["left"].set_visible(False)

            if i == 5:
                if j == 0:
                    ax.set_xlabel(r"$P_\text{A priori}$ [$\si{\milli \meter \per \hour}$]")
                else:
                    ax.set_xlabel(r"$P_\text{MRMS}$ [$\si{\milli \meter \per \hour}$]")
                    
            else:
                ax.set_xticklabels([])
                for m in ax.xaxis.get_majorticklines():
                    m.set_visible(False)
                for m in ax.xaxis.get_minorticklines():
                    m.set_visible(False)
                ax.spines["left"].set_visible(False)

    # Colorbars
    ax = f.add_subplot(gs[-1, 1:])
    cb = plt.colorbar(mappable, label=r"$p(P_\text{Retrieved}|P_\text{A priori / MRMS})$ [\si{(\milli \meter \, \hour^{-1})^{-1}}]", cax=ax, orientation="horizontal")
    return f

#### No snow surfaces, no mountains

In [None]:
fig = make_scatter_plots(no_frozen=True, no_snow_sfc=True, no_orographic=True);
fig.savefig("../../plots/validation/scatter.pdf", bbox_inches="tight")

In [None]:
fig = make_scatter_plots(no_frozen=True, no_snow_sfc=True, no_orographic=True);
fig.savefig("../../plots/validation/scatter.pdf", bbox_inches="tight")

#### All samples

In [None]:
make_scatter_plots(no_frozen=False, no_snow_sfc=False, no_orographic=False);
fig.savefig("../../plots/validation/scatter_all.pdf", bbox_inches="tight")

### Error metrics

This code calculates and displayes the quantitative error metrics for every product and year and displays them as bar plots.

In [None]:
import pandas as pd
from gprof_nn.validation import calculate_error_metrics
groups = ["gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr"]

def calculate_stats(
    no_frozen: bool,
    no_snow_sfc: bool,
    no_orographic: bool = True,
    no_ocean: bool = True
):
    stats_db = calculate_error_metrics(
        results_db_ref,
        groups + ["combined"],
        no_frozen=no_frozen,
        no_snow_sfc=no_snow_sfc,
        no_orographic=no_orographic,
        no_ocean=no_ocean,
    )
    stats_db_0 = calculate_error_metrics(
        results_db,
        groups + ["combined"],
        no_frozen=no_frozen,
        no_snow_sfc=no_snow_sfc,
        no_orographic=no_orographic,
        no_ocean=no_ocean,
    )
    stats_db_1 = calculate_error_metrics(
        results_db_1,
        groups + ["combined"],
        no_frozen=no_frozen,
        no_snow_sfc=no_snow_sfc,
        no_orographic=no_orographic,
        no_ocean=no_ocean,
    )
    stats_db_2 = calculate_error_metrics(
        results_db_2,
        groups + ["combined"],
        no_frozen=no_frozen,
        no_snow_sfc=no_snow_sfc,
        no_orographic=no_orographic,
        no_ocean=no_ocean,
    )
    stats_db_3 = calculate_error_metrics(
        results_db_3,
        groups[1:] + ["combined"],
        no_frozen=no_frozen,
        no_snow_sfc=no_snow_sfc,
        no_orographic=no_orographic,
        no_ocean=no_ocean,
    )
    
    stats_db["Time period"] = "A Priori (2019)"
    stats_db = stats_db.reset_index().rename(columns={"index": "Algorithm"})
    stats_db_0["Time period"] = "MRMS (2019)"
    stats_db_0 = stats_db_0.reset_index().rename(columns={"index": "Algorithm"})
    stats_db_1["Time period"] = "MRMS (2020)"
    stats_db_1 = stats_db_1.reset_index().rename(columns={"index": "Algorithm"})
    stats_db_2["Time period"] = "MRMS (2021)"
    stats_db_2 = stats_db_2.reset_index().rename(columns={"index": "Algorithm"})
    stats_db_3["Time period"] = "MRMS (2022)"
    stats_db_3 = stats_db_3.reset_index().rename(columns={"index": "Algorithm"})
    
    stats = pd.concat([
        stats_db,
        stats_db_0,
        stats_db_1,
        stats_db_2,
        stats_db_3
    ])
    
    return stats

stats = calculate_stats(no_frozen=True, no_snow_sfc=True, no_orographic=True)
stats_all = calculate_stats(no_frozen=False, no_snow_sfc=False, no_orographic=False)

In [None]:
stats[stats["Time period"] == "MRMS (2020)"].to_csv("gmi_2020_conus.csv")

In [None]:
stats[stats["Time period"] == "MRMS (2020)"]

In [None]:
from matplotlib.gridspec import GridSpec
from gprof_nn.validation import get_colors, NAMES

def plot_stats(stats):
    f = plt.figure(figsize=(15, 10))
    gs = GridSpec(3, 3, height_ratios= [1.0] * 2 + [0.5], wspace=0.3, hspace=0.2)
    colors = get_colors()

    axs = np.array([
        [f.add_subplot(gs[i, j]) for j in range(3)]
        for i in range(2)
    ])

    y_lims = {
        "Bias": [-15, 15]
    }

    y_labels = {
        "Bias": "Bias [$\si{\percent}$]",
        "MAE": "MAE [$\si{\milli \meter \per \hour}$]",
        "MSE": "MSE [$(\si{\milli \meter \per \hour})^2$]",
        "SMAPE": "SMAPE$_{0.1}$ [$\si{\percent}$]",
    }

    bar_palette = list(colors.values())

    metrics = ["Bias", "MAE", "MSE", "Correlation coeff.", "SMAPE"]
    for i, metric in enumerate(metrics):
        ax = axs[i // 3, i % 3]
        g = sns.barplot(
            x="Time period",
            y=metric,
            hue="Algorithm",
            data=stats,
            ax=ax,
            saturation=1.0,
            palette=bar_palette
        )
        g.legend_.set_visible(False)


        if i == 0:
            ax.axhline(y=0, ls="--", c="k")

        if i // 3 == 0 and i % 3 < 2:
            ax.set_xticklabels([])
            ax.set_xlabel(None)
        else:
            for l in ax.get_xticklabels():
                l.set_rotation(45)

        if metric in y_labels:
            ax.set_ylabel(y_labels[metric])

        if metric in y_lims:
            ax.set_ylim(y_lims[metric])

        #ax.set_title(f"({chr(ord('a') + i)})")

    axs[-1, -1].set_axis_off()

    lax = f.add_subplot(gs[-1, :])
    lax.set_axis_off()
    lax.legend(*ax.get_legend_handles_labels(), ncol=6, loc="lower center")

    return f


#### No frozen, no snow surfaces, no orographic

In [None]:
f = plot_stats(stats)
f.savefig("../../plots/validation/metrics.pdf", bbox_inches="tight")
f.savefig("../../plots/validation/metrics.svg", bbox_inches="tight")

#### All samples

In [None]:
f = plot_stats(stats_all)
f.savefig("../../plots/validation/metrics_all.pdf", bbox_inches="tight")
f.savefig("../../plots/validation/metrics_all.svg", bbox_inches="tight")

## Error maps

In [None]:
import seaborn as sns
sns.reset_orig()
cmap = sns.color_palette("vlag_r", as_cmap=True)

In [None]:
from gprof_nn.validation import gridded_stats, CONUS

groups = ["gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr", "combined"]
results = {
    "A priori (2019)": results_db_ref,
    "MRMS (2019)": results_db,
    "MRMS (2020)": results_db_1,
    "MRMS (2021)": results_db_2,
    "MRMS (2022)": results_db_3
}
bins = (np.arange(CONUS[0], CONUS[2], 5.0), np.arange(CONUS[1], CONUS[3], 5.0))

def calculate_spatial_stats(
    no_frozen: bool,
    no_snow_sfc: bool,
    no_ocean: bool = True,
    no_orographic: bool = True
):
    corrs = {}
    biases = {}

    for group in groups:
        for name, data in results.items():

            bias, mae, mse, corr = gridded_stats(
                data[group],
                bins,
                min_samples=1000,
                no_ocean=no_ocean,
                no_orographic=no_orographic,
                no_frozen=no_frozen,
                no_snow_sfc=no_snow_sfc
            )

            biases.setdefault(group, {})[name] = bias
            corrs.setdefault(group, {})[name] = corr
            
    return biases, corrs

In [None]:
biases, corrs = calculate_spatial_stats(no_frozen=True, no_snow_sfc=True, no_orographic=True, no_ocean=True)

In [None]:
biases_all, corrs_all = calculate_spatial_stats(no_frozen=False, no_snow_sfc=False, no_orographic=False, no_ocean=True)

In [None]:
set_style(latex=True)

In [None]:
from gprof_nn.data.validation import CONUS
from gprof_nn.validation import REGIONS, gridded_stats
from matplotlib.patches import Rectangle
from matplotlib.gridspec import GridSpec
from matplotlib.colors import Normalize
from gprof_nn.validation import NAMES


def make_bias_plots(biases):
    """
    Plots maps of spatial biases for multiple retrievals and time periods.
    
    Args:
        biases: Nested dict mapping retrieval names and period names to maps of
             biases.
             
    Rerturn:
        A matplotlib.Figure containing the rendered bias maps.
    """
    selection = ["gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr", "combined"]
    M = len(selection)
    N = 5
    f = plt.figure(figsize=(N * 4.4 + 1, M * 2.4))
    gs = GridSpec(M + 1, N + 1, width_ratios= [0.1] + [1.0] * N, height_ratios=[1.0] * M + [0.1], hspace=0.05, wspace=0.05)
    axs = np.array([
        [f.add_subplot(gs[i, j + 1], projection=ccrs.PlateCarree()) for j in range(N)]
        for i in range(M)
    ])

    periods = ["A priori (2019)", "MRMS (2019)", "MRMS (2020)", "MRMS (2021)", "MRMS (2022)"]
    bins = (np.arange(CONUS[0], CONUS[2], 5.0), np.arange(CONUS[1], CONUS[3], 5.0))
    c_lon = 0.5 * (bins[0][1:] + bins[0][:-1])
    c_lat = 0.5 * (bins[1][1:] + bins[1][:-1])
    levels = np.linspace(-0.055, 0.055, 12)

    # Bias
    norm = Normalize(-0.025, 0.025)
    for i, group in enumerate(selection):
        ax = f.add_subplot(gs[i, 0])
        ax.set_axis_off()
        ax.text(0, 0, NAMES[group], rotation=90, va="center", ha="center")
        ax.set_ylim([-2, 2])

        for j, period in enumerate(periods):

            ax = axs[i, j]
            
            # Skip GPROF V5 for year 2022
            if (i == 0) and (j == 4):
                ax.set_title(periods[4])
                ax.axis("off")
                continue

            bias = biases[group][period]
            m = ax.pcolormesh(
                bins[0],
                bins[1],
                bias.T,
                vmin=-0.05,
                vmax=0.05,
                cmap="coolwarm_r",
                rasterized=True
            )

            ax.coastlines()
            ax.set_xlim([CONUS[0] + 1, CONUS[2] - 1])
            ax.set_ylim([CONUS[1] + 1, CONUS[3] - 1])
            if i == 0:
                ax.set_title(periods[j])

    ax = f.add_subplot(gs[-1, 1:])
    plt.colorbar(m, label=r"$\overline{P_\text{Retrieved} - P_\text{A priori/MRMS}}$ [$\si{\milli \meter \per \hour}$]", cax=ax, orientation="horizontal")
    plt.tight_layout()
    return f

#### No snow, no snow surfaces, no orographic

In [None]:
fig = make_bias_plots(biases);
fig.savefig("../../plots/validation/bias_maps.pdf", bbox_inches="tight")

#### All samples

In [None]:
make_bias_plots(biases_all);

## Corrected biases

In [None]:
from gprof_nn.data.validation import CONUS
from gprof_nn.validation import REGIONS, gridded_stats
from matplotlib.patches import Rectangle
from matplotlib.gridspec import GridSpec
from matplotlib.colors import Normalize
from gprof_nn.validation import NAMES

def make_corrected_bias_plots(biases):
    """
    Plots maps of spatial biases corrected for GPM CMB retrievals.
    
    Args:
        biases: Nested dict mapping retrieval names and period names to maps of
             biases.
             
    Rerturn:
        A matplotlib.Figure containing the rendered bias maps.
    """
    selection = ["gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr"]
    M = len(selection)
    N = 5
    f = plt.figure(figsize=(N * 4.4 + 1, M * 2.4))
    gs = GridSpec(M + 1, N + 1, width_ratios= [0.1] + [1.0] * N, height_ratios=[1.0] * M + [0.1], hspace=0.05, wspace=0.05)
    axs = np.array([
        [f.add_subplot(gs[i, j + 1], projection=ccrs.PlateCarree()) for j in range(N)]
        for i in range(M)
    ])

    periods = ["A priori (2019)", "MRMS (2019)", "MRMS (2020)", "MRMS (2021)", "MRMS (2022)"]
    bins = (np.arange(CONUS[0], CONUS[2], 5.0), np.arange(CONUS[1], CONUS[3], 5.0))
    c_lon = 0.5 * (bins[0][1:] + bins[0][:-1])
    c_lat = 0.5 * (bins[1][1:] + bins[1][:-1])
    levels = np.linspace(-0.055, 0.055, 12)

    # Bias
    norm = Normalize(-0.025, 0.025)
    for i, group in enumerate(selection):
        ax = f.add_subplot(gs[i, 0])
        ax.set_axis_off()
        ax.text(0, 0, NAMES[group], rotation=90, va="center", ha="center")
        ax.set_ylim([-2, 2])

        for j, period in enumerate(periods):

            ax = axs[i, j]

            # Skip GPROF V5 for year 2022
            if (i == 0) and (j == 4):
                ax.set_title(periods[4])
                ax.axis("off")
                continue

            bias = biases[group][period]
            bias -= biases["combined"][period]

            m = ax.pcolormesh(bins[0], bins[1],  bias.T, vmin=-0.05, vmax=0.05, cmap="coolwarm_r", rasterized=True)

            ax.coastlines()
            ax.set_xlim([CONUS[0] + 1, CONUS[2] - 1])
            ax.set_ylim([CONUS[1] + 1, CONUS[3] - 1])
            if i == 0:
                ax.set_title(periods[j])

    ax = f.add_subplot(gs[-1, 1:])
    plt.colorbar(m, label=r"$\overline{P_\text{Retrieved} - P_\text{A priori/MRMS}}$ [$\si{\milli \meter \per \hour}$]", cax=ax, orientation="horizontal")
    plt.tight_layout()
    return f;

In [None]:
make_corrected_bias_plots(biases);
f.savefig("../../plots/validation/bias_maps_corrected.pdf", bbox_inches="tight")

## Correlation coefficient maps

In [None]:
from gprof_nn.data.validation import CONUS
from matplotlib.gridspec import GridSpec
from matplotlib.colors import Normalize, BoundaryNorm
from matplotlib.cm import get_cmap
from gprof_nn.validation import NAMES


def make_correlation_plots(corrs):
    selection = ["gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr", "combined"]
    M = len(selection)
    N = 5
    f = plt.figure(figsize=(N * 4.4 + 1, M * 2.4))
    gs = GridSpec(M + 1, N + 1, width_ratios= [0.1] + [1.0] * N, height_ratios=[1.0] * M + [0.1], hspace=0.05, wspace=0.05)
    axs = np.array([
        [f.add_subplot(gs[i, j + 1], projection=ccrs.PlateCarree()) for j in range(N)]
        for i in range(M)
    ])

    periods = ["A priori (2019)", "MRMS (2019)", "MRMS (2020)", "MRMS (2021)", "MRMS (2022)"]
    bins = (np.arange(CONUS[0], CONUS[2], 5.0), np.arange(CONUS[1], CONUS[3], 5.0))
    c_lon = 0.5 * (bins[0][1:] + bins[0][:-1])
    c_lat = 0.5 * (bins[1][1:] + bins[1][:-1])
    levels = np.linspace(0.0, 1.0, 21)
    cmap = get_cmap("magma", len(levels))
    norm = BoundaryNorm(levels, len(levels))

    for i, group in enumerate(selection):

        ax = f.add_subplot(gs[i, 0])
        ax.set_axis_off()
        ax.text(0, 0, NAMES[group], rotation=90, va="center", ha="center")
        ax.set_ylim([-2, 2])

        for j, period in enumerate(periods):
            ax = axs[i, j]
            #ax.background_img("ne_gray")

            # Skip GPROF V5 for year 2022
            if (i == 0) and (j == 4):
                ax.set_title(periods[4])
                ax.axis("off")
                continue

            corr = corrs[group][period]
            m = ax.pcolormesh(bins[0], bins[1],  corr.T, norm=norm, cmap=cmap, rasterized=True)

            ax.coastlines(color="grey")
            ax.set_xlim([CONUS[0], CONUS[2]])
            ax.set_ylim([CONUS[1], CONUS[3]])
            if i == 0:
                ax.set_title(periods[j])

    ax = f.add_subplot(gs[-1, 1:])
    plt.colorbar(m, label="Correlation coeff.", cax=ax, orientation="horizontal")
    plt.tight_layout()
    return f

In [None]:
f = make_correlation_plots(corrs);
f.savefig("../../plots/validation/correlation_maps.pdf", bbox_inches="tight")

## Precision-recall curves

In [None]:
from gprof_nn.validation import get_colors
COLORS = get_colors()

In [None]:
from gprof_nn.data.validation import CONUS
from matplotlib.gridspec import GridSpec
from matplotlib.colors import Normalize
from gprof_nn.validation import NAMES, calculate_pr_curve, get_colors

def plot_pr_curves(
    no_frozen: bool,
    no_snow_sfc: bool,
    no_ocean: bool = True,
    no_orographic: bool = True,
    fpa: bool = False
):
    selection = ["gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr"]
    colors = get_colors()

    N = 5
    f = plt.figure(figsize=(3 * 5, 2 * 5))
    gs = GridSpec(2, 3, hspace=0.2, wspace=0.20)
    axs = np.array([
        [f.add_subplot(gs[j // 3, j % 3]) for j in range(N)]
    ])

    periods = ["A priori (2019)", "MRMS (2019)", "MRMS (2020)", "MRMS (2021)", "MRMS (2022)"]
    for i, results in enumerate([results_db_ref, results_db, results_db_1, results_db_2, results_db_3]):

        if i < 4:
            handles = []

        for retrieval in selection:
            
            if i == 4 and retrieval == "gprof_v5":
                continue

            data = results[retrieval]
            precision, recall, thresholds = calculate_pr_curve(
                data,
                fpa=fpa,
                no_ocean=no_ocean,
                no_orographic=no_orographic,
                no_frozen=no_frozen,
                no_snow_sfc=no_snow_sfc
            )

            ax = axs[0, i] 
            period = periods[i]

            handles_i = ax.plot(recall, precision, c=colors[retrieval], label=NAMES[retrieval])
            if i < 4:
                handles += handles_i

            if (i < 2):
                ax.set_xticklabels([])
                ax.set_xlabel("")
            else:
                ax.set_xlabel("Recall")

            ax.set_xlim([0, 1])
            ax.set_ylim([0, 1])
            if i % 3 == 0:
                ax.set_ylabel("Precision")
            else:
                ax.set_yticklabels([])

            ax.set_title(f"({chr(ord('a') + i)}) {period}", loc="left", pad=15)

    ax = f.add_subplot(gs[1, -1])
    ax.set_axis_off()
    ax.legend(handles=handles, loc="center")
    return f


In [None]:
f = plot_pr_curves(no_frozen=True, no_snow_sfc=True, no_ocean=True, no_orographic=True);
f.savefig("../../plots/validation/pr_curves.pdf", bbox_inches="tight")
f.savefig("../../plots/validation/pr_curves.svg", bbox_inches="tight")

In [None]:
f = plot_pr_curves(no_frozen=True, no_snow_sfc=False, no_ocean=False, no_orographic=False);

# Regional analysis

In [None]:
DATA_PATH = Path("/home/simonpf/data_3/gmi")
groups = ["gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr", "combined"]
results_db_ref_kwaj = {group: xr.load_dataset(DATA_PATH / f"../kwaj_gmi/results_{group}_db_ref.nc") for group in groups}
results_db_kwaj = {group: xr.load_dataset(DATA_PATH / f"../kwaj_gmi/results_{group}_db.nc") for group in groups}
results_db_1_kwaj = {group: xr.load_dataset(DATA_PATH / f"../kwaj_gmi/results_{group}_db_1.nc") for group in groups}
results_db_2_kwaj = {group: xr.load_dataset(DATA_PATH / f"../kwaj_gmi/results_{group}_db_2.nc") for group in groups}
results_db_3_kwaj = {group: xr.load_dataset(DATA_PATH / f"../kwaj_gmi/results_{group}_db_3.nc") for group in groups[1:]}

In [None]:
results_db_ref_all = {}
results_db_all = {}
results_db_1_all = {}
results_db_2_all = {}
results_db_3_all = {}

groups = ["gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr", "combined"]
 
for k in groups:
    data = results_db_ref[k].drop_vars("airmass_type")
    data_kwaj = results_db_ref_kwaj[k].copy()
    sp = data_kwaj.surface_precip_avg.data
    data_kwaj["rqi"] = (("samples"), 999.0 * np.ones_like(sp))
    data_kwaj["mask"] = (("samples"), 999.0 * np.ones_like(sp))
    data["range"] = (("samples"), np.nan * np.ones_like(data.surface_precip.data))
    results_db_ref_all[k] = xr.concat([data, data_kwaj], dim="samples")
del results_db_ref

for k in groups:
    data = results_db[k]
    data_kwaj = results_db_kwaj[k].copy()
    sp = data_kwaj.surface_precip_avg.data
    data_kwaj["rqi"] = (("samples"), 999.0 * np.ones_like(sp))
    data_kwaj["mask"] = (("samples"), 999.0 * np.ones_like(sp))
    data["range"] = (("samples"), np.nan * np.ones_like(data.surface_precip.data))
    results_db_all[k] = xr.concat([data, data_kwaj], dim="samples")
del results_db
    
for k in groups:
    data = results_db_1[k]
    data_kwaj = results_db_1_kwaj[k].copy()
    sp = data_kwaj.surface_precip_avg.data
    data_kwaj["rqi"] = (("samples"), 999.0 * np.ones_like(sp))
    data_kwaj["mask"] = (("samples"), 999.0 * np.ones_like(sp))
    data["range"] = (("samples"), np.nan * np.ones_like(data.surface_precip.data))
    results_db_1_all[k] = xr.concat([data, data_kwaj], dim="samples")
del results_db_1
    
for k in groups:
    data = results_db_2[k]
    data_kwaj = results_db_2_kwaj[k].copy()
    sp = data_kwaj.surface_precip_avg.data
    data_kwaj["rqi"] = (("samples"), 999.0 * np.ones_like(sp))
    data_kwaj["mask"] = (("samples"), 999.0 * np.ones_like(sp))
    data["range"] = (("samples"), np.nan * np.ones_like(data.surface_precip.data))
    results_db_2_all[k] = xr.concat([data, data_kwaj], dim="samples")
del results_db_2
    
for k in groups[1:]:
    data = results_db_3[k]
    data_kwaj = results_db_3_kwaj[k].copy()
    sp = data_kwaj.surface_precip_avg.data
    data_kwaj["rqi"] = (("samples"), 999.0 * np.ones_like(sp))
    data_kwaj["mask"] = (("samples"), 999.0 * np.ones_like(sp))
    data["range"] = (("samples"), np.nan * np.ones_like(data.surface_precip.data))
    results_db_3_all[k] = xr.concat([data, data_kwaj], dim="samples")
del results_db_3

In [None]:
from gprof_nn.validation import REGIONS, calculate_error_metrics
import pandas as pd

def calculate_region_metrics(
    no_frozen: bool,
    no_snow_sfc: bool,
    no_orographic: bool = True,
    no_ocean: bool = True
):
    region_metrics = []
    ranges = [None,] * 5 + [80e3]
    for region, rng in zip(REGIONS, ranges):
        print(region)
        no_ocean = region != "KWAJ"
        no_ocean = False
        metrics = calculate_error_metrics(
            results_db_ref_all,
            groups,
            region=region,
            ranges=rng,
            no_orographic=no_orographic,
            no_ocean=no_ocean,
            no_frozen=no_frozen
        ).reset_index()
        metrics = metrics.rename(columns={"index": "Algorithm"})
        metrics["Region"] = region
        metrics["Time period"] = "A priori (2019)"

        region_metrics.append(metrics)
        metrics = calculate_error_metrics(
            results_db_all,
            groups,
            region=region,
            ranges=rng,
            no_orographic=no_orographic,
            no_ocean=no_ocean,
            no_frozen=no_frozen
        ).reset_index()
        metrics = metrics.rename(columns={"index": "Algorithm"})
        metrics["Region"] = region
        metrics["Time period"] = "MRMS (2019)"
        region_metrics.append(metrics)

        metrics = calculate_error_metrics(
            results_db_1_all,
            groups,
            region=region,
            ranges=rng,
            no_orographic=no_orographic,
            no_ocean=no_ocean,
            no_frozen=no_frozen
        ).reset_index()
        metrics = metrics.rename(columns={"index": "Algorithm"})
        metrics["Region"] = region
        metrics["Time period"] = "MRMS (2020)"
        region_metrics.append(metrics)

        metrics = calculate_error_metrics(
            results_db_2_all,
            groups,
            region=region,
            ranges=rng,
            no_orographic=no_orographic,
            no_ocean=no_ocean,
            no_frozen=no_frozen
        ).reset_index()
        metrics = metrics.rename(columns={"index": "Algorithm"})
        metrics["Region"] = region
        metrics["Time period"] = "MRMS (2021)"
        region_metrics.append(metrics)

        metrics = calculate_error_metrics(
            results_db_3_all,
            groups[1:],
            region=region,
            ranges=rng,
            no_orographic=no_orographic,
            no_ocean=no_ocean,
            no_frozen=no_frozen
        ).reset_index()
        metrics = metrics.rename(columns={"index": "Algorithm"})
        metrics["Region"] = region
        metrics["Time period"] = "MRMS (2022)"
        region_metrics.append(metrics)
        
    region_metrics = pd.concat(region_metrics)
    return region_metrics

In [None]:
region_metrics = calculate_region_metrics(no_frozen=True, no_snow_sfc=True, no_orographic=True)
region_metrics_all = calculate_region_metrics(no_frozen=False, no_snow_sfc=False, no_orographic=False)

In [None]:
from gprof_nn.validation import calculate_precip_contribution, PRECIP_TYPES
import pandas as pd

precip_names = list(PRECIP_TYPES.keys())


def calculate_precip_contribs(
    no_frozen,
    no_snow_sfc,
    no_ocean=False,
    no_orographic=True
):
    contribs = []
    precip_types = []
    regions = []
    periods = []
    results = {
        "A priori (2019)": results_db_all["gprof_nn_1d"],
        "Ground radar (2019)": results_db_all["gprof_nn_1d"],
        "Ground radar (2020)": results_db_1_all["gprof_nn_1d"],
        "Ground radar (2021)": results_db_2_all["gprof_nn_1d"],
        "Ground radar (2022)": results_db_3_all["gprof_nn_1d"],
    }

    for period, result in results.items():
        for region in REGIONS.keys():
            contribs.append(calculate_precip_contribution(result, region=region, absolute=True))
            regions.append(region)
            precip_types.append("Total")
            periods.append(period)

            for precip_type in ["Stratiform", "Convective", "Snow"]:
                if "mask" in result:
                    contribs.append(
                        calculate_precip_contribution(
                            result,
                            precip_type,
                            region=region,
                            absolute=True,
                            no_frozen=no_frozen,
                            no_snow_sfc=no_snow_sfc,
                            no_ocean=no_ocean,
                            no_orographic=no_orographic
                        )
                    )
                else:
                    contribs.append = 0.0
                regions.append(region)
                precip_types.append(precip_type)
                periods.append(period)

    return pd.DataFrame({
        "Region": regions,
        "Contribution": contribs,
        "Precip type": precip_types,
        "Period": periods
    })


In [None]:
precip_contribs = calculate_precip_contribs(no_frozen=True, no_snow_sfc=True, no_orographic=True)
precip_contribs_all = calculate_precip_contribs(no_frozen=False, no_snow_sfc=False, no_orographic=False)

In [None]:
precip_contribs = precip_contribs[precip_contribs["Precip type"] != "Snow"]
precip_contribs_all = precip_contribs_all[precip_contribs_all["Precip type"] != "Snow"]

### Regional statistics

In [None]:
from matplotlib.gridspec import GridSpec
from gprof_nn.validation import get_colors
set_style(latex=True)
COLORS = get_colors()
bar_palette = list(COLORS.values())
colors = sns.color_palette("Greys", 5)[1:]

def plot_regional_metrics(
    region_metrics,
    precip_contribs
):
    f = plt.figure(figsize=(20, 20))
    gs = GridSpec(8, 6, height_ratios= [1.0, 0.2] + [1.0] * 5 + [1.5], hspace=0.2, wspace=0.05)

    axs = np.array([
        [f.add_subplot(gs[2 + i, j]) for j in range(6)]
        for i in range(5)
    ])

    y_lims = {
        "Bias": [-35, 35],
        "MAE": [0, 0.4],
        "MSE": [0, 4.0],
        "Correlation": [0, 1.0],
        "SMAPE": [0, 120],
    }

    y_labels = {
        "Bias": "Bias [$\si{\percent}$]",
        "MAE": "MAE [$\si{\milli \meter \per \hour}$]",
        "MSE": "MSE [$(\si{\milli \meter \per \hour})^2$]",
        "SMAPE": "SMAPE$_{0.1}$ [$\si{\percent}$]",
    }

    groups = ["gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr", "combined"]
    metrics = ["Bias", "MAE", "MSE", "Correlation coeff.", "SMAPE"]

    for j, region in enumerate(REGIONS):
        ax = f.add_subplot(gs[0, j])
        m = precip_contribs[precip_contribs["Region"] == region]
        g = sns.barplot(x="Period", y="Contribution", hue="Precip type", data=m,
                        ax=ax, saturation=0.9, palette=colors)
        g.legend_.set_visible(False)
        ax.set_title(region)

        ax.set_ylim(0, 0.25)
        if j > 0:
            ax.set_ylabel("")
            ax.set_yticklabels([])
        else:
            ax.set_ylabel(r"Mean precip [$\si{\milli \meter \per \hour}$]")
        ax.set_xticklabels([])
        ax.set_xlabel("")

    lax = f.add_subplot(gs[1, :])
    lax.set_axis_off()
    lax.legend(*ax.get_legend_handles_labels(), ncol=8, loc="center")

    for i, metric in enumerate(metrics):
        for j, region in enumerate(REGIONS):
            ax = axs[i, j]
            m = region_metrics[region_metrics["Region"] == region]
            g = sns.barplot(x="Time period", y=metric, hue="Algorithm", data=m,
                            ax=ax, saturation=0.9, palette=bar_palette)
            g.legend_.set_visible(False)

            if j > 0:
                ax.set_yticklabels([])
                ax.set_ylabel(None)
            if i < 4:
                ax.set_xlabel(None)
                ax.set_xticklabels([])
            else:    
                for l in ax.get_xticklabels():
                    l.set_rotation(90)

            if i == 0:
                ax.axhline(y=0, ls="--", c="k")

            if j == 0 and metric in y_labels:
                ax.set_ylabel(y_labels[metric])

            if metric in y_lims:
                ax.set_ylim(y_lims[metric])

    lax = f.add_subplot(gs[-1, :])
    lax.set_axis_off()
    lax.legend(*ax.get_legend_handles_labels(), ncol=6, loc="lower center")
    return f

In [None]:
f = plot_regional_metrics(region_metrics, precip_contribs);
f.savefig("../../plots/validation/region_metrics.pdf", bbox_inches="tight")

In [None]:
f = plot_regional_metrics(region_metrics_all, precip_contribs_all);

In [None]:
groups = ["gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr", "combined"]

results =  {
    group: (
        xr.concat(
            [results_db_1_all[group], results_db_2_all[group]],
            dim="samples"
        )
    )
    for group in groups
}

In [None]:
del results_db_all
del results_db_1_all
del results_db_2_all

In [None]:
from matplotlib.gridspec import GridSpec
from gprof_nn.validation import NAMES, REGIONS, calculate_seasonal_cycles
n_regions = len(REGIONS)

groups = ["gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr", "combined"]

def plot_seasonal_cycles(
    no_frozen=True,
    no_snow_sfc=True,
    no_orographic=True
):
    f = plt.figure(figsize=(5 * n_regions // 2, 11))
    gs = GridSpec(4, n_regions // 2, height_ratios=[1.0, 1.0, 0.3, 0.1], hspace=0.3)
    axs = np.array([
        [f.add_subplot(gs[i, j]) for j in range(n_regions // 2)]
        for i in range(2)
    ])


    for i, region in enumerate(REGIONS):

        no_ocean = region != "KWAJ"

        x, precip_strat = calculate_seasonal_cycles(results, "reference", region=region, precip_type="Stratiform", no_frozen=no_frozen, no_snow_sfc=no_snow_sfc, no_ocean=no_ocean, no_orographic=no_orographic)
        x, precip_conv = calculate_seasonal_cycles(results, "reference", region=region, precip_type="Convective", no_frozen=no_frozen, no_snow_sfc=no_snow_sfc, no_ocean=no_ocean, no_orographic=no_orographic)
        x, mean_ref = calculate_seasonal_cycles(results, "reference", region=region, no_frozen=no_frozen, no_snow_sfc=no_snow_sfc, no_ocean=no_ocean, no_orographic=no_orographic)
        
        ax = axs[i // 3, i % 3]

        norm = mean_ref.mean()
        handles_ref = [
            ax.fill_between(x, 0, mean_ref / norm, facecolor=colors[0], label="MRMS (Total)"),
            ax.fill_between(x, 0, precip_conv / norm, facecolor=colors[2], label="MRMS (Convective)"),
            ax.fill_between(x, precip_conv / norm, (precip_conv + precip_strat) / norm, facecolor=colors[1], label="MRMS (Stratiform)"),
        ]

        handles_ret = []

        for j, group in enumerate(groups):

            #
            # Bias
            #

            x, mean = calculate_seasonal_cycles(results, group, region=region, no_ocean=no_ocean, no_frozen=no_frozen, no_snow_sfc=no_snow_sfc, no_orographic=no_orographic)

            handles_ret += ax.plot(x, mean / norm, c=COLORS[group], label=NAMES[group])
            ax.set_ylim([0, 0.5])
            if i % 3 == 0:
                ax.set_ylabel(r"Surface precip. [ $\overline{P_\text{MRMS}}$ ]")
            else:
                ax.set_ylabel(None)
                ax.set_yticklabels([])
            ax.set_title(f"({chr(ord('a') + i)}) {region}", loc="left")
            ax.set_xlim([1, 12.99])
            ax.set_xticks(np.arange(1, 13))
            ax.set_xticklabels(
                ["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"]
            )

            if i // 3 == 0:
                ax.set_xlabel(None)
                ax.set_xticklabels([])
            else:
                ax.set_xlabel("Month")
            ax.set_xlim([1, 12])
            ax.set_ylim([0, 3])

    lax = f.add_subplot(gs[-2, :])
    lax.set_axis_off()
    lax.legend(handles=handles_ref, loc="lower center", ncol=6)

    lax = f.add_subplot(gs[-1, :])
    lax.set_axis_off()
    lax.legend(handles=handles_ret, loc="lower center", ncol=6)

    return f

In [None]:
f = plot_seasonal_cycles(no_frozen=True, no_snow_sfc=True, no_orographic=True);
f.savefig("../../plots/validation/seasonal_cycles.pdf", bbox_inches="tight")

In [None]:
f = plot_seasonal_cycles(no_frozen=True, no_snow_sfc=True, no_orographic=False);

In [None]:
f = plot_seasonal_cycles(no_frozen=False, no_snow_sfc=True, no_orographic=True);

In [None]:
f = plot_seasonal_cycles(no_frozen=False, no_snow_sfc=False, no_orographic=False);

In [None]:
from matplotlib.gridspec import GridSpec
from gprof_nn.validation import NAMES, REGIONS, calculate_diurnal_cycles
n_regions = len(REGIONS)

def plot_diurnal_cycles(
    no_frozen=True,
    no_snow_sfc=True,
    no_orographic=True
):
    f = plt.figure(figsize=(5 * n_regions // 2, 11))
    gs = GridSpec(4, n_regions // 2, height_ratios=[1.0, 1.0, 0.3, 0.1], hspace=0.3)
    axs = np.array([
        [f.add_subplot(gs[i, j]) for j in range(n_regions // 2)]
        for i in range(2)
    ])

    groups = ["gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr", "combined"]


    for i, region in enumerate(REGIONS):
        
        #no_ocean = region != "KWAJ"
        no_ocean = False

        ax = axs[i // 3, i % 3]

        x, precip_strat = calculate_diurnal_cycles(results, "reference", region=region, precip_type="Stratiform", no_ocean=no_ocean, no_frozen=no_frozen, no_snow_sfc=no_snow_sfc, no_orographic=no_orographic)
        x, precip_conv = calculate_diurnal_cycles(results, "reference", region=region, precip_type="Convective", no_ocean=no_ocean, no_frozen=no_frozen, no_snow_sfc=no_snow_sfc, no_orographic=no_orographic)
        x, mean_ref = calculate_diurnal_cycles(results, "reference", region=region, no_ocean=no_ocean, no_frozen=no_frozen, no_snow_sfc=no_snow_sfc, no_orographic=no_orographic)
        norm = mean_ref.mean()

        handles_ref = [
            ax.fill_between(x, 0, mean_ref / norm, facecolor=colors[1], label="Ground radar (Total)"),
            ax.fill_between(x, 0, precip_conv / norm, facecolor=colors[2], label="Ground radar (Convective)"),
            ax.fill_between(x, precip_conv / norm, (precip_conv + precip_strat) / norm, facecolor=colors[0], label="Ground radar (Stratiform)"),
        ]

        handles_ret = []

        for j, group in enumerate(groups):

            #
            # Bias
            #

            x, mean = calculate_diurnal_cycles(results, group, region=region, no_ocean=no_ocean, no_frozen=no_frozen, no_snow_sfc=no_snow_sfc, no_orographic=no_orographic)

            handles_ret += ax.plot(x, mean / norm, c=COLORS[group], label=NAMES[group])
            ax.set_ylim([0, 0.5])
            if i % 3 == 0:
                ax.set_ylabel(r"Surface precip. [ $\overline{P_\text{MRMS}}$ ]")
            else:
                ax.set_ylabel(None)
                ax.set_yticklabels([])
            ax.set_title(f"({chr(ord('a') + i)}) {region}", loc="left")
            ax.set_xlim([1, 23.0])

            if i // 3 == 0:
                ax.set_xlabel(None)
                ax.set_xticklabels([])
            else:
                ax.set_xlabel("Local time [$\si{\hour}$]")
            ax.set_xlim([0, 23.0])
            ax.set_ylim([0, 2])


    lax = f.add_subplot(gs[-2, :])
    lax.set_axis_off()
    lax.legend(handles=handles_ref, loc="lower center", ncol=6)

    lax = f.add_subplot(gs[-1, :])
    lax.set_axis_off()
    lax.legend(handles=handles_ret, loc="lower center", ncol=6)

    return f

In [None]:
f = plot_diurnal_cycles(no_frozen=True, no_snow_sfc=True, no_orographic=True);
f.savefig("../../plots/validation/diurnal_cycles.pdf", bbox_inches="tight")

In [None]:
f = plot_diurnal_cycles(no_frozen=False, no_snow_sfc=True, no_orographic=True);
f.savefig("../../plots/validation/diurnal_cycles.pdf", bbox_inches="tight")

In [None]:
f = plot_diurnal_cycles(no_frozen=False, no_snow_sfc=False, no_orographic=False);

## Error contributions

In [None]:
sp = np.concatenate([
    results_db_all["gprof_nn_3d"].surface_precip.data,
    results_db_1_all["gprof_nn_3d"].surface_precip.data,
    results_db_2_all["gprof_nn_3d"].surface_precip.data,
    results_db_3_all["gprof_nn_3d"].surface_precip.data,
], 0)
sp_ref = np.concatenate([
    results_db_all["gprof_nn_3d"].surface_precip_ref.data,
    results_db_1_all["gprof_nn_3d"].surface_precip_ref.data,
    results_db_2_all["gprof_nn_3d"].surface_precip_ref.data,
    results_db_3_all["gprof_nn_3d"].surface_precip_ref.data,
], 0)

valid = (sp >= 0.0) * (sp_ref >= 0.0)
sp = sp[valid]
sp_ref = sp_ref[valid]

mse = (sp - sp_ref) ** 2
mae = np.abs(sp - sp_ref)

rel_valid = sp_ref > 0.1
smape = np.abs(sp[rel_valid] - sp_ref[rel_valid]) / (0.5 * (sp[rel_valid] + sp_ref[rel_valid]))

In [None]:
bins = np.logspace(np.log10(5e-2), np.log10(150), 31)

y_mse = np.histogram(sp_ref, bins=bins, weights=mse)[0]
y_mae = np.histogram(sp_ref, bins=bins, weights=mae)[0]
y_smape = np.histogram(sp_ref[rel_valid], bins=bins, weights=smape)[0]

In [None]:
fig, ax = plt.subplots(1, 1)
x = 0.5 * (bins[1:] + bins[:-1])

ax.bar(x, 100 * y_mse / np.sum(y_mse), width=np.diff(bins), facecolor="none", edgecolor="C0", label="MSE", lw=1)
ax.bar(x, 100 * y_mae / np.sum(y_mae), width=np.diff(bins), facecolor="none", edgecolor="C1", label="MAE", lw=1)
ax.bar(x, 100 * y_smape / np.sum(y_smape), width=np.diff(bins), facecolor="none", edgecolor="C2", label="SMAPE", lw=1)

ax.set_xscale("log")
ax.set_xlim([5e-2, 150])
ax.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
ax.set_title("Contribution to total error")
ax.set_ylabel("Relative contribution [\%]")
ax.set_xlabel(r"$P_\text{Reference}$ [$\si{\milli \meter \per \hour}$]")
fig.savefig("../../plots/error_contributions.pdf", bbox_inches="tight")

## Precipitation contributions

Excluded pixels:
- Snow-covered, non-mountain
- MRMS snow
- Mountain

In [None]:
groups = ["reference", "gprof_v5", "gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr", "combined"]
biases = {}
totals = {}
for group in groups:
    
    if group == "reference":
        sp = results["gprof_v5"].surface_precip_ref.data
        sp_ref = results["gprof_v5"].surface_precip_ref.data
        sfc = results["gprof_v5"].surface_type.data
        mask = results["gprof_v5"].mask.data
    else:
        sp = results[group].surface_precip.data
        sp_ref = results[group].surface_precip_ref.data
        sfc = results[group].surface_type.data
        mask = results[group].mask.data
    
    bias = sp - sp_ref
    
    valid = (mask <= 100) * (sp >= 0.0) * (sp_ref >= 0.0) * (sfc > 1)
    
    
    bias_g = biases.setdefault(group, {})
    total_g = totals.setdefault(group, {})
    
    bias_g["all"] = bias[valid].mean() / sp_ref[valid].mean() * 100
    total_g["all"] = sp[valid].sum() / sp_ref[valid].sum() * 100
    
    slct = valid * ~((sfc >= 8) * (sfc <= 11)) * (sfc < 17) * ~(np.isclose(mask, 3) + np.isclose(mask, 4))
    bias_g["selected"] = bias[slct].mean() / sp_ref[slct].mean() * 100.0
    total_g["selected"] = sp[slct].sum() / sp_ref[valid].sum() * 100.0
    
    mtn = valid * (sfc >= 17)
    bias_g["mountain"] = bias[mtn].mean() / sp_ref[mtn].mean() * 100.0
    total_g["mountain"] = sp[mtn].sum() / sp_ref[valid].sum() * 100.0
    
    snow = valid * (np.isclose(mask, 3) + np.isclose(mask, 4))
    bias_g["mrms_snow"] = bias[snow].mean() / sp_ref[snow].mean() * 100.0
    total_g["mrms_snow"] = sp[snow].sum() / sp_ref[valid].sum() * 100.0
    
    snow_sfc = valid * ((sfc >= 8) * (sfc <= 11))
    bias_g["snow_sfc"] = bias[snow_sfc].mean() / sp_ref[snow_sfc].mean() * 100.0
    total_g["snow_sfc"] = sp[snow_sfc].sum() / sp_ref[valid].sum() * 100.0
    


In [None]:
NAMES["reference"] = "MRMS"
bias = []
total = []
precip_type = []
source = []

for group in biases.keys():
    bias_g = biases[group]
    total_g = totals[group]
    for prcp_type in bias_g.keys():
        bias.append(bias_g[prcp_type])
        total.append(total_g[prcp_type])
        source.append(NAMES[group])
        precip_type.append(prcp_type)

data = pd.DataFrame({
    "Total precipitation": total,
    "Bias": bias,
    "Precipitation type": precip_type,
    "Product": source
} )

In [None]:
fig = plt.figure(figsize=(16, 4))
gs = GridSpec(1, 3, width_ratios=[1.0, 1.0, 0.5], wspace=0.4)
colors = ["grey"] + list(COLORS.values())
labels = ["All collocations", "Snow-free surface, \n no mountain, \n liquid precipitation", "Mountain", "Snow", "Snow-covered surface", ]

ax = fig.add_subplot(gs[0, 0])
g = sns.barplot(x="Precipitation type", y="Total precipitation", hue="Product", data=data, palette=colors, saturation=1.0)
g.legend().set_visible(False)
ax.set_ylabel(r"Total precipitation [\%]")
ax.set_xticklabels(labels, rotation=90)
ax.set_title("(a) Total precipitation", loc="left")

ax = fig.add_subplot(gs[0, 1])
g = sns.barplot(x="Precipitation type", y="Bias", hue="Product", data=data, palette=colors, saturation=1.0)
g.legend().set_visible(False)
ax.set_ylabel("Bias [\%]")
ax.set_xticklabels(labels, rotation=90)
ax.set_title("(b) Relative biases", loc="left")

lax = fig.add_subplot(gs[0, 2])
lax.legend(*ax.get_legend_handles_labels(), ncol=1, loc="center")
lax.set_axis_off()

fig.savefig("../../plots/precip_contributions.pdf", bbox_inches="tight")

## Explained error

In [None]:
from gprof_nn.validation import calculate_explained_error

groups = ["gprof_v7", "gprof_nn_1d", "gprof_nn_3d", "gprof_nn_hr"]


stats_db = calculate_explained_error(results_db_ref, groups + ["combined"], fpa=True, no_orographic=True)
stats_db_0 = calculate_explained_error(results_db, groups + ["combined"], fpa=True, no_orographic=True)
stats_db_1 = calculate_explained_error(results_db_1, groups + ["combined"], fpa=True, no_orographic=True)
stats_db_2 = calculate_explained_error(results_db_2, groups + ["combined"], fpa=True, no_orographic=True)
stats_db_3 = calculate_explained_error(results_db_3, groups + ["combined"], fpa=True, no_orographic=True)

In [None]:
results_db["gprof_v5"], results_db["gprof_nn_1d"]

In [None]:
stats_db["Time period"] = "Database"
stats_db = stats_db.reset_index().rename(columns={"index": "Algorithm"})
stats_db_0["Time period"] = "2019"
stats_db_0 = stats_db_0.reset_index().rename(columns={"index": "Algorithm"})
stats_db_1["Time period"] = "2020"
stats_db_1 = stats_db_1.reset_index().rename(columns={"index": "Algorithm"})
stats_db_2["Time period"] = "2021"
stats_db_2 = stats_db_2.reset_index().rename(columns={"index": "Algorithm"})
stats_db_3["Time period"] = "2022"
stats_db_3 = stats_db_3.reset_index().rename(columns={"index": "Algorithm"})

In [None]:
import pandas as pd
stats = pd.concat([
    stats_db,
    stats_db_0,
    stats_db_1,
    stats_db_2,
    stats_db_3
])
print(stats.pivot(index="Algorithm", columns="Time period").to_latex())