# Show metrics are the same

Create some tensors

In [25]:
x = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]]).T
y = torch.tensor([[0, 1, 2, 1], [2, 3, 4, 4]]).T
o = torch.tensor([0.25, 0.25, 0.3, 0.2])
a = torch.tensor([0.25, 0.25])

Absolute Error Metrics implementation

In [35]:
from torch import Tensor
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics import Metric

def _absolute_error_update(preds: Tensor, target: Tensor, area: Tensor) -> Tensor:
    _check_same_shape(preds, target)
    diff = torch.abs(preds - target)
    sum_absolute_error = torch.sum(diff * diff * area, axis=1)
    return sum_absolute_error


def _absolute_error_compute(sum_absolute_error: Tensor, omegas: Tensor) -> Tensor:
    return torch.sum(sum_absolute_error * omegas.squeeze())


def absolute_error(
    preds: Tensor, target: Tensor, omegas: Tensor, area: Tensor
) -> Tensor:
    """
    Computes squared absolute error
    Args:
        preds: estimated labels
        target: ground truth labels
        omegas
    Return:
        Tensor with ASE
    Example:
        >>> from torchmetrics.functional import mean_squared_error
        >>> x = torch.tensor([0., 1, 2, 3])
        >>> y = torch.tensor([0., 1, 2, 2])
        >>> mean_squared_error(x, y)
        tensor(0.2500)
    """
    sum_abs_error = _absolute_error_update(preds, target, area)
    return _absolute_error_compute(sum_abs_error, omegas)


Criteriona AE implementation

In [36]:
def criterion_ae(F_pred, F_obs, omegas, area):
    instance_misfit = torch.sum(torch.abs(F_pred - F_obs) ** 2 * area, axis=1)
    return torch.sum(instance_misfit * omegas.squeeze())

In [40]:
from numpy.testing import assert_almost_equal

This shows that both codes produce the same result

In [49]:
print(absolute_error(x, y, o, a))
print(criterion_ae(x, y, o, a))
assert_almost_equal(absolute_error(x, y, o, a), criterion_ae(x, y, o, a), decimal=4)

tensor(0.4000)
tensor(0.4000)


Now split into 2 batches

In [50]:
sum_ae_1 = _absolute_error_update(x[0:2], y[0:2], a)
sum_ae_2 = _absolute_error_update(x[2:4], y[2:4], a)
sum_ae = _absolute_error_compute(sum_ae_1, o[0:2]) + _absolute_error_compute(sum_ae_2, o[2:4])

This shows that we can compute the absolute error per batch and then sum up:

In [52]:
assert_almost_equal(absolute_error(x, y, o, a), sum_ae, decimal=4)

However, we do have to refactor the code to properly implement it