# Sub-Ensembles for Fast Uncertainty Estimation

## Introduction

A common approach to obtaining high-quality estimates are **Full Ensembles**, where multiple neural networks are trained independently and their predictions are combined. While Full Ensembles are simple and highly effective, they come with substantial computational and memory cost: each model requires full training and storage. This makes them difficult to use in resource-constrained environments or during fast model development cycles.

To address these limitations, recent research
> [Deep Sub-Ensembles for Fast Uncertainty Estimation in Image Classification](https://arxiv.org/pdf/1910.08168)

has proposed **Sub-Ensembles**, a technique that retains the predictive benefits of Full Ensembles while drastically reducing training time and interference overhead. Instead of training multiple full models, Sub-Ensembles share a large portion of the network (the backbone) and create several lightweight, partially independent branches (heads). These branches act as an ensemble member, providing diversity at a fraction of the computational cost.

The goal of the notebook is to:
- introduce the core idea behind (Sub-)Ensembles,
- evaluate their ability to estimate predictive uncertainty,
- demonstrate how they can be implemented in practice.

## What are Full Ensembles?

### Ensembles

offer a method where multiple models are trained independently and their predictions are combined to improve uncertainty quantification. The idea is that by aggregating the outputs of several models, we can reduce variance, improve robustness and achieve better generalization than a single model. Ensembles are particularly effective where individual models are prone to overfitting or high variance.

In [None]:
from matplotlib import patches
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(18, 5))
ax.set_xlim(0, 18)
ax.set_ylim(0, 5)
ax.axis("off")


def draw_block(x: float, y: float, width: float, height: float, text: str, color: str = "lightblue") -> None:
    rect = patches.FancyBboxPatch((x, y), width, height, boxstyle="round,pad=0.02", edgecolor="black", facecolor=color)
    ax.add_patch(rect)
    ax.text(x + width / 2, y + height / 2, text, ha="center", va="center", fontsize=10)


def draw_arrow(x1: float, y1: float, x2: float, y2: float) -> None:
    ax.annotate("", xy=(x2, y2), xytext=(x1, y1), arrowprops={"arrowstyle": "->", "lw": 2, "color": "black"})


# -----------------------------
# Standard Model
draw_block(0.5, 2, 1, 1, "Input (Image)", color="skyblue")
draw_block(2.5, 2, 1, 1, "Base Model", color="lightgreen")
draw_block(4.5, 2, 1, 1, "Output", color="lightcoral")
draw_arrow(1.5, 2.5, 2.5, 2.5)
draw_arrow(3.5, 2.5, 4.5, 2.5)
ax.text(3.0, 4, "Standard Model", fontsize=12, fontweight="bold", ha="center")

# -----------------------------
# Full Ensemble
ensemble_x = 12
ax.text(ensemble_x + 0.5, 4.5, "Full Ensemble", fontsize=12, fontweight="bold", ha="center")

# Input
draw_block(ensemble_x - 2.0, 2.0, 1, 1, "Input (Image)", color="skyblue")

# Three independent models
for y in [1.25, 2.5, 3.75]:
    draw_arrow(ensemble_x - 1.0, 2.5, ensemble_x, y)
for i, y in enumerate([3.2, 2, 0.8]):
    draw_block(ensemble_x, y, 1, 1, f"Base Model\nT{i + 1}", color="lightgreen")

# Combination
for y in [1.25, 2.5, 3.75]:
    draw_arrow(ensemble_x + 1.0, y, ensemble_x + 2.0, 2.5)
draw_block(ensemble_x + 2.0, 2.0, 1, 1, "Prediction", color="white")

plt.tight_layout()
plt.show()

## What are Sub-Ensembles?

A **Sub-Ensemble** generates multiple subnetworks from a single base model. This can be achieved by dividing a neural network into two subnetworks, the trunk network $T$ and the task network $K$. The full output for an input $x$ then is $K(T(x))$.
Since all subnetworks share the same base feature extractor (**backbone**), instead of training multiple independent neural networks from scratch, the forward pass is much faster compared to evaluating multiple fully independent networks.
Each subnetwork produces its own prediction, and the final output is obtained by averaging across all subnetworks. The **model uncertainty** is served by the variance of the predictions from each subnetwork.

Overall they provide:
- reduced memory and training cost when compared to standard Full Ensembles
- reduced parameter overhead improves predictive performance and reliability of uncertainty estimates
- practical solution for robust uncertainty quantification at scale


In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(20, 5))
ax.set_xlim(0, 20)
ax.set_ylim(0, 5)
ax.axis("off")


def draw_block(x: float, y: float, width: float, height: float, text: str, color: str = "lightblue") -> None:
    rect = patches.FancyBboxPatch((x, y), width, height, boxstyle="round,pad=0.02", edgecolor="black", facecolor=color)
    ax.add_patch(rect)
    ax.text(x + width / 2, y + height / 2, text, ha="center", va="center", fontsize=10)


def draw_arrow(x1: float, y1: float, x2: float, y2: float) -> None:
    ax.annotate("", xy=(x2, y2), xytext=(x1, y1), arrowprops={"arrowstyle": "->", "lw": 2, "color": "black"})


# -----------------------------
# Full Ensemble
ensemble_x = 2.5
ax.text(ensemble_x + 1.5, 4.5, "Full Ensemble", fontsize=12, fontweight="bold", ha="center")

# Input
draw_block(ensemble_x - 2.0, 2.0, 1, 1, "Input (Image)", color="skyblue")

# Three independent models
for y in [1.25, 2.5, 3.75]:
    draw_arrow(ensemble_x - 1.0, 2.5, ensemble_x, y)
for i, y in enumerate([3.2, 2, 0.8]):
    draw_block(ensemble_x, y, 1, 1, f"Trunk  T{i + 1}", color="lightgreen")
    draw_block(ensemble_x + 2.0, y, 1, 1, f"Task T{i + 1}", color="lightcoral")
    draw_arrow(ensemble_x + 1.0, y + 0.5, ensemble_x + 2.0, y + 0.5)

# Combination
for y in [1.25, 2.5, 3.75]:
    draw_arrow(ensemble_x + 3.0, y, ensemble_x + 4.0, 2.5)
draw_block(ensemble_x + 4.0, 2.0, 1, 1, "Combination", color="brown")

# Prediction
draw_arrow(ensemble_x + 5.0, 2.5, ensemble_x + 6.0, 2.5)
draw_block(ensemble_x + 6.0, 2.0, 1, 1, "Prediction", color="white")

# -----------------------------
# SubEnsemble
sub_ensemble_x = 10.0
ax.text(sub_ensemble_x + 4.5, 4.5, "Sub Ensemble", fontsize=12, fontweight="bold", ha="center")

# Input
draw_block(sub_ensemble_x, 2.0, 1, 1, "Input (Image)", color="skyblue")

# Shared Feature Extractor
draw_arrow(sub_ensemble_x + 1.0, 2.5, sub_ensemble_x + 2.0, 2.5)
draw_block(sub_ensemble_x + 2.0, 2.0, 1, 1, "Shared\nfeature\nextractor", color="lightgreen")

# Three task models
for y in [1.25, 2.5, 3.75]:
    draw_arrow(sub_ensemble_x + 3.0, 2.5, sub_ensemble_x + 4.0, y)
for i, y in enumerate([3.2, 2, 0.8]):
    draw_block(sub_ensemble_x + 4.0, y, 1, 1, f"Task  T{i + 1}", color="lightcoral")

# Combination
for y in [1.25, 2.5, 3.75]:
    draw_arrow(sub_ensemble_x + 5.0, y, sub_ensemble_x + 6.0, 2.5)
draw_block(sub_ensemble_x + 6.0, 2.0, 1, 1, "Combination", color="brown")

# Prediction
draw_arrow(sub_ensemble_x + 7.0, 2.5, sub_ensemble_x + 8.0, 2.5)
draw_block(sub_ensemble_x + 8.0, 2.0, 1, 1, "Prediction", color="white")

plt.tight_layout()
plt.show()

## Code Demo

This snippet shows how a simple sequential base model is repurposed into a Sub-Ensemble model. The base model is split into a shared feature extractor and multiple independent heads. The output is obtained by averaging the head predictions during the forward pass.

In [None]:
import copy

import torch
from torch import nn

base_model = nn.Sequential(
    nn.Linear(2, 2),
    nn.Linear(2, 2),
    nn.Linear(2, 2),
)


class SubEnsemble(nn.Module):
    """Sub-Ensemble model that repurposes a base model into multiple heads sharing the same feature extractor.

    Attributes:
        n_heads: int, Number of heads.
        feature_extractor: nn.Sequential, Feature extractor.
        heads: nn.ModuleList, List of heads.
    """

    def __init__(self, model: nn.Module, n_heads: int = 3) -> None:
        """Initializes the SubEnsemble by splitting the provided base model.

        Args:
            model: nn.Module, the sequential base model.
            n_heads: int, optional, number of ensemble heads

        Raises:
            ValueError: If the provided model is not a nn.Sequential or contains fewer than two layers.
        """
        super().__init__()

        self.n_heads = n_heads

        if not isinstance(model, nn.Sequential) or len(model) < 2:
            msg = "Base model must be an nn.Sequential with at least 2 layers."
            raise ValueError(msg)

        # Shared feature extractor = all layers except last
        self.feature_extractor = nn.Sequential(*model[:-1])

        # Last layer of the base model
        last_layer = model[-1]

        # Create multiple head copies
        self.heads = nn.ModuleList()
        for _ in range(n_heads):
            new_head = copy.deepcopy(last_layer)

            # Reinitialize parameters
            for layer in new_head.modules():
                if hasattr(layer, "reset_parameters"):
                    layer.reset_parameters()

            self.heads.append(new_head)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the shared feature extractor and all ensemble heads.

         Returns the mean prediction across heads.

        Args:
            x: torch.Tensor, input tensor.

        Returns:
            torch.Tensor, averaged predictions across all ensemble heads.
        """
        # Shared feature Extractor
        features = self.feature_extractor(x)

        # Each head produces a prediction
        outputs = [head(features) for head in self.heads]

        # Stack + Mean
        return torch.stack(outputs).mean(dim=0)


# Transform base_model into Sub-Ensemble with 3 heads
sub_ensemble = SubEnsemble(base_model, n_heads=3)

# Output
print("Base model:", base_model)
print("SubEnsemble:", sub_ensemble)
print("SubEnsemble type:", type(sub_ensemble))
print("Is Sub-Ensemble?", isinstance(sub_ensemble, SubEnsemble))

Note that the heads can also be created by using the already implemented **generate_torch_ensemble** function from probly

In [None]:
from src.probly.transformation.ensemble.torch import generate_torch_ensemble

head, num_heads = [nn.Linear(2, 2), 3]

heads = generate_torch_ensemble(head, num_heads)

print("(heads):", heads)

SubEnsemble can then be defined like this:

In [None]:
import torch
from torch import nn

from src.probly.transformation.ensemble.torch import generate_torch_ensemble

base_model = nn.Sequential(
    nn.Linear(2, 2),
    nn.Linear(2, 2),
    nn.Linear(2, 2),
)


class SubEnsemble(nn.Module):
    """Sub-Ensemble model that repurposes a base model into multiple heads sharing the same feature extractor.

    Attributes:
        n_heads: int, number of heads.
        feature_extractor: nn.Sequential, shared feature extractor.
        heads: nn.ModuleList, list of heads.
    """

    def __init__(self, model: nn.Module, n_heads: int = 3) -> None:
        """Initializes the SubEnsemble by splitting the provided base model.

        Args:
            model: nn.Module, the sequential base model.
            n_heads: int, optional, number of ensemble heads

        Raises:
            ValueError: If the provided model is not a nn.Sequential or contains fewer than two layers.
        """
        super().__init__()

        self.n_heads = n_heads

        if not isinstance(model, nn.Sequential) or len(model) < 2:
            msg = "Base model must be an nn.Sequential with at least 2 layers."
            raise ValueError(msg)

        # Create feature extractor
        self.feature_extractor = nn.Sequential(*model[:-1])

        # Create n_heads heads
        self.heads = generate_torch_ensemble(model[-1], n_heads)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the shared feature extractor and all ensemble heads.

         Returns the mean prediction across heads.

        Args:
            x: torch.Tensor, input tensor.

        Returns:
            torch.Tensor, averaged predictions across all ensemble heads.
        """
        features = self.feature_extractor(x)
        outputs = [h(features) for h in self.heads]
        return torch.stack(outputs).mean(dim=0)


# Transform base model
sub_ensemble = SubEnsemble(base_model, n_heads=3)

# Output
print("SubEnsemble:", sub_ensemble)

## Evaluation of Sub-Ensembles

To evaluate the benefits of Sub-Ensembles, we conducted a series of experiments on the CIFAR-10 dataset. We compare traditional training approaches - such as single networks - with both Full Ensembles and Sub-Ensembles.

The experiment demonstrates evaluation metrics such as accuracy, confidence, negative log-likelihood (NLL) and training time. This setup provides a clear, reproducible framework to assess the performance of Sub-Ensembles relative to other probabilistic modelling techniques.

### Understanding the metrics

1. **Accuracy**
    - The fraction of correctly classified samples over the total number of samples.
    - Higher accuracy indicates better overall predictive performance.
2. **Confidence**
    - The average predicted probability assigned to the predicted class.
    - Higher confidence means the model is more certain about its predictions, regardless of whether they are correct.
3. **Negative Log-Likelihood** (NLL)
    - Measures how well the predicted probabilities match the true labels.
    - Lower NLL indicates the model assigns higher probability to the correct labels, capturing both correctness and confidence.
4. **Training Time**
    -  The total time (in s) the model takes to complete training.

**Keeping it simple**
- **Accuracy:** how often the model gets the label right,
- **Confidence:** how sure the model is about its predictions,
- **Negative log-likelihood:** how well the modelÂ´s probabilities match the true label,
- **Training Time:** how long the model took to complete training (in seconds).

### The CIFAR experiment

We used a lightweight model designed for CIFAR-10 classification. It consists of two stages:
1. Features:
Two convolutional blocks that map the input images to 32-channel feature maps, extracting low-level features.

2. Classifier:
Two additional convolutional blocks with increasing channel depth (32 $\to$ 64 $\to$ 128) and spatial downsampling via stride 2. The resulting feature maps are globally pooled, flattened and passed through a linear layer to produce predictions for the 10 classes.

Furthermore, 10 epochs were used for training with 5 ensemble members and subensemble heads respectively.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

results = {
    "Base": {
        "Accuracy": 0.7012999999999999,
        "Confidence": 0.7491816083590189,
        "Negative log-likelihood": 0.8654566344579061,
        "train_time": 111.25705416997273,
    },
    "Full Ensemble": {
        "Accuracy": 0.7285333333333334,
        "Confidence": 0.6695056756337484,
        "Negative log-likelihood": 0.7785670175552367,
        "train_time": 565.2635641892751,
    },
    "Sub-Ensemble": {
        "Accuracy": 0.7333333333333334,
        "Confidence": 0.4354797601699829,
        "Negative log-likelihood": 1.0467662965138753,
        "train_time": 229.0507171948751,
    },
}

colors = ["tab:green", "tab:blue", "tab:purple"]
models = list(results.keys())

# Plot 1: Accuracy & Uncertainty Comparison
metrics = ["Accuracy", "Confidence", "Negative log-likelihood"]
x = np.arange(len(metrics))
bar_width = 0.15

fig, ax = plt.subplots(figsize=(8, 5))
for i, m in enumerate(models):
    values = [float(results[m][met]) for met in metrics]
    ax.bar(x + i * bar_width, values, bar_width, color=colors[i], label=m)

ax.set_xticks(x + bar_width * (len(models) - 1) / 2)
ax.set_xticklabels(metrics)
ax.set_ylabel("Value")
ax.set_title("Model Comparison: Accuracy, Uncertainty & NLL")
ax.legend()
plt.show()

# Plot 2: Training Time Comparison
fig, ax = plt.subplots(figsize=(8, 4))
ax.bar(models, [float(results[m]["train_time"]) for m in models], color=colors)
ax.set_ylabel("Training Time (s)")
ax.set_title("Training Time Comparison")
plt.show()

- The **Base model** reaches an accuracy of 70.1%, with relatively high confidence (0.75) and a negative log-likelihood of 0.87, indicating that while it performs reasonably well, it may be overconfident in its predictions.

- The **Full Ensemble** model achieves slightly higher accuracy (72.9%), with moderate confidence (0.67) and the lowest negative log-likelihood (0.78), showing that combining multiple models improves predictive performance und uncertainty estimation.

- The **Sub-Ensemble** model achieves balance between efficiency and performance: accuracy of 73.3%, lower confidence (0.44), and higher negative log-likelihood (1.05), while requiring much less training time (speedup of ~2.46) than the Full Ensemble.

#### Conclusion

Full Ensembles deliver the most robust predictions and the most reliable uncertainty estimates at high computational cost.
Sub-Ensembles, achieve a strong balance between performance and efficiency. They offer a **practical middle ground** with comparable uncertainty quality at a fraction of the training time.

## Live experiment

Using sklearn we can generate a synthetic dataset for a demonstration. This dataset of `3.000 samples` with `20 features` is generated using make_classification, followed by a train-validation split and standardization. A small fully connected neural network serves as the base architecture, which is then transformed into different model variants: Base, Full Ensemble (`3 members`) and Sub-Ensemble (`3 heads`). All models are trained for `100 epochs` and evaluated on the validation set in terms of accuracy, confidence, negative log-likelihood and training time.

#### Imports

Necessary imports:

In [None]:
import time

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from src.probly.transformation.ensemble.torch import generate_torch_ensemble

#### Models

Let`s begin with a base model:

In [None]:
model = nn.Sequential(
    nn.Linear(20, 16),
    nn.ReLU(),
    nn.Linear(16, 16),
    nn.ReLU(),
    nn.Linear(16, 2),
)

We then need to define the class Ensemble which will be done using `generate_torch_ensemble` from probly:

In [None]:
class Ensemble(nn.Module):
    """Sub-Ensemble model that repurposes a base model into multiple Ensemble-members."""

    def __init__(self, model: nn.Module, n_members: int = 3) -> None:
        """Initializes the Ensemble by splitting the provided base model n_members times."""
        super().__init__()
        self.models = generate_torch_ensemble(model, n_members)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with Stack + Mean for all ensemble members."""
        outputs = [m(x) for m in self.models]
        return torch.stack(outputs).mean(dim=0)

Now we can transform the models needed:

In [None]:
base_model = model
ensemble_model = Ensemble(model=model, n_members=3)
sub_ensemble_model = SubEnsemble(model=model, n_heads=3)

#### Data

To generating data!

In [None]:
X, y = make_classification(
    n_samples=3000,
    n_features=20,
    n_informative=5,
    n_redundant=5,
    n_classes=2,
    class_sep=0.5,
    flip_y=0.1,
    random_state=0,
)
X = X.astype("float32")
y = y.astype("int64")

X_train, X_val, y_train, y_val = train_test_split(
    X,
    y,
    test_size=0.2,
    random_state=0,
    stratify=y,
)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)

train_dataset = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
val_dataset = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

criterion = nn.CrossEntropyLoss()

#### What about training?

We begin with a training function for a single model:

In [None]:
def train_single(
    model: nn.Module,
    loader: DataLoader = train_loader,
    criterion: nn.CrossEntropyLoss = criterion,
    epochs: int = 100,
) -> None:
    model.train()

    single_optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    for _ in range(epochs):
        for xb, yb in loader:
            single_optimizer.zero_grad()
            preds = model(xb)
            loss = criterion(preds, yb)
            loss.backward()
            single_optimizer.step()

To train an Ensemble we can call  `train_single` for each ensemble member like this:

In [None]:
def train_ensemble(
    ensemble: Ensemble,
    loader: DataLoader = train_loader,
    criterion: nn.CrossEntropyLoss = criterion,
    epochs: int = 100,
) -> None:
    for _, member in enumerate(ensemble.models):
        train_single(member, loader, criterion, epochs)

Finally, training the SubEnsemble. We will train the feature_extractor together with the first head, while the other heads are frozen. We then freeze the feature_extractor and train the remaining heads.

In [None]:
def train_subensemble(
    subensemble: SubEnsemble,
    loader: DataLoader = train_loader,
    criterion: nn.CrossEntropyLoss = criterion,
    epochs: int = 100,
) -> None:
    subensemble.feature_extractor.train()
    subensemble.heads[0].train()
    # Freeze all heads except first
    for h in subensemble.heads[1:]:
        h.eval()

    optimizer_fe = torch.optim.Adam(
        list(subensemble.feature_extractor.parameters()) + list(subensemble.heads[0].parameters()),
        lr=0.001,
    )
    # Training feature_extractor + first head
    for _ in range(epochs):
        for xb, yb in loader:
            optimizer_fe.zero_grad()
            features = subensemble.feature_extractor(xb)
            preds = subensemble.heads[0](features)
            loss = criterion(preds, yb)
            loss.backward()
            optimizer_fe.step()

    # Freeze feature_extractor
    subensemble.feature_extractor.eval()
    for head in subensemble.heads:
        head.train()
    # Training remaining heads
    optimizer_head = torch.optim.Adam(head.parameters(), lr=0.001)
    for _ in range(epochs):
        for xb, yb in loader:
            optimizer_head.zero_grad()
            with torch.no_grad():
                features = subensemble.feature_extractor(xb)
            loss = 0.0
            for head in subensemble.heads[1:]:
                preds = head(features)
                loss = criterion(preds, yb)
            loss.backward()
            optimizer_head.step()

#### Evaluation

To properly evaluate we define:

In [None]:
def evalu(model: nn.Module, loader: DataLoader = val_loader) -> None:
    model.eval()
    all_probs = []
    all_preds = []
    all_labels = []
    nll = 0.0

    with torch.no_grad():
        for xb, yb in loader:
            preds = model(xb)
            nll += F.cross_entropy(preds, yb, reduction="sum").item()
            probs = F.softmax(preds, dim=1)
            all_probs.append(probs)
            all_preds.append(torch.argmax(probs, dim=1))
            all_labels.append(yb)

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    acc = accuracy_score(all_labels.cpu(), all_preds.cpu())
    confidence = (probs.max(dim=1)).values.mean().item()

    return {
        "Accuracy": acc,
        "Confidence": confidence,
    }

#### Let's get to it!

We train each model, calling the appropriate function.

In [None]:
criterion = nn.CrossEntropyLoss()

start_time = time.time()
train_single(model, train_loader, criterion, epochs=100)
base_time = time.time() - start_time

start_time = time.time()
train_ensemble(ensemble_model, train_loader, criterion, epochs=100)
ensemble_time = time.time() - start_time

start_time = time.time()
train_subensemble(sub_ensemble_model, train_loader, criterion, epochs=100)
subensemble_time = time.time() - start_time

To evaluation and beyond!

In [None]:
models = [
    ("Base", base_model),
    ("Full Ensemble", ensemble_model),
    ("Sub-Ensemble", sub_ensemble_model),
]

results = {}

res_base = evalu(base_model, val_loader)
res_base["train_time"] = base_time
results["Base"] = res_base

res_ensemble = evalu(ensemble_model, val_loader)
res_ensemble["train_time"] = ensemble_time
results["Full Ensemble"] = res_ensemble

res_subensemble = evalu(sub_ensemble_model, val_loader)
res_subensemble["train_time"] = subensemble_time
results["Sub-Ensemble"] = res_subensemble

### The results

In [None]:
colors = ["tab:green", "tab:blue", "tab:purple"]
models = list(results.keys())

# Plot 1: Accuracy & Confidence Comparison
metrics = ["Accuracy", "Confidence"]
x = np.arange(len(metrics))
bar_width = 0.15

fig, ax = plt.subplots(figsize=(8, 5))
for i, m in enumerate(models):
    values = [float(results[m][met]) for met in metrics]
    ax.bar(x + i * bar_width, values, bar_width, color=colors[i], label=m)

ax.set_xticks(x + bar_width * (len(models) - 1) / 2)
ax.set_xticklabels(metrics)
ax.set_ylabel("Value")
ax.set_title("Model Comparison: Accuracy & Confidence")
ax.legend()
plt.show()

# Plot 2: Training Time Comparison
fig, ax = plt.subplots(figsize=(8, 5))
ax.bar(models, [float(results[m]["train_time"]) for m in models], color=colors)
ax.set_ylabel("Training Time (s)")
ax.set_title("Training Time Comparison")
plt.show()

### Disclaimer

Keep in mind that this experiment uses synthetic data, the results can vary.

Furthermore, this live experiment was primarily included to show how Sub-Ensembles can be created and roughly show the possible advantages they have over Full Ensembles.

### Anyways, take with a grain of salt

We should see these trends:
- **Accuracy:** highest for Full Ensemble, with Sub-Ensemble next.
- **Confidence:** lowest for Sub-Ensemble, while Full Ensemble is closer to the base model.
- **Training Time:** highest for Full Ensemble, while Sub-Ensemble creates a speedup of roughly 1.5-3.0

Feel free to play around with these values (they were kept low to improve the runtime of this notebook):
- Models: `n_members`, `n_heads`
- Data: `n_samples`, `batch_size`
- Training: `epochs`