# Impact of ancillary variables

This notebook assesses the impact of different ancillary retrieval inputs.

## Data

The evaluation uses the same test dataset used to assess GPROF-NN against the original GPROF, i.e, collocations of GMI and GPM-CMB observations from days 1, 2, 3 of every month from the water year 2019.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from gprof_nn.plotting import set_style
set_style(latex=False)

In [None]:
# Configure plotting
import seaborn as sns
from gprof_nn.plotting import set_style
sns.reset_orig()
set_style(latex=False)

In [None]:
from pathlib import Path
from gprof_nn.data.training_data import decompress_and_load

def load_results(path, variables):
    """
    Load retrieval results from directory.
    
    Args:
        path: The directory from which to load the results.
        variables: The variables to load from the results files.
    
    Return:
        An xarray.Dataset containing all retrieval results found in the
        given directory.
    """
    files = list(Path(path).glob("*.nc"))
    files += list(Path(path).glob("*nc.gz"))
    datasets = []
    for f in files:
        vs = variables
        data = decompress_and_load(f)
        for v in variables:
            if v + "_true" in data.variables:
                vs.append(v + "_true")
        datasets.append(data[vs])
        data.close()
        
    results = xr.concat(datasets, dim="samples")
    
    if "pixels" in results.dims:
        results_flat = results.stack(samples_new = ("samples", "scans", "pixels"))
        results = results_flat.rename_dims({"samples_new": "samples"})
    return results
    

In [None]:
variables = ["surface_precip", "surface_type"]
# Results using new preprocessor and all ancillary variables
results_all = load_results(
    "/gdata1/simon/gprof_nn/results/gmi_new/gprof_nn_1d/",
    variables
)
results_no_t2m = load_results(
    "/gdata1/simon/gprof_nn/results/gmi_dropped_15/",
    variables
)
results_3d = load_results(
    "/gdata1/simon/gprof_nn/results/gmi_new/gprof_nn_3d/",
    variables
)

In [None]:
import pandas as pd

def calculate_error_statistics(data, configuration):
    """
    Calculates surface precip errors statistics.
    
    Args:
        data: An xarray.Dataset containing the retrieval results.
        configuration: A string describing the configuration for which the
            stats are calculated, which will be included in the resulting
            dataframe.
        
    Return:
        A pandas Dataframe containing the error statistics.
    """
    
    true = data.surface_precip_true.data
    retrieved = data.surface_precip.data
    
    valid = (true >= -999) 
    true = true[valid]
    retrieved = retrieved[valid]
    
    bias = (true - retrieved).mean()
    mse = ((true - retrieved) ** 2).mean()
    corr = np.corrcoef(true, retrieved)[0, 1]
    
    return pd.DataFrame({
        "Bias": [bias],
        "MSE": [mse],
        "Correlation": [corr],
        "Configuration": [configuration]
    })
    

## All surfaces

We start by assessing the accuracy for all samples in the test data.

In [None]:
stats_all = calculate_error_statistics(results_all, "1D, All")
stats_no_tcwv = calculate_error_statistics(results_no_t2m, "1D, No t2m")
stats_3d = calculate_error_statistics(results_3d, "3D, All")

In [None]:
stats = pd.concat([
    stats_all,
    stats_no_tcwv,
    stats_3d
])

In [None]:
stats

In [None]:
f = plt.figure(figsize=(18, 7))
ax = f.add_subplot(1, 3, 1)
sns.barplot(stats, x="Configuration", y="Bias", ax=ax)
for l in ax.xaxis.get_ticklabels():
    l.set_rotation(45)

ax = f.add_subplot(1, 3, 2)
sns.barplot(stats, x="Configuration", y="MSE", ax=ax)
for l in ax.xaxis.get_ticklabels():
    l.set_rotation(45)

ax = f.add_subplot(1, 3, 3)
sns.barplot(stats, x="Configuration", y="Correlation", ax=ax)
for l in ax.xaxis.get_ticklabels():
    l.set_rotation(45)
    
f.suptitle("Ocean (surface type 1)", y=1.1)
plt.tight_layout()
f.savefig("metrics_all.png", dpi=200, bbox_inches="tight")

## Ocean surfaces

Evaluation restricted to pixels with ``surface_type = 1``.

In [None]:
results_all_ocean = results_all[{"samples": results_all.surface_type.data == 1}]
results_no_t2m_ocean = results_no_t2m[{"samples": results_no_t2m.surface_type.data == 1}]
results_3d_ocean = results_3d[{"samples": results_3d.surface_type.data == 1}]

In [None]:
stats_all = calculate_error_statistics(results_all_ocean, "1D, All")
stats_no_tcwv = calculate_error_statistics(results_no_t2m_ocean, "1D, No t2m")
stats_3d = calculate_error_statistics(results_3d_ocean, "3D, All")

In [None]:
stats = pd.concat([
    stats_all,
    stats_no_tcwv,
    stats_3d
])

In [None]:
f = plt.figure(figsize=(18, 7))
ax = f.add_subplot(1, 3, 1)
sns.barplot(stats, x="Configuration", y="Bias", ax=ax)
for l in ax.xaxis.get_ticklabels():
    l.set_rotation(45)

ax = f.add_subplot(1, 3, 2)
sns.barplot(stats, x="Configuration", y="MSE", ax=ax)
for l in ax.xaxis.get_ticklabels():
    l.set_rotation(45)

ax = f.add_subplot(1, 3, 3)
sns.barplot(stats, x="Configuration", y="Correlation", ax=ax)
for l in ax.xaxis.get_ticklabels():
    l.set_rotation(45)
    
f.suptitle("Ocean (surface type 1)", y=1.1)
plt.tight_layout()
f.savefig("metrics_all.png", dpi=200, bbox_inches="tight")