# Utility Functions

## Introduction

This notebook provides a practical introduction to the core utility functions in `probly`.
These helpers are essential building blocks for training probabilistic models and quantifying uncertainty.

We will focus on two main categories:

- **Model traversal functions**, which inspect a model’s architecture
- **Uncertainty quantification functions**, which compute meaningful uncertainty scores from model predictions

---

## Key Utility Functions in `probly`

### 1. `collect_kl_divergence` (for BNNs)

**What it does:**
Automatically traverses a Bayesian Neural Network and sums the KL divergence from each Bayesian layer.

**Why it’s useful:**
This function is critical for computing the **ELBO loss** during training.

---

### 2. `total_entropy`, `conditional_entropy`, `mutual_information`

**What they do:**
These functions take a set of predictions (for example, from an ensemble) and decompose predictive uncertainty.

**Why they’re useful:**
They allow you to separately measure:

- **Aleatoric uncertainty** (inherent randomness in the data)
- **Epistemic uncertainty** (uncertainty due to limited model knowledge)

---

### 3. `evidential_uncertainty` (for Evidential Models)

**What it does:**
Computes an uncertainty score directly from the **evidence vector** produced by an evidential model.

**Why it’s useful:**
It provides a fast, single-pass way to determine whether a model is uncertain about its prediction.


In [None]:
# Example 1: collect_kl_divergence for Bayesian Neural Networks
from probly.transformation import bayesian
from probly.train.bayesian.torch import collect_kl_divergence
import torch
from torch import nn

# Create and transform a model to Bayesian
model = nn.Sequential(
    nn.Linear(10, 32),
    nn.ReLU(),
    nn.Linear(32, 3)
)
bnn_model = bayesian(model)

# Create dummy input
inputs = torch.randn(4, 10)

# Forward pass (this samples weights from distributions)
outputs = bnn_model(inputs)

# Collect KL divergence from all Bayesian layers
kl = collect_kl_divergence(bnn_model)

print(f"Total KL Divergence: {kl.item():.4f}")

In [None]:
# Example 2: Decomposing uncertainty with entropy functions
from probly.quantification.classification import total_entropy, conditional_entropy, mutual_information
import numpy as np

# Simulated predictions from a 3-member ensemble for 2 instances, 3 classes
# Shape: (num_instances, num_samples, num_classes)
ensemble_predictions = np.array([
    # Instance 1: High agreement (low epistemic uncertainty)
    [[0.8, 0.1, 0.1],
     [0.75, 0.15, 0.1],
     [0.85, 0.1, 0.05]],
    
    # Instance 2: High disagreement (high epistemic uncertainty)
    [[0.7, 0.2, 0.1],
     [0.2, 0.7, 0.1],
     [0.1, 0.2, 0.7]]
])

# Compute uncertainty metrics
total_ent = total_entropy(ensemble_predictions)
cond_ent = conditional_entropy(ensemble_predictions)  # Aleatoric uncertainty
mutual_info = mutual_information(ensemble_predictions)  # Epistemic uncertainty

print("Instance 1 (models agree):")
print(f"  Total Entropy: {total_ent[0]:.4f}")
print(f"  Aleatoric Uncertainty: {cond_ent[0]:.4f}")
print(f"  Epistemic Uncertainty: {mutual_info[0]:.4f}")

print("\nInstance 2 (models disagree):")
print(f"  Total Entropy: {total_ent[1]:.4f}")
print(f"  Aleatoric Uncertainty: {cond_ent[1]:.4f}")
print(f"  Epistemic Uncertainty: {mutual_info[1]:.4f}")

In [None]:
# Example 3: evidential_uncertainty for Evidential Models
from probly.quantification.classification import evidential_uncertainty
import numpy as np

# Simulated evidence vectors (alpha values) from an evidential model
# High evidence = confident, low evidence = uncertain

# Confident prediction: lots of evidence for class 0
confident_evidence = np.array([[100.0, 2.0, 3.0]])

# Uncertain prediction: little evidence for any class
uncertain_evidence = np.array([[1.0, 1.0, 1.0]])

# Compute uncertainty scores
confident_uncertainty = evidential_uncertainty(confident_evidence)
uncertain_uncertainty = evidential_uncertainty(uncertain_evidence)

print(f"Confident prediction evidence: {confident_evidence[0]}")
print(f"  Uncertainty score: {confident_uncertainty[0]:.4f}")

print(f"\nUncertain prediction evidence: {uncertain_evidence[0]}")
print(f"  Uncertainty score: {uncertain_uncertainty[0]:.4f}")