# Credal Sets Exploration Notebook

This notebook gives a very simple, beginner-friendly view of credal sets.
We start from standard probability distributions, which are vectors of nonnegative numbers that sum to 1.
An ensemble model can produce multiple probability distributions for the same input.
The **credal set** is the set of all these probability distributions produced by the ensemble.



## Minimal intuition

To summarize an ensemble of probability distributions, we can look at **lower** and **upper** probabilities per class.
For each class, the **lower** value is the minimum probability that any ensemble member gives to that class.
For each class, the **upper** value is the maximum probability that any ensemble member gives to that class.
These lower/upper values form an interval that tells us a simple range of belief for each class.
First, we will create a small NumPy matrix P, that represents the ensemble output/predictions.


In [1]:
import numpy as np

In [2]:
P = np.array(
    [
        [0.6, 0.3, 0.1],
        [0.5, 0.4, 0.1],
        [0.7, 0.2, 0.1],
        [0.4, 0.4, 0.2],
        [0.55, 0.25, 0.20],
    ],
    dtype=float,
)

Now, we have to do basic sanity checks on P, to make sure that no probability is negative. It will also sum each row.

In [3]:
nonneg_msg = "Probabilities must be nonnegative."
row_sum_msg = "Each row must sum to 1.0."

if not np.all(P >= 0.0):
    raise ValueError(nonneg_msg)

if not np.allclose(P.sum(axis=1), 1.0):
    raise ValueError(row_sum_msg)

In [4]:
# Lower and upper probability envelopes per class
lower = P.min(axis=0)
upper = P.max(axis=0)

print("Ensemble probabilities P (rows = members, columns = classes):")
print(P)
print()
print("Lower envelope per class:")
print(lower)
print()
print("Upper envelope per class:")
print(upper)

Ensemble probabilities P (rows = members, columns = classes):
[[0.6  0.3  0.1 ]
 [0.5  0.4  0.1 ]
 [0.7  0.2  0.1 ]
 [0.4  0.4  0.2 ]
 [0.55 0.25 0.2 ]]

Lower envelope per class:
[0.4 0.2 0.1]

Upper envelope per class:
[0.7 0.4 0.2]


## Interpreting lower and upper envelopes

The **lower** value for a class is the most conservative belief across all ensemble members.
It says, "even the least confident member does not go below this probability for this class."
The **upper** value for a class is the most optimistic belief across all ensemble members.
It says, "at least one member goes up to this probability for this class."
If the interval between lower and upper is **wide**, the ensemble members disagree a lot, so uncertainty is high.
If the interval is **narrow**, the ensemble members are more in agreement, so uncertainty is lower.


In [5]:
# Simple summary table: lower, upper, width per class
interval_width = upper - lower

for class_idx, (lo, up, width) in enumerate(zip(lower, upper, interval_width, strict=True)):
    print(f"Class {class_idx}: lower={lo:.3f}, upper={up:.3f}, width={width:.3f}")

Class 0: lower=0.400, upper=0.700, width=0.300
Class 1: lower=0.200, upper=0.400, width=0.200
Class 2: lower=0.100, upper=0.200, width=0.100


## What is the credal set here?

Each row of `P` is one probability distribution over the 3 classes.
The whole matrix `P` collects several such distributions from different ensemble members.
The **credal set** in this example is simply the set of all these rows.
It is a small, discrete set of possible beliefs about the same input.
The lower and upper envelopes we computed are a basic summary of this credal set.


## Relation to Probly (high-level only)

In Probly, credal sets are organized in a small hierarchy, with classes like `CredalSet`, `DiscreteCredalSet`, and `CategoricalCredalSet`.
The idea is that a credal set is a set of probability distributions, with some structure and operations defined on it.
In the NumPy implementation, these distributions are stored in arrays with shapes that look roughly like `(..., num_members, num_classes)`.
This means we track, for each input, several members (distributions) over a fixed number of classes.
Our simple matrix `P` in this notebook mirrors that idea conceptually, but in a tiny, stand-alone form.
Here we only explain the concept and do not import or depend on Probly code.


## Basic Experimentation

Now we run a few simple experiments to build intuition about credal sets and envelope bounds.
All experiments use small ensembles (3 classes, a few to tens of members) and focus on how disagreement affects the lower/upper intervals.


In [6]:
rng = np.random.default_rng(42)


# Helper function to validate probability distributions
def validate_probs(p: np.ndarray) -> None:
    nonneg_msg = "Probabilities must be nonnegative."
    row_sum_msg = "Each row must sum to 1.0."
    if not np.all(p >= 0.0):
        raise ValueError(nonneg_msg)
    if not np.allclose(p.sum(axis=1), 1.0):
        raise ValueError(row_sum_msg)


# Helper function to compute envelopes
def envelope(p: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    lower = p.min(axis=0)
    upper = p.max(axis=0)
    return lower, upper

### Experiment 1: Low vs High disagreement

We compare two ensembles: one where members agree (low disagreement) and one where they disagree more (high disagreement).
We expect wider uncertainty intervals when disagreement is higher.


In [7]:
# Low disagreement: all members are very similar
P_low = np.array(
    [
        [0.6, 0.3, 0.1],
        [0.61, 0.29, 0.1],
        [0.59, 0.31, 0.1],
        [0.6, 0.3, 0.1],
        [0.6, 0.3, 0.1],
    ],
    dtype=float,
)
validate_probs(P_low)

# High disagreement: members differ more
P_high = np.array(
    [
        [0.8, 0.15, 0.05],
        [0.3, 0.5, 0.2],
        [0.5, 0.3, 0.2],
        [0.2, 0.6, 0.2],
        [0.4, 0.4, 0.2],
    ],
    dtype=float,
)
validate_probs(P_high)

# Compute envelopes and widths
lower_low, upper_low = envelope(P_low)
width_low = upper_low - lower_low
avg_width_low = width_low.mean()

lower_high, upper_high = envelope(P_high)
width_high = upper_high - lower_high
avg_width_high = width_high.mean()

print("Low disagreement ensemble:")
print(f"  Average interval width: {avg_width_low:.4f}")
print(f"  Lower: {lower_low}")
print(f"  Upper: {upper_low}")
print()
print("High disagreement ensemble:")
print(f"  Average interval width: {avg_width_high:.4f}")
print(f"  Lower: {lower_high}")
print(f"  Upper: {upper_high}")
print()
print(f"P_high has wider bounds: {avg_width_high > avg_width_low}")

Low disagreement ensemble:
  Average interval width: 0.0133
  Lower: [0.59 0.29 0.1 ]
  Upper: [0.61 0.31 0.1 ]

High disagreement ensemble:
  Average interval width: 0.4000
  Lower: [0.2  0.15 0.05]
  Upper: [0.8 0.6 0.2]

P_high has wider bounds: True


As expected, the high disagreement ensemble has wider uncertainty intervals.
When ensemble members disagree more, the gap between lower and upper bounds grows, reflecting higher uncertainty.


### Experiment 2: Effect of number of ensemble members

We generate ensembles with different numbers of members, all centered around the same distribution but with small random variations.
This shows how the number of members affects the bounds we observe.


In [8]:
# Center distribution
c = np.array([0.6, 0.3, 0.1])

# Generate ensembles with different member counts
member_counts = [3, 5, 20]
scale = 0.05  # Noise scale

results = []
for n in member_counts:
    # Generate noise and add to center
    noise = rng.normal(0, scale, size=(n, 3))
    P_n = c + noise
    # Clip to nonnegative and renormalize
    P_n = np.clip(P_n, 0, None)
    P_n = P_n / P_n.sum(axis=1, keepdims=True)
    validate_probs(P_n)

    # Compute envelopes
    lower_n, upper_n = envelope(P_n)
    width_n = upper_n - lower_n
    avg_width_n = width_n.mean()

    results.append((n, avg_width_n, lower_n, upper_n))
    print(f"n={n:2d} members: average interval width = {avg_width_n:.4f}")

print()
print("Summary: More members can reveal more extreme values, affecting bounds.")

n= 3 members: average interval width = 0.0917
n= 5 members: average interval width = 0.0772
n=20 members: average interval width = 0.1119

Summary: More members can reveal more extreme values, affecting bounds.


The bounds depend on both the variability of the distributions and the sample size (number of members).
With more members, we are more likely to observe extreme values, which can widen the lower/upper bounds.


### Experiment 3: Averaging loses information

A common approach is to average ensemble predictions into a single distribution.
However, this hides the disagreement between members.
We show that the average lies within the bounds but does not capture the uncertainty range.


In [None]:
p_mean = P_high.mean(axis=0)


validate_probs(p_mean.reshape(1, -1))


lower_high, upper_high = envelope(P_high)

print("Averaged distribution (p_mean):")
print(p_mean)
print()
print("Credal set bounds (lower, upper):")
for class_idx in range(3):
    lo = lower_high[class_idx]
    up = upper_high[class_idx]
    pm = p_mean[class_idx]
    inside = lo <= pm <= up
    print(f"  Class {class_idx}: [{lo:.3f}, {up:.3f}]")
    print(f"    p_mean[{class_idx}] = {pm:.3f} (inside bounds: {inside})")
print()
print("The average lies within the bounds but does not show the width of uncertainty.")

Averaged distribution (p_mean):
[0.44 0.39 0.17]

Credal set bounds (lower, upper):
  Class 0: [0.200, 0.800]
    p_mean[0] = 0.440 (inside bounds: True)
  Class 1: [0.150, 0.600]
    p_mean[1] = 0.390 (inside bounds: True)
  Class 2: [0.050, 0.200]
    p_mean[2] = 0.170 (inside bounds: True)

The average lies within the bounds but does not show the width of uncertainty.


## ProbabilityIntervals (bounds-based credal set)

A **ProbabilityIntervals** credal set describes our uncertainty using **per-class lower and upper probability bounds** instead of listing all individual distributions.
This representation was explicitly discussed as a next, simple step in the sprint: we only store intervals, not the full set of distributions.

For a vector of lower bounds `lower` and upper bounds `upper` (one entry per class), we require:
- **Validity of each bound**: \(0 \leq \text{lower}_i \leq \text{upper}_i \leq 1\)
- **Global mass constraints**: \(\sum_i \text{lower}_i \leq 1 \leq \sum_i \text{upper}_i\)

These conditions ensure that there exists at least one probability distribution that respects all the intervals at the same time.

In [10]:
# Example ProbabilityIntervals for 3 classes
lower_bounds = np.array([0.1, 0.2, 0.0], dtype=float)
upper_bounds = np.array([0.6, 0.7, 0.4], dtype=float)


def validate_probability_intervals(lower: np.ndarray, upper: np.ndarray) -> None:
    """Validate simple bounds-based credal set conditions.

    Conditions:
    - Shapes match and are 1D.
    - 0 <= lower <= upper <= 1 elementwise.
    - sum(lower) <= 1 <= sum(upper).
    """
    lower = np.asarray(lower, dtype=float)
    upper = np.asarray(upper, dtype=float)

    same_shape_msg = "Lower and upper must have the same shape."
    one_dim_msg = "Lower and upper must be 1D arrays (one entry per class)."
    bounds_range_msg = "Bounds must satisfy 0 <= lower and upper <= 1 for all classes."
    order_msg = "Each lower bound must be <= the corresponding upper bound."
    sum_lower_msg = "Sum of lower bounds must be <= 1."
    sum_upper_msg = "Sum of upper bounds must be >= 1."

    if lower.shape != upper.shape:
        raise ValueError(same_shape_msg)
    if lower.ndim != 1:
        raise ValueError(one_dim_msg)

    if np.any(lower < 0.0) or np.any(upper > 1.0):
        raise ValueError(bounds_range_msg)
    if np.any(lower > upper):
        raise ValueError(order_msg)

    sum_lower = float(lower.sum())
    sum_upper = float(upper.sum())
    if sum_lower > 1.0 + 1e-8:
        raise ValueError(sum_lower_msg)
    if sum_upper < 1.0 - 1e-8:
        raise ValueError(sum_upper_msg)


# Demonstrate one valid and one invalid example
print("Valid intervals example:")
validate_probability_intervals(lower_bounds, upper_bounds)
print("  OK: intervals are valid.\n")

# Invalid example: sum of lower bounds is too large
lower_bad = np.array([0.5, 0.6, 0.2], dtype=float)  # sums to 1.3 (> 1)
upper_bad = np.array([0.7, 0.8, 0.9], dtype=float)

print("Invalid intervals example:")
try:
    validate_probability_intervals(lower_bad, upper_bad)
except ValueError as e:
    print("  Error:", e)

Valid intervals example:
  OK: intervals are valid.

Invalid intervals example:
  Error: Sum of lower bounds must be <= 1.


## Membership check (is a distribution inside the credal set?)

Once we have interval bounds `lower` and `upper`, we can ask whether a candidate probability vector `p` is **compatible** with them.
In this simple view, membership means that `p` is a valid probability distribution and, for every class, it stays inside the interval:

- `p[i]` must satisfy \(\text{lower}_i \leq p[i] \leq \text{upper}_i\) for all classes `i`.

This is the basic "is-in-check" functionality discussed in the sprint: a lightweight way to test whether a single distribution lies inside the bounds-based credal set.

In [11]:
def validate_prob_vector(p: np.ndarray) -> None:
    """Validate that p is a single probability vector.

    Conditions:
    - 1D array.
    - Nonnegative entries.
    - Sum equals 1 (within numerical tolerance).
    """
    p = np.asarray(p, dtype=float)

    one_dim_msg = "p must be a 1D probability vector."
    nonneg_msg = "Probabilities must be nonnegative."
    sum_msg = "Probabilities must sum to 1."

    if p.ndim != 1:
        raise ValueError(one_dim_msg)
    if np.any(p < 0.0):
        raise ValueError(nonneg_msg)
    if not np.isclose(p.sum(), 1.0):
        raise ValueError(sum_msg)


def is_in_intervals(p: np.ndarray, lower: np.ndarray, upper: np.ndarray) -> bool:
    """Check if p lies inside the probability intervals credal set.

    This reuses the interval validation and then checks elementwise
    lower <= p <= upper.
    """
    validate_probability_intervals(lower, upper)
    validate_prob_vector(p)

    p = np.asarray(p, dtype=float)
    lower = np.asarray(lower, dtype=float)
    upper = np.asarray(upper, dtype=float)

    return bool(np.all(p >= lower - 1e-8) and np.all(p <= upper + 1e-8))


# Example distributions
p_inside = np.array([0.3, 0.5, 0.2], dtype=float)  # fits inside the example intervals
p_outside = np.array([0.7, 0.2, 0.1], dtype=float)  # first entry exceeds upper bound

print("Membership checks:")
print("  p_inside:", p_inside, "->", is_in_intervals(p_inside, lower_bounds, upper_bounds))
print("  p_outside:", p_outside, "->", is_in_intervals(p_outside, lower_bounds, upper_bounds))

Membership checks:
  p_inside: [0.3 0.5 0.2] -> True
  p_outside: [0.7 0.2 0.1] -> False


## Sampling for intuition

Before implementing more advanced operations on credal sets, it can be very helpful just to **sample** from them.
Sampling gives a concrete feel for "what is inside" the credal set without requiring full analytical formulas.

In this simple notebook, we:
- Sample a random ensemble member (a random row of `P`) from a **discrete credal set**.
- Sample from **probability intervals** using basic rejection sampling, relying only on NumPy.

In [12]:
def sample_from_discrete_credal_set(p_members: np.ndarray, rng: np.random.Generator) -> np.ndarray:
    """Return a random row of P (one ensemble member)."""
    p_members = np.asarray(p_members, dtype=float)
    two_d_msg = "P must be a 2D array: (num_members, num_classes)."
    if p_members.ndim != 2:
        raise ValueError(two_d_msg)
    num_members = p_members.shape[0]
    idx = rng.integers(0, num_members)
    return p_members[idx]


def sample_from_probability_intervals(
    lower: np.ndarray, upper: np.ndarray, rng: np.random.Generator, max_tries: int = 10000
) -> np.ndarray:
    """Sample a probability vector from the intervals via simple rejection.

    The procedure is:
    1. Validate the intervals.
    2. Repeatedly draw a random probability vector from a symmetric Dirichlet
       (uniform over the simplex) using NumPy only.
    3. Accept the sample if all coordinates lie within [lower, upper].

    This is not optimized but is simple and good for building intuition.
    """
    validate_probability_intervals(lower, upper)
    lower = np.asarray(lower, dtype=float)
    upper = np.asarray(upper, dtype=float)
    k = lower.shape[0]

    for _ in range(max_tries):
        candidate = rng.dirichlet(np.ones(k, dtype=float))
        if np.all(candidate >= lower - 1e-8) and np.all(candidate <= upper + 1e-8):
            return candidate

    error_msg = "Failed to find a sample inside the intervals. Try loosening bounds or increasing max_tries."
    raise RuntimeError(error_msg)


# Demonstrate sampling (rng defined in Basic Experimentation section)
sampled_discrete = sample_from_discrete_credal_set(P, rng)
sampled_interval = sample_from_probability_intervals(lower_bounds, upper_bounds, rng)

validate_prob_vector(sampled_discrete)
validate_prob_vector(sampled_interval)

print("Sample from discrete credal set (random ensemble member):")
print("  ", sampled_discrete)
print()
print("Sample from ProbabilityIntervals credal set:")
print("  ", sampled_interval)

Sample from discrete credal set (random ensemble member):
   [0.4 0.4 0.2]

Sample from ProbabilityIntervals credal set:
   [0.4080684  0.22291303 0.36901857]


## Shape and axis conventions

In the Probly codebase, discrete credal sets are represented using NumPy arrays with shapes that look like `(..., num_members, num_classes)`.
- The **last axis** (size = `num_classes`) holds the class probabilities.
- The axis just before that (size = `num_members`) indexes the different ensemble members (or distributions) for the same input.

The lower and upper envelopes are then computed by reducing (taking `min`/`max`) over the **member axis**.
In Probly, the discrete credal set representation uses arrays shaped `(..., num_members, num_classes)` and computes envelopes by reducing over the member axis (`axis=-2`).
For our 2D example `P.shape == (num_members, num_classes)`, `axis=0` is equivalent to `axis=-2`.
Our small example matrix `P` already follows this idea in a very simple way: it is just `(num_members, num_classes)`.

In [13]:
print("Shape of P (members x classes):", P.shape)
print("Shape of lower (per-class):", lower.shape)
print("Shape of upper (per-class):", upper.shape)
print()
print("In this notebook, axis 0 of P indexes ensemble members,")
print("and axis 1 indexes classes. Lower/upper are 1D over the class axis.")

Shape of P (members x classes): (5, 3)
Shape of lower (per-class): (3,)
Shape of upper (per-class): (3,)

In this notebook, axis 0 of P indexes ensemble members,
and axis 1 indexes classes. Lower/upper are 1D over the class axis.


In [None]:
P3 = P[None, :, :]
lower3 = P3.min(axis=-2)
upper3 = P3.max(axis=-2)

print("P3 = P[None, :, :] shape:", P3.shape)
print("lower3 = P3.min(axis=-2) shape:", lower3.shape)
print("upper3 = P3.max(axis=-2) shape:", upper3.shape)
print()
print("lower3 (same as lower):", lower3)
print("upper3 (same as upper):", upper3)

P3 = P[None, :, :] shape: (1, 5, 3)
lower3 = P3.min(axis=-2) shape: (1, 3)
upper3 = P3.max(axis=-2) shape: (1, 3)

lower3 (same as lower): [[0.4 0.2 0.1]]
upper3 (same as upper): [[0.7 0.4 0.2]]
