# **Batch Ensemble Networks**

In this notebook, we:

A) Explain the idea behind **Batch Ensembles**.

B) Create a small Multi-Layer-Perceptron (MLP).

D) Train the **MLP** as an Ensemble and BatchEnsemble network on CIFAR10.

E) Compare **accuracy** and **speed**.


## A) Introduction: What are Batch Ensembles?

**Batch Ensembles** are a way to efficiently approximate an ensemble of neural networks. Traditional ensembles require training and storing multiple independent networks, which is memory and computation expensive.

Key ideas:
- Use **shared base weights (and biases)** for all ensemble members.
- Introduce **rank-1 multiplicative factors** for each member.
- Much **faster and memory-efficient** than classic ensembles.

Mathematically the classic forward of

$$y_i = W \circ x + b$$

transforms to

$$y_i = (W \circ (x \circ s_i^T)) \circ r_i+ b$$

Where $r_i, s_i$ are the rank-1 vectors for ensemble member $i$, and $\circ$ denotes element-wise multiplication.


## What does the transformation do?

### The parameters

- num_members: The number of ensemble members to create.
- s_mean: The mean used to initialize the input modulation factor s.
- s_std: The standard deviation used to initialize the input modulation factor s.
- r_mean: The mean used to initialize the output modulation factor r.
- r_std: The standard deviation used to initialize the output modulation factor r.

### The layers

With these parameters we can transform:
- Linear layer into BatchEnsembleLinear layer
- Conv2d layer into BatchEnsembleConv2d layer

The transformation keeps the dimensions and base weights, while adding rank-1 factors:
- **s:** scales input-dimension per member
- **r:** scales output-dimension per member

The base weights (weight) and bias (bias) are shared across all members, keeping memory usage minimal. The differences between members arise solely from their individual scaling factorys **s** and **r**.

## B) Setup of the MLP

**Standard Imports and Pytorch Setup**

In [1]:
import time

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cpu


**Import CIFAR10 Dataset**

In [2]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        # add more transforms if desired
    ],
)

train_data = CIFAR10(root="./data", train=True, transform=transform, download=True)
val_data = CIFAR10(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=2)

print(f"Train samples: {len(train_data)},  Val samples: {len(val_data)}")

Train samples: 50000,  Val samples: 10000


**The MLP Class**

We create a MLP inherting basic functionality from the nn.Module parent class. The MLP has to hidden layers utilizing the ReLU activation function.


In [3]:
class MLP(nn.Module):
    def __init__(self, in_dim: int = 3072, hidden: int = 128, out_dim: int = 10) -> None:
        """Initialize the MLP model with two hidden layers.

        Args:
            in_dim (int): Dimension of the input features. Default is 3072 (32x32x3 for CIFAR-10).
            hidden (int): Number of neurons in the hidden layers. Default is 128.
            out_dim (int): Dimension of the output features. Default is 10 (number of classes in CIFAR-10).
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the MLP model.

        Before passing the input through the network, it flattens the input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_dim).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, out_dim).
        """
        x = x.view(x.size(0), -1)
        return self.net(x)

**The Ensemble MLP**

In [4]:
from probly.transformation.ensemble import ensemble

in_dim = 3 * 32 * 32
hidden = 128
num_members = 5

ensemble_mlp = ensemble(
    base=MLP(in_dim=in_dim, hidden=hidden, out_dim=10),
    num_members=num_members,
)

**The BatchEnsemble MLP**

In [5]:
from probly.transformation import batchensemble

batch_ensemble_mlp = batchensemble(
    base=MLP(in_dim=in_dim, hidden=hidden, out_dim=10),
    num_members=num_members,
)

Let's compare the different models now.

We start with a comparison of the base MLP and the BatchEnsemble MLP:

In [6]:
print(f"Base MLP:\n{MLP(in_dim=in_dim, hidden=hidden, out_dim=10)}\n")
print(f"BatchEnsemble MLP:\n{batch_ensemble_mlp}")

Base MLP:
MLP(
  (net): Sequential(
    (0): Linear(in_features=3072, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=10, bias=True)
  )
)

BatchEnsemble MLP:
MLP(
  (net): Sequential(
    (0): BatchEnsembleLinear()
    (1): ReLU()
    (2): BatchEnsembleLinear()
    (3): ReLU()
    (4): BatchEnsembleLinear()
  )
)


Then we compare the Ensemble MLP and the BatchEnsemble MLP:

In [7]:
print(f"Ensemble MLP:\n{ensemble_mlp}\n")
print(f"BatchEnsemble MLP:\n{batch_ensemble_mlp}")

Ensemble MLP:
ModuleList(
  (0-4): 5 x MLP(
    (net): Sequential(
      (0): Linear(in_features=3072, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=10, bias=True)
    )
  )
)

BatchEnsemble MLP:
MLP(
  (net): Sequential(
    (0): BatchEnsembleLinear()
    (1): ReLU()
    (2): BatchEnsembleLinear()
    (3): ReLU()
    (4): BatchEnsembleLinear()
  )
)


## C) Training

**Training Methods**

While there is currently no training functionality implemented in *probly* we define the training methods below.

**Base Training Method**

In [None]:
def train_model(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    loss_function: nn.CrossEntropyLoss,
    train_loader: DataLoader,
    epochs: int = 10,
    num_members: int | None = None,
) -> nn.Module:
    for epoch in range(epochs):
        t0 = time.perf_counter()
        total_loss = 0.0
        model.train()
        for xb, yb in train_loader:
            x = xb.to(device).float()
            y = yb.to(device).long()
            optimizer.zero_grad()
            out = model(x)

            if isinstance(num_members, int) and num_members > 0:
                loss = 0.0
                for e in range(num_members):
                    loss += loss_function(out[e], y)
                loss = loss / num_members
            else:
                # fallback to standard loss computation
                loss = loss_function(out, y)
            loss.backward()
            total_loss += loss.item()
            optimizer.step()
        avg_loss = total_loss / len(train_loader)
        t1 = time.perf_counter()
        print(f"Epoch {epoch + 1}/{epochs} trained in {t1 - t0} seconds.")
        print(f"> Loss: {avg_loss}")
    return model

In [None]:
def train_ensemble(
    ensemble: MLP,
    train_loader: DataLoader,
    epochs: int = 10,
    lr: float = 1e-3,
) -> nn.ModuleList:
    model = nn.ModuleList()

    for i, member in enumerate(ensemble):
        print(f"\nTraining ensemble member {i + 1}/{len(ensemble)}")
        member_i = member.to(device)
        optimizer = optim.Adam(member_i.parameters(), lr=lr)
        train_model(
            member_i,
            optimizer=optimizer,
            loss_function=nn.CrossEntropyLoss(),
            train_loader=train_loader,
            epochs=epochs,
        )
        model.append(member_i)
    return model


def train_batchensemble(
    base_cls: MLP,
    num_members: int,
    train_loader: DataLoader,
    epochs: int = 10,
    lr: float = 1e-3,
) -> nn.Module:
    model = base_cls.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model = train_model(
        model,
        optimizer=optimizer,
        loss_function=nn.CrossEntropyLoss(),
        train_loader=train_loader,
        epochs=epochs,
        num_members=num_members,
    )

    return model

In [None]:
epochs = 1
lr = 1e-3

t0_batch_ensemble = time.perf_counter()
trained_batch_ensemble = train_batchensemble(
    base_cls=batch_ensemble_mlp,
    num_members=num_members,
    train_loader=train_loader,
    epochs=epochs,
    lr=lr,
)
t1_batch_ensemble = time.perf_counter()
print(f"\nTrained BatchEnsemble model of size {num_members} in {t1_batch_ensemble - t0_batch_ensemble:.2f}s")

In [None]:
t0_ensemble = time.perf_counter()
trained_ensemble = train_ensemble(
    ensemble=ensemble_mlp,
    train_loader=train_loader,
    epochs=epochs,
    lr=lr,
)
t1_ensemble = time.perf_counter()
print(f"\nTrained classical ensemble of size {num_members} in {t1_ensemble - t0_ensemble:.2f}s")

In [None]:
class Evaluator:
    def __init__(self, data_loader: torch.utils.data.DataLoader, device: str) -> None:
        """Initialize the Evaluator with a data loader and device.

        Args:
            data_loader (torch.utils.data.DataLoader): DataLoader for evaluation data.
            device (str): Device to run the evaluation on ('cpu' or 'cuda').
        """
        self.data_loader = data_loader
        self.device = device

    def _setup(self) -> None:
        self.correct = 0
        self.total = 0
        self.member_predictions = []

    def evaluate_batchensemble(self, model: nn.Module, num_members: int) -> tuple[float, torch.Tensor]:
        """Evaluate a BatchEnsemble model."""
        self._setup()
        model.to(self.device)
        model.eval()

        with torch.no_grad():
            for xb, yb in self.data_loader:
                x = xb.to(self.device).float()
                y = yb.to(self.device).long()

                out = model(x)  # [E, B, out_dim]
                preds = torch.argmax(out, dim=2)  # [E, B]

                self.correct += (preds == y.unsqueeze(0)).sum().item()
                self.total += y.size(0) * num_members
                self.member_predictions.append(preds.cpu())

        accuracy = self.correct / self.total
        all_member_preds = torch.cat(self.member_predictions, dim=1)

        return accuracy, all_member_preds

    def evaluate_classical_ensemble(self, models: nn.ModuleList) -> tuple[float, torch.Tensor]:
        """Evaluate a classical ensemble of models."""
        self._setup()
        for m in models:
            m.to(self.device)
            m.eval()

        with torch.no_grad():
            for xb, yb in self.data_loader:
                x = xb.to(self.device).float()
                y = yb.to(self.device).long()

                batch_member_preds = []
                for m in models:
                    out = m(x)  # [B, out_dim]
                    preds = torch.argmax(out, dim=1)  # [B]
                    batch_member_preds.append(preds.cpu().unsqueeze(0))  # [1, B]

                batch_member_preds = torch.cat(batch_member_preds, dim=0)  # [E, B]
                self.correct += (batch_member_preds == y.unsqueeze(0).cpu()).sum().item()
                self.total += y.size(0) * len(models)
                self.member_predictions.append(batch_member_preds)

        accuracy = self.correct / self.total
        all_member_preds = torch.cat(self.member_predictions, dim=1)
        return accuracy, all_member_preds

In [None]:
# Evaluate BatchEnsemble
evaluator = Evaluator(val_loader, device)
be_acc, be_member_preds = evaluator.evaluate_batchensemble(trained_batch_ensemble, num_members)
print(f"BatchEnsemble Accuracy: {be_acc:.4f}")

In [None]:
for m in trained_ensemble:
    m.to(device)
ce_acc, ce_member_preds = evaluator.evaluate_classical_ensemble(trained_ensemble)
print(f"Classical Ensemble Accuracy: {ce_acc:.4f}")