In [2]:
from abc import ABC, abstractmethod
import jax.numpy as jnp
from functools import partial

# Loss functions

Loss functions map a vector of model predictions and a corresponding vector of ground truthes to a scalar:
$$
\mathcal{L}: \mathbb{R}^N \times \mathbb{R}^N \rightarrow \mathbb{R}
$$

The greater the output of the loss function, the worse the model performs. A good loss function should encourage "high-quality" responses from the model, perhaps including:
 - Matching the ground truth in any given instance
 - Matching the ground truth in a particular, minority class (i.e imbalanced learning)
 - Mathcing the ground truth for "critical" observations (e.g. fraud detection)

In [3]:
class Loss(ABC):
    @abstractmethod
    def forward(self, y_est: jnp.array, y: jnp.array) -> jnp.array:
        pass

In [4]:
## Accuracy
class Accuracy(Loss):

    def forward(self, y_est, y):
        return (y == y_est).sum() / y.size

## Mean Square Error

The **Mean Squared Error (MSE)** is a convex loss function defined as the second moment of the error distribution. Given a set of true labels \( y \in \mathbb{R}^N \) and estimated predictions \( \hat{y} \in \mathbb{R}^N \), the loss is computed as:  

$$
\mathcal{L}_{\text{MSE}}(y, \hat{y}) = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
$$

where \( N \) is the number of samples. The gradient of MSE with respect to \( \hat{y} \) is given by:

$$
\frac{\partial \mathcal{L}_{\text{MSE}}}{\partial \hat{y}_i} = -\frac{2}{N} (y_i - \hat{y}_i)
$$

which follows from differentiating the squared error term. Since the function is quadratic, it ensures convexity, making it suitable for gradient-based optimization. However, due to the squared term, MSE assigns higher penalty to large deviations, increasing its sensitivity to outliers.

In [None]:
class MSE(Loss):
    @staticmethod
    def forward(y_est, y):
        err = y - y_est
        loss = (err**2).sum() / err.size
        return loss
    
class RMSE(Loss):

    @staticmethod
    def forward(y_est, y):
        err = y - y_est
        loss = (err**2).sum() / err.size
        return jnp.sqrt(loss)

### **Cross-Entropy Loss**  
The **Cross-Entropy Loss** is derived from the Kullback-Leibler (KL) divergence, measuring the difference between two probability distributions. Given true class labels $ y \in \{0,1\}^N $ and predicted probabilities $ \hat{y} \in [0,1]^N $, the **binary cross-entropy** loss is:

$$
\mathcal{L}_{\text{BCE}}(y, \hat{y}) = - \sum_{i=1}^{N} \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right]
$$

which is derived from the likelihood function of a Bernoulli distribution under maximum likelihood estimation (MLE). For multi-class classification with $ C $ possible classes, where $ y $ is one-hot encoded, the **categorical cross-entropy** extends to:

$$
\mathcal{L}_{\text{CCE}}(y, \hat{y}) = - \frac{1}{N} \sum_{i=1}^{N} \sum_{j=1}^{C} y_{ij} \log(\hat{y}_{ij})
$$

where $y_{ij} $ is the true probability (1 for the correct class, 0 otherwise), and $ \hat{y}_{ij} $ is the softmax-normalized predicted probability for class $ j $:

$$
\hat{y}_{ij} = \frac{e^{z_{ij}}}{\sum_{k=1}^{C} e^{z_{ik}}}
$$

The gradient with respect to $ z_{ij} $ simplifies to:

$$
\frac{\partial \mathcal{L}_{\text{CCE}}}{\partial z_{ij}} = \hat{y}_{ij} - y_{ij}
$$

which results in a stable gradient for optimizing classification models. Unlike MSE, cross-entropy loss is well-suited for probabilistic interpretations and avoids issues of slow convergence when used with softmax activation in multi-class problems.

In [6]:
class CrossEntropy(Loss):
    def __init__(self, eps=1e-8, axis=-1):
        self.eps = eps
        self.axis = axis
        self.forward = partial(self.forward, eps = self.eps, axis = self.axis)
    
    @staticmethod
    def forward(y_est, y, eps, axis):
        y_est = jnp.clip(y_est, eps, 1 - eps)
        loss = -jnp.mean(jnp.sum(y * jnp.log(y_est), axis=axis))

        return loss

### **Huber Loss**  
Huber loss is a **robust loss function** that is less sensitive to outliers than Mean Squared Error (MSE) while maintaining smooth gradients. It combines **L1 loss** (absolute error) and **L2 loss** (squared error) based on a threshold \( \delta \). It is formally defined as:

$$
\mathcal{L}_{\text{Huber}}(y, \hat{y}) =
\begin{cases} 
\frac{1}{2} (y - \hat{y})^2, & \text{if } |y - \hat{y}| \leq \delta \\
\delta \left( |y - \hat{y}| - \frac{1}{2} \delta \right), & \text{otherwise}
\end{cases}
$$

where $ \delta $ is a hyperparameter controlling the transition point between quadratic and linear behavior. The gradient with respect to \( \hat{y} \) is:

$$
\frac{\partial \mathcal{L}_{\text{Huber}}}{\partial \hat{y}} =
\begin{cases} 
-(y - \hat{y}), & |y - \hat{y}| \leq \delta \\
-\delta \cdot \text{sign}(y - \hat{y}), & \text{otherwise}
\end{cases}
$$

Huber loss behaves like **MSE** for small residuals and like **MAE (Mean Absolute Error)** for large residuals, providing robustness against outliers.


In [6]:
class Huber(Loss):
    def __init__(self, delta=1.0):
        self.delta = delta

    def __call__(self, y_est, y):
        err = y - y_est
        abs_err = jnp.abs(err)
        quadratic = 0.5 * (err**2)
        linear = self.delta * (abs_err - 0.5 * self.delta)
        loss = jnp.where(abs_err <= self.delta, quadratic, linear)
        return loss.mean()

### **Hinge Loss (SVM Loss)**  
Hinge loss is primarily used in **Support Vector Machines (SVMs)** for classification. Given a set of labels $ y \in \{-1, 1\} $ and predictions $ f(x) $, the **hinge loss** is defined as:
$$
\mathcal{L}_{\text{Hinge}}(y, f(x)) = \sum_{i=1}^{N} \max(0, 1 - y_i f(x_i))
$$

This loss function enforces a margin of at least 1 for correct classifications. If $ y_i f(x_i) \geq 1 $, the loss is zero (correct classification with sufficient margin), otherwise, it penalizes the incorrect or weakly correct predictions.

The gradient with respect to $ f(x) $ is:

$$
\frac{\partial \mathcal{L}_{\text{Hinge}}}{\partial f(x_i)} =
\begin{cases} 
- y_i, & y_i f(x_i) < 1 \\
0, & \text{otherwise}
\end{cases}
$$

Since hinge loss is **non-differentiable** at $ y_i f(x_i) = 1 $, subgradient methods are used in optimization.

For multi-class classification, hinge loss is extended as:

$$
\mathcal{L}_{\text{Multi-Hinge}}(y, f(x)) = \sum_{i=1}^{N} \sum_{j \neq y_i} \max(0, f_j(x_i) - f_{y_i}(x_i) + 1)
$$

where $ f_j(x_i) $ is the raw score for class $ j $, and $ f_{y_i}(x_i) $ is the score for the correct class.


In [None]:
class Hinge(Loss):
    def __init__(self):
        pass
    
    def __call__(self, y_est, y):
        loss = jnp.max(0, 1 - y * y_est)
        return loss.mean()

### **Kullback-Leibler (KL) Divergence Loss**  
KL divergence is a measure of how one probability distribution $ P(y) $ differs from another distribution $ Q(y) $ (e.g., a model's predicted distribution). It is given by:

$$
D_{\text{KL}}(P || Q) = \sum_{i=1}^{N} P(y_i) \log \frac{P(y_i)}{Q(y_i)}
$$

In the context of deep learning, $ P(y) $ represents the true probability distribution (e.g., one-hot encoded ground truth labels), and $ Q(y) $ represents the predicted probability distribution (e.g., softmax outputs). The loss function is thus:

$$
\mathcal{L}_{\text{KL}} = \sum_{i=1}^{N} y_i \log y_i - y_i \log \hat{y}_i
$$

Since the first term is independent of the model, KL divergence loss is often computed as:

$$
\mathcal{L}_{\text{KL}} = -\sum_{i=1}^{N} y_i \log \hat{y}_i
$$

which is equivalent to **categorical cross-entropy** when $ y_i $ is one-hot encoded.

The gradient with respect to $ \hat{y} $ is:

$$
\frac{\partial \mathcal{L}_{\text{KL}}}{\partial \hat{y}_i} = - \frac{y_i}{\hat{y}_i}
$$

making it suitable for classification tasks, particularly in **variational inference** and **reinforcement learning**.

In [None]:
class KLDivergence(Loss):
    def __init__(self):
        pass
    
    def __call__(self, y_est, y):
        loss = (y * (jnp.log(y) - jnp.log(y_est))).sum(axis=-1)
        return loss.mean()

## An example
MSE heavily penalizes large errors, making it highly sensitive to outliers (**2.107 → 8.013**), while Huber grows more slowly (**0.807 → 2.233**), making it more robust. Hinge and Cross-Entropy effectively penalize misclassifications, with Cross-Entropy rising sharply (**0.105 → 2.072**) when confident predictions are incorrect.

In [None]:
losses = {
    "MSE": MSE(),
    "Huber (δ=1)": Huber(delta=1.0),
    "Hinge": Hinge(),
    "KL Divergence": KLDivergence(),
    "Cross Entropy": CrossEntropy()
}

# True labels (regression and classification cases)
y_reg = jnp.array([2.0, -1.0, 3.0])  # Regression (MSE)
y_cls = jnp.array([1, -1, 1])        # Classification (Hinge)
y_prob = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]])  # OHE (KL, Cross Entropy)

# Different prediction scenarios
predictions = {
    "Good Predictions": jnp.array([[0.9, 0.1], [0.1, 0.9], [0.9, 0.1]]),  # Correct + low noise
    "Noisy Predictions": jnp.array([[0.7, 0.3], [0.3, 0.7], [0.6, 0.4]]),  # Noisy
    "Outlier Predictions": jnp.array([[0.1, 0.9], [0.9, 0.1], [0.2, 0.8]])  # Bad outliers
}

# Evaluate losses
for loss_name, loss_fn in losses.items():
    header = f"\n---  {loss_name} Loss Results:  ---"
    print(header)
    for scenario, y_est in predictions.items():
        if loss_name == "Hinge":
            result = loss_fn(y_est[:, 0] * 2 - 1, y_cls)
        elif loss_name in ["KL Divergence", "Cross Entropy"]:
            y_est = jnp.clip(y_est, 1e-7, 1.0)  # Catch log(0)
            result = loss_fn(y_est, y_prob)
        else:
            result = loss_fn(y_est[:, 0] * 2 - 1, y_reg)
        print(f"{scenario}: {result:.3f}")
    print("-"*len(header))

UsageError: Line magic function `%ignore` not found.
