# A Brief Introduction to Ensemble Transformation

The goal of this notebook is to showcase the `ensemble` function, which is based on the provided code files (`common.py`, `torch.py`, `__init__.py`). We will demonstrate how to use it to create an ensemble of models from a base PyTorch model.

#### But what is an Ensemble?
An *Ensemble* in machine learning is usually a method that uses a finite set of learning algorithms. Instead of relying on a single model, an ensemble combines the results of several models to create a better result.
Those different models usually slightly differ in their parameters.
This often improves robustness and helps quantify **uncertainty** — for example, when the ensemble members disagree, we know the model is unsure.  
In `probly`, the `ensemble` transformation automates the creation of such model collections directly from a base PyTorch model. <br>
**Why use Ensemble?** <br>
Because it saves you from having to manually copy, reset, and manage multiple model instances yourself. `ensemble` does all of that for you automatically

In [15]:
from collections.abc import Callable
from typing import TYPE_CHECKING

import torch
from torch import nn

from lazy_dispatch import lazydispatch
from lazy_dispatch.isinstance import LazyType
from probly.transformation.ensemble import ensemble
from pytraverse import singledispatch_traverser, traverse

if TYPE_CHECKING:
    from collections.abc import Callable

    from lazy_dispatch.isinstance import LazyType
    from probly.predictor import Predictor

In [None]:
from __future__ import annotations  # noqa: F404


# Our base model
class SimpleNet(nn.Module):
    def __init__(self) -> None:
        """Initialize the network architecture with two linear layers."""
        super().__init__()
        self.layer1 = nn.Linear(10, 20)
        self.layer2 = nn.Linear(20, 2)

    def forward(self, x):  # noqa: ANN201, ANN001
        """Run a forward pass of the network."""
        x = torch.relu(self.layer1(x))
        return self.layer2(x)

    def reset_parameters(self) -> None:
        """Reset the parameters of all layers."""
        # Important for Deep Ensembles: a custom reset function
        print("Custom reset_parameters of SimpleNet called!")
        self.layer1.reset_parameters()
        self.layer2.reset_parameters()


base_model = SimpleNet()
print(f"Base model weight (layer1): {base_model.layer1.weight.data[0, 0]}")
print("--- Creating ensemble with reset (default) ---")
base_model_2 = SimpleNet()
base_weight = base_model_2.layer1.weight.data[0, 0].item()
print(f"Base model weight: {base_weight}")

# ensemble is the function from common.py
# reset_params=True is the default
model_ensemble = ensemble(base_model_2, num_members=3, reset_params=True)

print(f"\nEnsemble created: {type(model_ensemble)}")
print(f"Number of members: {len(model_ensemble)}")

# Let's compare the weights to see if they are different
weight0 = model_ensemble[0].layer1.weight.data[0, 0].item()
weight1 = model_ensemble[1].layer1.weight.data[0, 0].item()
weight2 = model_ensemble[2].layer1.weight.data[0, 0].item()

print(f"Weight of member 0: {weight0}")
print(f"Weight of member 1: {weight1}")
print(f"Weight of member 2: {weight2}")

if base_weight not in (weight0, weight1) and weight0 != weight1:
    print("All weights are different (as expected).")

Base model weight (layer1): 0.053378403186798096
--- Creating ensemble with reset (default) ---
Base model weight: 0.18488982319831848
Custom reset_parameters of SimpleNet called!
Custom reset_parameters of SimpleNet called!
Custom reset_parameters of SimpleNet called!

Ensemble created: <class 'torch.nn.modules.container.ModuleList'>
Number of members: 3
Weight of member 0: -0.11245134472846985
Weight of member 1: -0.2733441889286041
Weight of member 2: 0.1302759349346161
All weights are different (as expected).


## 1. Setup: Dependencies and Code Definitions

Before we can use the `ensemble` function, we'll define it and its components as described in the code files (`common.py`, `torch.py`, `__init__.py`). This ensures this notebook is self-contained.

We will copy the contents of the provided files here and adjust the relative imports.
We use three components: `common.py` for dispatching and `torch.py` for generating an ensemble. `init.py`connects the generic ensemnle logic with the PyTorch implementation. All components will be explained further below.

To prevent errors, run the cells in the given order.

### 1.1 How `__init__.py` connects Ensemble Components
The `__init__.py` file defines the public interface of the ensemble module.
It re-exports the main functions `ensemble` and `register` so they can be imported directly from `probly.ensemble`.
Additionally, it performs a *lazy registration* of the PyTorch backend, which means the Torch implementation is only loaded when a Torch model is actually used.
This design avoids unnecessary imports, prevents circular dependencies, and keeps the package lightweight and modular.

*NOTE:* Since we are in a notebook, we are performing the registration
(which already happened in the 'torch.py' cell) explicitly.
The 'delayed_register' logic is not needed here because we loaded 'torch.py'
directly.

### 1.2 How `common.py` implements the Main Logic
The `common.py` file defines the core logic of the ensemble module.
It introduces a generic dispatcher called `ensemble_generator`, which dynamically selects the correct ensemble creation function based on the model type.
The register function allows developers to link new model types (such as PyTorch or custom predictors) to their specific generator implementations.
Finally, the `ensemble()` function provides a simple, user-facing API that hides the dispatch mechanism and automatically calls the right generator.
Together, these components make the ensemble system flexible and easily extensible to other frameworks.

In [None]:
from __future__ import annotations  # noqa: E402, F404, RUF100

"""Shared ensemble implementation."""

from typing import TYPE_CHECKING  # noqa: E402

if TYPE_CHECKING:
    from collections.abc import Callable

    from lazy_dispatch.isinstance import LazyType
    from probly.predictor import Predictor


@lazydispatch
def ensemble_generator[In, KwIn, Out](base: Predictor[In, KwIn, Out]) -> Predictor[In, KwIn, Out]:
    """Generate an ensemble from a base model."""
    msg = f"No ensemble generator is registered for type {type(base)}"
    raise NotImplementedError(msg)


def register(cls: LazyType, generator: Callable) -> None:
    """Register a class which can be used as a base for an ensemble."""
    ensemble_generator.register(cls=cls, func=generator)


def ensemble[T: Predictor](base: T, n_members: int, reset_params: bool = True) -> T:
    """Create an ensemble predictor from a base predictor.

    Args:
        base: Predictor, The base model to be used for the ensemble.
        n_members: The number of members in the ensemble.
        reset_params: Whether to reset the parameters of each member.

    Returns:
        Predictor, The ensemble predictor.
    """
    return ensemble_generator(base, n_members=n_members, reset_params=reset_params)

### 1.3 The Torch Implementation: `torch.py`
The `torch.py` file implements the ensemble generator specifically for PyTorch models.
It uses the `pytraverse` library to recursively clone neural networks and optionally reset their parameters using each layer’s `reset_parameters()` method.
The main function, `generate_torch_ensemble()`, creates multiple independent copies of a given base model and returns them as an nn.ModuleList.
This ensures that each ensemble member has its own parameters, allowing the ensemble to represent independent model instances.
At the end, the PyTorch generator is registered with the dispatcher using register(`nn.Module`, `generate_torch_ensemble`), linking it seamlessly to the common interface.

In [None]:
# Required third-party dependencies
from collections.abc import Callable  # noqa: TC003
from typing import TYPE_CHECKING

from torch import nn

from lazy_dispatch import lazydispatch
from lazy_dispatch.isinstance import LazyType  # noqa: TC001


# --- Mocks for 'probly' dependencies that were not provided ---
class Predictor:
    """Mock class for probly.predictor.Predictor."""


TORCH_MODULE = nn.Module  # Mock for probly.lazy_types.TORCH_MODULE

# Mock for probly.traverse_nn.nn_traverser
nn_traverser = singledispatch_traverser[nn.Module](name="nn_traverser")


@nn_traverser.register(nn.Module)
def _nn_traverse_default(obj: nn.Module, traverse: traverse) -> nn.Module:  # type: ignore  # noqa: PGH003
    """Default traverser that maps children."""
    return traverse.map_children(obj)

In [None]:
from __future__ import annotations  # noqa: F404

import copy

reset_traverser = singledispatch_traverser[nn.Module](name="reset_traverser")


@reset_traverser.register
def _(obj: nn.Module) -> nn.Module:
    if hasattr(obj, "reset_parameters"):
        obj.reset_parameters()  # type: ignore[operator]
    return obj


def _copy(module: nn.Module) -> nn.Module:
    # simple deep copy without relying on the nn_traverser mock
    return copy.deepcopy(module)


def _reset_copy(module: nn.Module) -> nn.Module:
    cloned = _copy(module)
    for m in cloned.modules():
        if hasattr(m, "reset_parameters"):
            m.reset_parameters()
    return cloned


def generate_torch_ensemble(
    obj: nn.Module,
    n_members: int,
    reset_params: bool = True,
) -> nn.ModuleList:
    """Build a torch ensemble by copying the base model n_members times, resetting the parameters of each member."""
    if reset_params:
        return nn.ModuleList([_reset_copy(obj) for _ in range(n_members)])
    return nn.ModuleList([_copy(obj) for _ in range(n_members)])


register(nn.Module, generate_torch_ensemble)

## 2. The Problem: Manual Ensemble Creation

Let's say we have a base model in PyTorch and want to create a "Deep Ensemble" for uncertainty quantification. For this, we need several copies of this model, each of which must have different initialized weights.

The naive approach would be to manually copy the model and reset the parameters. This can be tedious, especially with complex, nested models.

In [None]:
# Our base model
class SimpleNet(nn.Module):
    def __init__(self) -> None:
        """Initialize the network architecture with two linear layers."""
        super().__init__()
        self.layer1 = nn.Linear(10, 20)
        self.layer2 = nn.Linear(20, 2)

    def forward(self, x):  # noqa: ANN201, ANN001
        """Run a forward pass of the network."""
        x = torch.relu(self.layer1(x))
        return self.layer2(x)

    def reset_parameters(self) -> None:
        """Reset the parameters of all layers."""
        # Important for Deep Ensembles: a custom reset function
        self.layer1.reset_parameters()
        self.layer2.reset_parameters()


base_model = SimpleNet()


# Manual approach
n_members = 3
manual_ensemble = []
for _ in range(n_members):
    model_copy = copy.deepcopy(base_model)
    # We have to remember to reset the parameters manually
    if hasattr(model_copy, "reset_parameters"):
        model_copy.reset_parameters()
    manual_ensemble.append(model_copy)

This works, but it's cumbersome. We have to use `copy.deepcopy` and manually check for a `reset_parameters` method.

The `torch.py` code automates this. The `_reset_copy` function uses `pytraverse` to recursively traverse the module and call `reset_parameters()` on every submodule that has it. This is a perfect example of the separation of concerns shown in `pytraverse_tutorial.ipynb` (cell `5ae1e551`): the traversal logic is separate from the reset logic.
In the next section, we will see how the `ensemble()` function automates this entire process, making ensemble creation both cleaner and safer.

## 3. The Automated Solution: The ensemble() Function

The `ensemble` function from `common.py` is a `lazydispatch` wrapper. It automatically selects the correct generator based on the type of the base model.

Since we registered `generate_torch_ensemble` for `nn.Module` (in the `torch.py` cell above), we can apply the `ensemble` function directly to our `SimpleNet` object.

## Optional Behavior: Cloning Without Reset

The `ensemble` function also accepts `reset_params=False`.

In this case, `generate_torch_ensemble` calls the `_copy` function instead of `_reset_copy`. `_copy` simply uses the `nn_traverser` to clone the module without calling `reset_parameters()`.

In [None]:
print("--- Creating ensemble without reset ---")
base_model_3 = SimpleNet()
base_weight_3 = base_model_3.layer1.weight.data[0, 0].item()
print(f"Base model weight: {base_weight_3}")

# This time we set reset_params to False
copied_ensemble = ensemble(base_model_3, n_members=2, reset_params=False)

weight0 = copied_ensemble[0].layer1.weight.data[0, 0].item()
weight1 = copied_ensemble[1].layer1.weight.data[0, 0].item()

print(f"\nWeight of member 0: {weight0}")
print(f"Weight of member 1: {weight1}")

assert weight0 == base_weight_3  # noqa: S101
assert weight1 == base_weight_3  # noqa: S101
print("All weights are identical to the base model (as expected).")

--- Creating ensemble without reset ---
Base model weight: 0.016505658626556396

Weight of member 0: 0.016505658626556396
Weight of member 1: 0.016505658626556396
All weights are identical to the base model (as expected).


## 4. Summary and Key Takeaway

The `ensemble` function is a powerful dispatcher that abstracts away the complexity of creating model ensembles.

By registering type-specific generators (like `generate_torch_ensemble` for `nn.Module`) with the `ensemble_generator`, it provides a clean, extensible API.

Internally, the PyTorch implementation uses `pytraverse` to efficiently traverse, copy, and optionally reset parameters of the module structure. This demonstrates how the abstract concepts from `pytraverse_tutorial.ipynb` (like `singledispatch_traverser` and `traverse` with `{CLONE: True}`) are used in a real-world application to write robust and maintainable code.

## 5. Further Reading and References
- **Probly documentation:** ***for more information on how everything in probly works*** <br> (https://github.com/pwhofman/probly/tree/main/docs)
- **PyTraverse intro notebook:** ***offers a deep tutorial on how pytraverse automates traversing***<br> (https://github.com/pwhofman/probly/blob/main/notebooks/examples/pytraverse_tutorial.ipynb)
