# Custom Loss Functions

This notebook provides a practical introduction to the specialized loss functions in `probly`. While standard losses like `nn.CrossEntropyLoss` are sufficient for deterministic models, probabilistic models often require custom loss functions to handle uncertainty.

We will cover three key types of custom losses:
-   **Negative Log-Likelihood (NLL) Losses:** Adaptations for probabilistic outputs.
-   **Evidential Losses:** Specialized functions for models that learn "evidence."
-   **Calibration-Aware Losses:** Losses that directly optimize for model calibration.

---

## 1. Negative Log-Likelihood (NLL) Losses
NLL losses are a foundational concept in training probabilistic models. Instead of just penalizing wrong predictions, they evaluate how well the entire predicted *distribution* explains the true target.

### Example: The ELBO Loss for Bayesian Neural Networks
A Bayesian Neural Network (BNN) requires a unique loss function that balances two goals:
1.  **Fit the data:** Make accurate predictions (similar to a standard loss).
2.  **Stay simple:** Keep the weight distributions close to a simple prior distribution.

The **Evidence Lower Bound (ELBO)** loss achieves this.

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

from probly.train.bayesian.torch import collect_kl_divergence
from probly.transformation import bayesian


class ELBOLoss(nn.Module):
    """Evidential Lower Bound Loss."""

    def __init__(self, kl_penalty: float = 1e-5) -> None:
        """Initialize the loss.

        Args:
            kl_penalty: The penalty weight for the KL divergence term.

        """
        super().__init__()
        self.kl_penalty = kl_penalty

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor, kl: torch.tensor) -> torch.Tensor:
        """Compute the ELBO loss.

        Args:
            inputs: The input tensor.
            targets: The target tensor.
            kl: The KL divergence tensor.

        Returns:
            The calculated loss.

        """
        # 1. Standard Cross-Entropy
        cross_entropy_loss = F.cross_entropy(inputs, targets)

        # 2. KL Divergence Regularizer
        kl_divergence = self.kl_penalty * kl

        return cross_entropy_loss + kl_divergence

The ELBOLoss combines a standard cross-entropy loss with a KL Divergence term, which penalizes the model for having weight distributions that are too complex or far from the initial prior
For more information on how Bayesian models work, see the [Bayesian Transformation](../bayesian_transformation.ipynb) tutorial.


## 2. Evidential Losses

Evidential Deep Learning models do not output probabilities directly.
Instead, they output **evidence** for each class, which requires specialized loss functions.

### Example: Evidential Losses for Classification and Regression

The `probly` library provides custom loss functions for evidential learning, based on the original research papers:

- **EvidentialLogLoss (Classification)**
  Adapts the standard log loss to work with evidence scores (`alpha`) instead of probabilities.

- **EvidentialNIGNLLLoss (Regression)**
  A more complex negative log-likelihood (NLL) loss that handles the four parameters of an evidential regression model:
  `gamma`, `nu`, `alpha`, and `beta`.

### Training an Evidential Model

The training loop for an evidential model typically combines:

1. An evidential NLL loss (classification or regression), and
2. A regularization term that encourages the model to remain uncertain on out-of-distribution data.
The total loss is a weighted sum of these two components.

In [None]:
# Example: Training with ELBO Loss
import torch
from torch import nn

# Create a simple model and transform it to Bayesian
model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 3))
bnn_model = bayesian(model)

# Create the ELBO loss
criterion = ELBOLoss(kl_penalty=1e-5)

# Dummy data (8 samples, 10 features, 3 classes)
inputs = torch.randn(8, 10)
targets = torch.randint(0, 3, (8,))

# Forward pass
outputs = bnn_model(inputs)

# Collect KL divergence from all Bayesian layers
kl = collect_kl_divergence(bnn_model)

# Compute loss
loss = criterion(outputs, targets, kl)

print(f"ELBO Loss: {loss.item():.4f}")
print(f"  - Cross-Entropy component: {nn.functional.cross_entropy(outputs, targets).item():.4f}")
print(f"  - KL Divergence component: {(criterion.kl_penalty * kl).item():.6f}")
print(f"\nTotal KL from all layers: {kl.item():.4f}")

For full implementations, see the [**Evidential Classification**](../train_evidential_classification.ipynb) and [**Evidential Regression**](../train_evidential_regression.ipynb) tutorials.

## 3. Calibration-Aware Losses

Sometimes, the most effective way to achieve good calibration is to include a calibration objective directly in the loss function.
This forces the model to optimize calibration as part of training.

### Example: Label Relaxation

**Label Relaxation** is a simple but effective technique for reducing over-confidence and improving model calibration.
Instead of using hard one-hot encoded labels (e.g., `[0, 0, 1]`), the labels are softened:

- The true class is assigned a slightly lower value (e.g., `0.9`).
- The remaining probability mass (e.g., `0.1`) is distributed across the other classes.

This discourages the model from producing extreme, over-confident predictions.

The `probly` library provides a direct implementation of this approach through the **`LabelRelaxationLoss`**.
The `LabelRelaxationLoss` can be used as a drop-in replacement for standard losses like `nn.CrossEntropyLoss`, making it easy to integrate into existing training pipelines.

In [None]:
# Example: Training with Label Relaxation
import torch
from torch import nn

from probly.train.calibration.torch import LabelRelaxationLoss

# Create a simple classifier
model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 3))

# Use Label Relaxation instead of CrossEntropyLoss
# alpha=0.1 means: true class gets 0.9, other classes share 0.1
criterion = LabelRelaxationLoss(alpha=0.1)

# Standard CrossEntropyLoss for comparison
standard_criterion = nn.CrossEntropyLoss()

# Dummy data
inputs = torch.randn(8, 10)
targets = torch.randint(0, 3, (8,))

# Forward pass
outputs = model(inputs)

# Compare losses
relaxed_loss = criterion(outputs, targets)
standard_loss = standard_criterion(outputs, targets)

print(f"Standard CrossEntropy Loss: {standard_loss.item():.4f}")
print(f"Label Relaxation Loss:      {relaxed_loss.item():.4f}")

By optimizing this "softer" objective, the model learns to produce better-calibrated probability estimates.

For a full implementation, see the [**Label Relaxation Calibration**](../label_relaxation_calibration.ipynb) tutorial.