# Metrics and losses

> Metrics and losses for training and evaluation.

In [None]:
# | default_exp utils.metrics

In [None]:
# | export
import jax
import jax.numpy as jnp
import numpy as np

In [None]:
# | export


def squared_error(y_true, y_pred):
    return (y_true - y_pred) ** 2


def absolute_error(y_true, y_pred):
    return jnp.abs(y_true - y_pred)


def mse(y_true, y_pred, axis=None):
    return jnp.mean((y_true - y_pred) ** 2, axis=axis)


def mae(y_true, y_pred, axis=None):
    return jnp.mean(jnp.abs(y_true - y_pred), axis=axis)


def mse_relative(y_true, y_pred, axis=None):
    return jnp.mean(((y_true - y_pred) ** 2), axis=axis) / jnp.mean(
        (y_true**2), axis=axis
    )


def mae_relative(y_true, y_pred, axis=None):
    return jnp.mean(jnp.abs(y_true - y_pred), axis=axis) / jnp.mean(
        jnp.abs(y_true), axis=axis
    )


def accumulate_metrics(metrics):
    metrics = jax.device_get(metrics)
    return {k: np.mean([metric[k] for metric in metrics]) for k in metrics[0]}