# Wrap metrics/compute binary metrics

If we want to compute binary metrics like the CSI (Critical Success Index) from real valued forecasts, these need to be thresholded first. For this, we wrap the metrics with input transforms since the metrics/statistics expect the data to be in binary format already.

Let's take an example of an ensemble forecast that we want to compute the CSI for. Doing this requires several transforms on the prediction and target data.

The continuous, real-valued forecasts need to be converted to binary forecasts based on a threshold value. (In this case for total_precipitation).
Then the binary ensembles have to be averaged to produce a probability forecast for each of the thresholds.
Finally, the probability forecasts have to be thresholded by probability values to produce a binary output that we can compute the CSI for.
Let's load the data and apply all the wrappers around the CSI metric.

In [1]:
# IMPORTANT: If you are running this on Colab, uncomment the cell below to access the cloud datasets.
# from google.colab import auth
# auth.authenticate_user()

In [2]:
import numpy as np
from weatherbenchX import aggregation
from weatherbenchX.data_loaders import xarray_loaders
from weatherbenchX.metrics import categorical
from weatherbenchX.metrics import wrappers

In [3]:
prediction_path = 'gs://weatherbench2/datasets/ifs_ens/2018-2022-64x32_equiangular_conservative.zarr'
target_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_conservative.zarr'

In [4]:
variables = ['total_precipitation_6hr']
target_data_loader = xarray_loaders.TargetsFromXarray(
    path=target_path,
    variables=variables,
)
prediction_data_loader = xarray_loaders.PredictionsFromXarray(
    path=prediction_path,
    variables=variables,
)

In [5]:
init_times = np.array(['2020-01-01T00'], dtype='datetime64[ns]')
lead_times = np.array([6], dtype='timedelta64[h]').astype('timedelta64[ns]')   # To silence xr warnings.

In [6]:
target_chunk = target_data_loader.load_chunk(init_times, lead_times)
prediction_chunk = prediction_data_loader.load_chunk(init_times, lead_times)

In [7]:
target_chunk

In [8]:
prediction_chunk

Note that the wrappers are applied in the order of the given list, so in this case ContinuousToBinary is applied first.

In [9]:
wrapped_csi = wrappers.WrappedMetric(
    metric=categorical.CSI(),
    transforms=[
        wrappers.ContinuousToBinary(
            which='both',
            threshold_value=[1/1000, 5/1000],   # Raw values are in m
            threshold_dim='threshold_precipitation'
        ),
        wrappers.EnsembleMean(
            which='predictions', ensemble_dim='number'
        ),
        wrappers.ContinuousToBinary(
            which='predictions',
            threshold_value=[0.25, 0.75],
            threshold_dim='threshold_probability'
        ),
    ],
)
metrics = {'csi': wrapped_csi}

In [10]:
aggregator = aggregation.Aggregator(
  reduce_dims=['init_time', 'latitude', 'longitude'],
)
aggregation.compute_metric_values_for_single_chunk(
    metrics,
    aggregator,
    prediction_chunk,
    target_chunk
)

As we can see the final result has two additional dimensions: `threshold_precipitation` and `threshold_probability`.