# Implement a new metric

In [3]:
import numpy as np
import xarray as xr
from weatherbenchX.metrics import base
from weatherbenchX.metrics import deterministic

Metrics in WeatherBench-X are defined by a set of statistics and instructions how to compute the final metrics value from the averaged statistics.

Statistics are computed from the predictions and targets for each element. Further, statistics are divided into single variable statistics (computed separately for each variable; most common use case) and multi-variate statistics (where statistics are computed as a function of several variables).

As a simple example, let's take the RMSE. Here, the statistic in the squared error which is a per-variable computation.

```python
class SquaredError(base.PerVariableStatistic):
  """Squared error between predictions and targets."""

  def compute_per_variable(
      self,
      predictions: xr.DataArray,
      targets: xr.DataArray,
  ) -> xr.DataArray:
    return (predictions - targets) ** 2
```

The RMSE metric specifies the SquaredError statistic and takes the square root over it from the aggregated values.

```python
class RMSE(base.PerVariableMetric):
  """Root mean squared error."""

  @property
  def statistics(self) -> Mapping[Hashable, base.Statistic]:
    return {'SquaredError': SquaredError()}

  def _values_from_mean_statistics_per_variable(
      self,
      statistic_values: Mapping[Hashable, xr.DataArray],
  ) -> xr.DataArray:
    """Computes metrics from aggregated statistics."""
    return np.sqrt(statistic_values['SquaredError'])
```

In [4]:
predictions = xr.Dataset({'2m_temperature': xr.DataArray(np.ones((2, 32, 64)), dims=['init_time', 'latitude', 'longitude'])})
targets = predictions.copy()
predictions

In [5]:
rmse = deterministic.RMSE()

In [6]:
statistic_values = {name: statistic.compute(predictions, targets) for name, statistic in rmse.statistics.items()}
statistic_values

{'SquaredError': {'2m_temperature': <xarray.DataArray '2m_temperature' (init_time: 2, latitude: 32, longitude: 64)> Size: 33kB
  array([[[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]],
  
         [[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]]])
  Dimensions without coordinates: init_time, latitude, longitude}}

Take the mean now. Here we do it explicitly for a single metrics. Typically, this would be done in `compute_unique_statistics_for_all_metrics`.

In [8]:
statistic_values['SquaredError'] = {k: v.mean() for k,v in statistic_values['SquaredError'].items()}
statistic_values

{'SquaredError': {'2m_temperature': <xarray.DataArray '2m_temperature' ()> Size: 8B
  array(0.)}}

Now we can compute the metric (in this case take the square root) from the averaged statistic.

In [9]:
rmse.values_from_mean_statistics(statistic_values)

{'2m_temperature': <xarray.DataArray '2m_temperature' ()> Size: 8B
 array(0.)}

Note: Some metrics can have more than one statistic. See, for example, the ensemble CRPS implementation.