# Ensemble Transformation

This notebook is a practical introduction to the Ensemble transformation in `probly`. Deep Ensembles are one of the most robust and high-performing methods for uncertainty quantification.

We will start by explaining the core idea behind Deep Ensembles and see how  `probly`'s transformation enables you to create them. We will then walk through a PyTorch example to see how to train the ensemble and use the disagreement between its members to estimate predictive uncertainty.



---
## Part A: Introduction to Ensembles and the Ensemble Transformation
---


## 1. Concept: What is a Deep Ensemble?
### 1.1 The Core Idea: Wisdom of the Crowd

The idea behind an ensemble is simple and powerful: instead of relying on the prediction of a single model, we train multiple models independently
and aggregate their predictions. The core principle is that if we have a diverse set of "experts" (the models), their collective judgment will be
more robust and reliable than any single expert's.

### 1.2 Deep Ensembles for Uncertainty

In Deep Learning, a **Deep Ensemble** consists of multiple neural networks. To create a diverse ensemble, each network is
trained from a different random initialization.

When we give the same input to every model in the ensemble, we will get a set of different predictions.

- The mean of these predictions gives us a robust final prediction.
- The variance (or disagreement) among these predictions gives us a direct and high-quality measure of the model's uncertainty.
If all models agree, uncertainty is low. If they disagree significantly, uncertainty is high.

While very effective, creating and managing deep ensembles manually can be cumbersome.

### 1.3 The Ensemble Transformation

The Ensemble transformation in `probly` automates the creation and management of a deep ensemble.

The transformation does the following:

- It takes a single, user-defined base model as a template.
- It creates a specified number of deep copies of this model.
- Crucially, it re-initializes the parameters of each copy, so every model in the ensemble starts from a different random state.
- It packages all these independent models into a single torch.nn.ModuleList.

This provides a convenient way to train and query all ensemble members simultaneously.

### 1.4. What that entails
| Aspect                       |Ensemble Transformation in `probly`                                                |
|------------------------------|--------------------------------------------------------                          |
| **Main Idea**                | "Wisdom of the crowd"                                                            |
| Stochastic Element           | Disagreement between multiple independent models.                                |
| Architectural Change         | Creates an nn.ModuleList of cloned and re-initialized models.                    |
| Uncertainty Interpretation   | A very strong and robust measure of model uncertainty.                           |
| Training Cost                |High (training N models)                                                          |


## 2. Ensemble Quickstart (PyTorch)

Below: build a small MLP, apply `ensemble(model)`.

In [1]:
import torch
from torch import nn

from probly.transformation import ensemble


def build_mlp(in_dim: int = 10, hidden: int = 32, out_dim: int = 1) -> nn.Sequential:
    return nn.Sequential(
        nn.Linear(in_dim, hidden),
        nn.ReLU(),
        nn.Linear(hidden, out_dim),
    )


model = build_mlp()
print("Original model:\n", model)

# Apply the Ensemble transformation
num_models = 5
ensemble_model = ensemble(model, num_members=num_models)
print(f"\nEnsemble with {num_models} members:\n", ensemble_model)

Original model:
 Sequential(
  (0): Linear(in_features=10, out_features=32, bias=True)
  (1): ReLU()
  (2): Linear(in_features=32, out_features=1, bias=True)
)

Ensemble with 5 members:
 ModuleList(
  (0-4): 5 x Sequential(
    (0): Linear(in_features=10, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=1, bias=True)
  )
)


#### Notes on the structure

- Notice that the ensemble_model is a ModuleList containing 5 independent copies of the original Sequential model.
- By default, probly has already re-initialized the weights of each of these 5 models so they are different from one another.

## 3. Training and Uncertainty with an Ensemble

Training an ensemble involves feeding the data to all members and aggregating their losses. At inference time, we aggregate their predictions to get a final output and an uncertainty score.

In [None]:
# Toy regression data
torch.manual_seed(0)
n = 128
X = torch.randn(n, 10)
true_w = torch.randn(10, 1)
y = X @ true_w + 0.1 * torch.randn(n, 1)

# Build and transform the model
model = build_mlp(in_dim=10, hidden=64, out_dim=1)
ensemble_model = ensemble(model, num_members=5)

# Training loop for an ensemble
# We can use a single optimizer for all parameters
all_params = [p for member in ensemble_model for p in member.parameters()]
opt = torch.optim.Adam(all_params, lr=1e-2)
loss_fn = nn.MSELoss()

for _step in range(200):
    opt.zero_grad()

    total_loss = 0
    # Get a prediction from each member
    for member in ensemble_model:
        pred = member(X)
        total_loss += loss_fn(pred, y)

    # Average the loss and backpropagate
    avg_loss = total_loss / len(ensemble_model)
    avg_loss.backward()
    opt.step()


# Prediction function for an ensemble
@torch.no_grad()
def ensemble_predict(
    ensemble: nn.ModuleList,
    inputs: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    preds = []
    # Get a prediction from each member
    for member in ensemble:
        preds.append(member(inputs).detach())

    stacked = torch.stack(preds, dim=0)  # [n_members, N, out_dim]
    mean = stacked.mean(dim=0)
    var = stacked.var(dim=0, unbiased=False)
    return mean, var


mean_pred, var_pred = ensemble_predict(ensemble_model, X[:5])
print("Predictive mean (first 5):\n", mean_pred.squeeze())
print("\nPredictive variance (first 5):\n", var_pred.squeeze())

## 4. Part A Summary

In Part A, we introduced Deep Ensembles as a robust method for uncertainty quantification based on the "wisdom of the crowd."
We saw that probly's ensemble transformation simplifies the creation of an ensemble by automatically cloning and re-initializing a base model,
packaging the members into a convenient nn.ModuleList. Unlike other transformations that modify a model's internal layers, the ensemble method
creates multiple, independent models. The disagreement in their predictions provides a powerful and high-quality estimate of model uncertainty.

---

## Part B — Applied Ensemble Transformation
---

In **Part A**, we learned what the **Ensemble transformation** in `probly` does.
In this **Part B** , we will apply it to a classification model, get predictions from each member, and visualize the resulting uncertainty.


An in depth walkthrough of:

- How to apply the ensemble function to a real model.

- How to train all the models in the ensemble on a real dataset   (FashionMNIST).

- How to use the trained ensemble to make predictions.

- Most importantly: How to use the disagreement between the models to measure uncertainty and detect Out-of-Distribution (OOD) data—in this case, telling the difference between clothing (FashionMNIST) and handwritten digits (MNIST).

Can be found here:
[FashionMNIST Out-of-Distribution Example](fashionmnist_ood_ensemble.ipynb).

## Final Summary — Ensemble Transformation Tutorial
You have now learned the fundamentals of the **Ensemble Transformation** in `probly`—how to automatically create an `nn.ModuleList` of independent models to capture predictive uncertainty from their disagreement. We saw that this allows us to get both a robust mean prediction and a measure of the model's confidence.

To put this theory into practice, the [FashionMNIST OOD Example](fashionmnist_ood_ensemble.ipynb) will show you how to apply this technique to a real-world problem: building a classifier that knows when it's uncertain and can detect out-of-distribution data.