## Introduction to Ensemble flax


An **ensemble** is a collection of multiple models with the same architecture but different parameter initializations. 
The models make predictions independently. Usually, the results are then averaged.
Instead of a single instance of a neural network with fixed parameters, the ensemble contains multiple independent copies of the same model, whose parameters have been reinitialized.


```text
Input → [Model 1] → \
                   → Mean/Var → Output
Input → [Model 2] → /


**Flax** is a neural network library for JAX that provides flexible and composable tools to build and train deep learning models.

The **param_reset-traverser** searches the model for each **Param** object and replaces its value with newly randomized, normally distributed values using its own RNG key, ensuring that each ensemble copy has completely independent parameters.

The code creates a traverser called **param_reset**, which specifically modifies Param objects while traversing a model. The registered function specifies that for each parameter found, a new random key is first generated, and then its value is replaced with normally distributed random numbers of the same shape; afterward, the updated parameter is returned. The param_reset traverser iterates through a Flax-NNX model:

In [None]:
from _future_ import annotations
from typing import List
from jax import random
from flax.nnx import Module, Param, Rngs
from pytraverse import traverse, CLONE, lazydispatch_traverser
from probly.traverse_nn import nn_compose

# Example: Reset-Traverser
param_reset = lazydispatch_traverser[object]("param_reset")
@param_reset.register
def _(p: Param, rngs: Rngs) -> Param:
 key = rngs.params()
 p.value = random.normal(key, p.value.shape)
 return p

The **_reset_copy** function first creates a new RNG container initialized with the provided seed, which is responsible for parameter reinitialization. Then, using traverse, it creates a deep copy of the model, applying the param_reset traverser to all parameters so that each copy receives completely new, randomly initialized weights.

The **generate_flax_ensemble** function creates an ensemble by copying the given model multiple times, with each copy reinitialized using _reset_copy and a unique seed. This results in multiple independent model instances with different parameter values. By calling register(Module, generate_flax_ensemble), this ensemble generation function is registered in the ProblySystem, allowing it to be automatically used for Flax modules.

In [None]:
# Ensemble
def generate_flax_ensemble(model: Module, n_members: int) -> List[Module]:
 return [_reset_copy(model, seed=i) for i in range(n_members)]
register(Module, generate_flax_ensemble)

## Different Network Architectures

The ensemble mechanism in Flax is model-agnostic — it can be applied to any network architecture that follows the Flax module structure.  
Below, we show two simple examples:
1. A **Linear (MLP)** model for low-dimensional data  
2. A **Convolutional Neural Network (CNN)** for image data  

In [None]:
import flax.linen as nn
import jax.numpy as jnp
from typing import Any

# --- 1. Simple MLP ---
class MLP(nn.Module):
    hidden_dim: int = 64
    output_dim: int = 10

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(self.output_dim)(x)
        return x


# --- 2. Simple CNN ---
class SmallCNN(nn.Module):
    num_classes: int = 10

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = x.mean(axis=(1, 2))  # global average pooling
        x = nn.Dense(self.num_classes)(x)
        return x


Each of these architectures can be combined with the `generate_flax_ensemble()` function
to create ensembles of independent models.

For example:

In [None]:
# Example: create an ensemble of CNNs
cnn_model = SmallCNN(num_classes=10)
ensemble_cnn = generate_flax_ensemble(cnn_model, n_members=3)

# Example: create an ensemble of MLPs
mlp_model = MLP(hidden_dim=64, output_dim=3)
ensemble_mlp = generate_flax_ensemble(mlp_model, n_members=5)

## Example usage - the Flax model:

In [None]:
def flax_model_small_2d_2d(flax_rngs: nnx.Rngs) -> nnx.Module:
 """Return a small linear model with 2 input and 2 output neurons."""
 model = nnx.Sequential(
 nnx.Linear(2, 2, rngs=flax_rngs),
 nnx.Linear(2, 2, rngs=flax_rngs),
 nnx.Linear(2, 2, rngs=flax_rngs),
 )
 return model

Creating an ensemble from a Flax model with 3 copies:

In [None]:
ensemble = generate_flax_ensemble(model, n_members=3)