An ensemble is a collection of multiple models with the same architecture.
Each model makes predictions independently, and their outputs are typically aggregated—for example by averaging or voting—to produce a more robust final prediction.

model 1 ─┐
model 2 ─┼──> combined → more stable predictions
model 3 ─┘

In this transformation, Flax models are specifically duplicated and combined into an ensemble.
Given a Flax model such as:

In [None]:
from flax.nnx import Module, Param
import jax.numpy as jnp


class SimpleModel(Module):
    def __init__(self) -> None:
        """Initialize the model parameters."""
        self.w = Param(jnp.array([1.0, 2.0, 3.0]))
        self.b = Param(jnp.array(0.5))

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Compute the linear transformation w * x + b."""
        return self.w * x + self.b

we can create an ensemble with three copies from it.
To do this, we call the transformation as follows:

In [2]:
from probly.transformation.ensemble.flax import generate_flax_ensemble

model = SimpleModel()

ensemble = generate_flax_ensemble(model, 3)