‚ö†Ô∏è Status: Work in progress ‚Äì bitte noch nicht mergen <br>
!! TODOs (vor Merge l√∂schen) !!
KEY: (Copy and Paste to title, delete before main merge!) <br>
‚ÄºÔ∏è: needs fix<br>
üöß: Started<br>
‚ùå: not finished<br>
‚úÖ: finished<br>

- [x] Grobe Struktur anlegen
- [x] introduction ensemble
- [x] einheitliche √úberschriften
- [ ] deep copy belegen
- [ ] Abweichungen zu Code anmerken oder fixen
- [x] API-Summary (kurz erkl√§ren, wie man Ensemble benutzt)
- [ ] Fehlermeldungen bei nicht registrierung zeigen
- [ ] Einheitliche Abschnitte
- [ ] wenn zeit: Visualisierungen
- [ ] Seed setzen (Startwert f√ºr Zufallsgenerator setzen)

# A Brief Introduction to Ensemble Transformation

The goal of this notebook is to showcase the `ensemble` function, which is based on the provided code files (`common.py`, `torch.py`, `__init__.py`). We will demonstrate how to use it to create an ensemble of models from a base PyTorch model.

# 0. What is an Ensemble?
An *ensemble* in machine learning is usually a method that uses a finite set of learning algorithms. Instead of relying on a single model, ensemble combines the results of several models to create a better result.
Those different models usually slightly differ in their parameters.
This often improves robustness and helps quantify **uncertainty** ‚Äî for example, when the ensemble members disagree, we know the model is unsure.  
In `probly`, the `ensemble` transformation automates the creation of such model collections directly from a base PyTorch model. <br>
**Why use `ensemble`?** <br>
Because it saves you from having to manually copy, reset, and manage multiple model instances yourself. `ensemble` does all of that for you automatically.

## 1. Setup: Dependencies and Code Definitions

Before we can use the `ensemble` function, we'll define it and its components as described in the code files (`common.py`, `torch.py`, `__init__.py`). This ensures this notebook is self-contained.

We will copy the contents of the provided files here and adjust the relative imports.
We use two components:
- `common.py`: defines a generic dispatcher (`ensemble_generator`) and a user-facing API `ensemble()`.
- `torch.py`: registers a PyTorch-specific generator `generate_torch_ensemble()` that clones a model n-times, optionally resetting parameters for each member, and returns an `nn.ModuleList`.

--- REREAD!!! NOT FINISHED YET!
**Compatibility note:** In some versions, `common.ensemble(...)` passes keyword args (`n_members`, `reset_params`) 
to the dispatcher. If your `ensemble_generator` base signature does not accept these kwargs, calling
`ensemble(...)` may raise a `TypeError`. In that case, call the PyTorch generator directly 
(`generate_torch_ensemble(...)`) or update the dispatcher signature in the codebase.

In [None]:
# Required third-party dependencies
import torch
from torch import nn
from typing import TYPE_CHECKING, TypeVar
from collections.abc import Callable
from lazy_dispatch import lazydispatch
from lazy_dispatch.isinstance import LazyType
from pytraverse import CLONE, singledispatch_traverser, traverse, sequential as nn_compose

# --- Mocks for 'probly' dependencies that were not provided ---
class Predictor:
    """Mock class for probly.predictor.Predictor"""
    pass

TORCH_MODULE = nn.Module  # Mock for probly.lazy_types.TORCH_MODULE

# Mock for probly.traverse_nn.nn_traverser
nn_traverser = singledispatch_traverser[nn.Module](name="nn_traverser")

@nn_traverser.register(nn.Module)
def _nn_traverse_default(obj: nn.Module, traverse: traverse) -> nn.Module:
    """Default traverser that maps children."""
    return traverse.map_children(obj)

print("Library dependencies and mocks loaded.")

In [None]:
# --- Content from common.py ---

if TYPE_CHECKING:
    from collections.abc import Callable
    from lazy_dispatch.isinstance import LazyType
    # from probly.predictor import Predictor # Already mocked above

@lazydispatch
def ensemble_generator[In, KwIn, Out](base: Predictor[In, KwIn, Out]) -> Predictor[In, KwIn, Out]:
    """Generate an ensemble from a base model."""
    msg = f"No ensemble generator is registered for type {type(base)}"
    raise NotImplementedError(msg)

def register(cls: LazyType, generator: Callable) -> None:
    """Register a class which can be used as a base for an ensemble."""
    ensemble_generator.register(cls=cls, func=generator)

def ensemble[T: Predictor](base: T, num_members: int, reset_params: bool = True) -> T:
    """Create an ensemble predictor from a base predictor.

    Args:
        base: Predictor, The base model to be used for the ensemble.
        num_members: The number of members in the ensemble.
        reset_params: Whether to reset the parameters of each member.

    Returns:
        Predictor, The ensemble predictor.
    """
    return ensemble_generator(base, num_members=num_members, reset_params=reset_params)

print("Code from 'common.py' loaded.")

In [None]:
# --- Content from torch.py ---

from __future__ import annotations
# from torch import nn # Already imported

# from probly.traverse_nn import nn_compose, nn_traverser # Already mocked above
# from pytraverse import CLONE, singledispatch_traverser, traverse # Already imported

# from .common import register # Adapted to use the global function

reset_traverser = singledispatch_traverser[nn.Module](name="reset_traverser")

@reset_traverser.register
def _(obj: nn.Module) -> nn.Module:
    if hasattr(obj, "reset_parameters"):
        obj.reset_parameters()  # type: ignore[operator]
    return obj

def _reset_copy(module: nn.Module) -> nn.Module:
    return traverse(module, nn_compose(reset_traverser), init={CLONE: True})

def _copy(module: nn.Module) -> nn.Module:
    return traverse(module, nn_traverser, init={CLONE: True})

def generate_torch_ensemble(
    obj: nn.Module,
    num_members: int,
    reset_params: bool = True,
) -> nn.ModuleList:
    """Build a torch ensemble by copying the base model num_members times, resetting the parameters of each member."""
    if reset_params:
        return nn.ModuleList([_reset_copy(obj) for _ in range(num_members)])
    return nn.ModuleList([_copy(obj) for _ in range(num_members)])

register(nn.Module, generate_torch_ensemble)

print("Code from 'torch.py' loaded and generator registered.")

In [None]:
# --- Content from __init__.py ---

from __future__ import annotations
# from probly.lazy_types import TORCH_MODULE # Already mocked above

# from . import common # Adapted, as 'common' is already global
# ensemble = common.ensemble # Already global
# register = common.register # Already global

## Torch
# @common.ensemble_generator.delayed_register(TORCH_MODULE)
# def _(_: type) -> None:
#     from . import torch as torch_impl  # noqa: PLC0415

# NOTE: Since we are in a notebook, we are performing the registration
# (which already happened in the 'torch.py' cell) explicitly.
# The 'delayed_register' logic is not needed here because we loaded 'torch.py'
# directly.

print("Code from '__init__.py' (logically) executed.")

## 3. The Problem: Manual Ensemble Creation

Let's say we have a base model in PyTorch and want to create a "Deep Ensemble" for uncertainty quantification. For this, we need several copies of this model, each of which must have different initialized weights.

The naive approach would be to manually copy the model and reset the parameters. This can be tedious, especially with complex, nested models.

--- NOTE: Maybe break apart the big code chunk into smaller pieces and add explanation? ---

In [None]:
import copy

# Our base model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 20)
        self.layer2 = nn.Linear(20, 2)
    
    def forward(self, x):
        x = torch.relu(self.layer1(x))
        return self.layer2(x)
    
    def reset_parameters(self):
        # Important for Deep Ensembles: a custom reset function
        print("Custom reset_parameters of SimpleNet called!")
        self.layer1.reset_parameters()
        self.layer2.reset_parameters()

base_model = SimpleNet()
print(f"Base model weight (layer1): {base_model.layer1.weight.data[0, 0]}")

# Manual approach
num_members = 3
manual_ensemble = []
for _ in range(num_members):
    model_copy = copy.deepcopy(base_model)
    # We have to remember to reset the parameters manually
    if hasattr(model_copy, 'reset_parameters'):
        model_copy.reset_parameters() #
    manual_ensemble.append(model_copy)

print("\nManual ensemble created.")
print(f"Weight of member 0: {manual_ensemble[0].layer1.weight.data[0, 0]}")
print(f"Weight of member 1: {manual_ensemble[1].layer1.weight.data[0, 0]}")

(Example code output)

Base model weight (layer1): 0.20300185680389404
Custom reset_parameters of SimpleNet called!
Custom reset_parameters of SimpleNet called!
Custom reset_parameters of SimpleNet called!

Manual ensemble created.
Weight of member 0: -0.22176861763000488
Weight of member 1: 0.17613589763641357

This works, but it's cumbersome. We have to use `copy.deepcopy` and manually check for a `reset_parameters` method.

The `torch.py` code automates this. The `_reset_copy` function uses `pytraverse` to recursively traverse the module and call `reset_parameters()` on every submodule that has it. This is a perfect example of the separation of concerns shown in `pytraverse_tutorial.ipynb` (cell `5ae1e551`): the traversal logic is separate from the reset logic.

## 3. The Automated Solution: The ensemble() Function

The `ensemble` function from `common.py` is a `lazydispatch` wrapper. It automatically selects the correct generator based on the type of the base model.

Since we registered `generate_torch_ensemble` for `nn.Module` (in the `torch.py` cell above), we can apply the `ensemble` function directly to our `SimpleNet` object.

In [None]:
# Important: To prevent NameError, please run above cells first!
print("--- Creating ensemble with reset (default) ---")
base_model_2 = SimpleNet()
base_weight = base_model_2.layer1.weight.data[0, 0].item()
print(f"Base model weight: {base_weight}")

# ensemble is the function from common.py
# reset_params=True is the default
model_ensemble = ensemble(base_model_2, num_members=3, reset_params=True)

print(f"\nEnsemble created: {type(model_ensemble)}")
print(f"Number of members: {len(model_ensemble)}")

# Let's compare the weights to see if they are different
weight0 = model_ensemble[0].layer1.weight.data[0, 0].item()
weight1 = model_ensemble[1].layer1.weight.data[0, 0].item()
weight2 = model_ensemble[2].layer1.weight.data[0, 0].item()

print(f"Weight of member 0: {weight0}")
print(f"Weight of member 1: {weight1}")
print(f"Weight of member 2: {weight2}")

if weight0 != base_weight and weight1 != base_weight and weight0 != weight1:
    print("All weights are different (as expected).")

--- Creating ensemble with reset (default) ---


NameError: name 'SimpleNet' is not defined

(Example code output)

--- Creating ensemble with reset (default) ---
Base model weight: -0.27976858615875244
Custom reset_parameters of SimpleNet called!
Custom reset_parameters of SimpleNet called!
Custom reset_parameters of SimpleNet called!

Ensemble created: <class 'torch.nn.modules.container.ModuleList'>
Number of members: 3

Weight of member 0: 0.1681283712387085
Weight of member 1: -0.06659793853759766
Weight of member 2: -0.19827675819396973

All weights are different (as expected).

## Optional Behavior: Cloning Without Reset

The `ensemble` function also accepts `reset_params=False`.

In this case, `generate_torch_ensemble` calls the `_copy` function instead of `_reset_copy`. `_copy` simply uses the `nn_traverser` to clone the module without calling `reset_parameters()`.

In [None]:
print("--- Creating ensemble without reset ---")
base_model_3 = SimpleNet()
base_weight_3 = base_model_3.layer1.weight.data[0, 0].item()
print(f"Base model weight: {base_weight_3}")

# This time we set reset_params to False
copied_ensemble = ensemble(base_model_3, num_members=2, reset_params=False)

weight0 = copied_ensemble[0].layer1.weight.data[0, 0].item()
weight1 = copied_ensemble[1].layer1.weight.data[0, 0].item()

print(f"\nWeight of member 0: {weight0}")
print(f"Weight of member 1: {weight1}")

assert weight0 == base_weight_3
assert weight1 == base_weight_3
print("All weights are identical to the base model (as expected).")

--- Creating ensemble without reset ---
Base model weight: -0.06256282329559326

Weight of member 0: -0.06256282329559326
Weight of member 1: -0.06256282329559326

All weights are identical to the base model (as expected).

## 5. Verifying Ensemble Properties
**TO-DO**: Proofing Deep Copy with set parameters


## 6. Summary and Key Takeaway

The `ensemble` function is a powerful dispatcher that abstracts away the complexity of creating model ensembles.

By registering type-specific generators (like `generate_torch_ensemble` for `nn.Module`) with the `ensemble_generator`, it provides a clean, extensible API.

Internally, the PyTorch implementation uses `pytraverse` to efficiently traverse, copy, and optionally reset parameters of the module structure. This demonstrates how the abstract concepts from `pytraverse_tutorial.ipynb` (like `singledispatch_traverser` and `traverse` with `{CLONE: True}`) are used in a real-world application to write robust and maintainable code.

## 7. Further Reading and References
**TO DO:** add references to pytraverse intro and probly documention
- **PyTorch:** `nn.ModuleList`, `reset_parameters()`
- **Probly documentation:**  (https://github.com/pwhofman/probly/tree/main/docs)
- **PyTraverse intro notebook:** (https://github.com/pwhofman/probly/blob/main/notebooks/examples/pytraverse_tutorial.ipynb)
