# Batch Ensemble Networks

In this notebook, we:

A) Explain **Batch Ensembles**.

B) Implement a **Ensemble MLP**.

C) Implement a **BatchEnsemble MLP**.

D) Train both networks on CIFAR10 and compare **accuracy and speed**.


## 1. Introduction: What are Batch Ensembles?

**Batch Ensembles** are a way to efficiently approximate an ensemble of neural networks. Traditional ensembles require training and storing multiple independent networks, which is memory and computation expensive.

Key ideas:
- Use **shared base weights** for all ensemble members.
- Introduce **rank-1 multiplicative factors** for each member.
- Much **faster and memory-efficient** than classic ensembles.

Mathematically, for a linear layer with weight matrix W:

$$y_i = (W \circ (r_i s_i^T)) x + b$$

Where $r_i, s_i$ are the rank-1 vectors for ensemble member $i$, and $\circ$ denotes element-wise multiplication.
 

## 1. Quick Setup

In [13]:
import time

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cpu


## 2. Import CIFAR10 Dataset

In [15]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ],
)

train_data = CIFAR10(root="./data", train=True, transform=transform, download=True)
val_data = CIFAR10(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=2)
val_loader = DataLoader(val_data, batch_size=256, shuffle=False, num_workers=2)

print(f"Train samples: {len(train_data)},  Val samples: {len(val_data)}")

Train samples: 50000,  Val samples: 10000


## 3. Classic MLP

Creating a MLP to use as a base model with 2 Hidden Layers.
 

In [18]:
class MLP(nn.Module):
    def __init__(self, in_dim: int = 3072, hidden: int = 128, out_dim: int = 10) -> None:
        """Initialize the MLP model with two hidden layers.

        Args:
            in_dim (int): Dimension of the input features. Default is 3072 (32x32x3 for CIFAR-10).
            hidden (int): Number of neurons in the hidden layers. Default is 128.
            out_dim (int): Dimension of the output features. Default is 10 (number of classes in CIFAR-10).
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the MLP model.

        Before passing the input through the network, it flattens the input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_dim).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, out_dim).
        """
        x = x.view(x.size(0), -1)
        return self.net(x)

While theres currently no training functionality implemented in *probly*, we define the training function ourselves:

In [19]:
def train_ensemble(
    base_cls: nn.Module,
    k: int,
    train_loader: DataLoader,
    epochs: int = 10,
    lr: float = 1e-3,
) -> list[nn.Module]:
    models = []
    for _ in range(k):
        print(f"\nTraining ensemble member {_ + 1}/{k}")
        m_k = base_cls().to(device)
        opt = optim.Adam(m_k.parameters(), lr=lr)
        lossfn = nn.CrossEntropyLoss()
        for epoch in range(epochs):
            t0 = time.perf_counter()
            m_k.train()
            total_loss = 0.0
            for xb, yb in train_loader:
                x = xb.to(device).float()
                y = yb.to(device).long()
                opt.zero_grad()
                out = m_k(x)
                loss = lossfn(out, y)
                loss.backward()
                total_loss += loss.item()
                opt.step()
            avg_loss = total_loss / len(train_loader)
            t1 = time.perf_counter()
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}, Time: {t1 - t0:.2f}s")
        models.append(m_k)
    return models

## 4. Define Batch Ensemble Linear Layer and Batch Ensemble MLP

In [20]:
import math

import torch
from torch import nn


class BatchEnsembleLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int, ensemble_size: int) -> None:
        """Initialize a BatchEnsemble Linear layer.

        Args:
            in_features (int): Number of input features.
            out_features (int): Number of output features.
            ensemble_size (int): Number of ensemble members.
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ensemble_size = ensemble_size

        # Shared weight and bias
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = nn.Parameter(torch.Tensor(out_features))

        # Rank-one factors
        self.r = nn.Parameter(torch.Tensor(ensemble_size, out_features))
        self.s = nn.Parameter(torch.Tensor(ensemble_size, in_features))

        # Init
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.r, 1.0, 0.01)
        nn.init.normal_(self.s, 1.0, 0.01)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the BatchEnsemble Linear layer.

        Args:
            x (torch.Tensor): Input tensor of shape [B, in_features] or [E, B, in_features],
                              where B is the batch size and E is the ensemble size.

        Returns:
            torch.Tensor: Output tensor of shape [E, B, out_features].
        """
        if x.dim() == 2:
            # First layer: add ensemble dimension
            x = x.unsqueeze(0).expand(self.ensemble_size, -1, -1)  # [E, B, in_features]
        elif x.dim() == 3 and x.size(0) != self.ensemble_size:
            msg = f"Expected first dim={self.ensemble_size}, got {x.size(0)}"
            raise ValueError(msg)

        x = x * self.s.unsqueeze(1)
        y = torch.matmul(x, self.weight.t())
        y = y * self.r.unsqueeze(1) + self.bias
        return y


class BatchEnsembleMLP(nn.Module):
    def __init__(self, in_dim: int = 3072, hidden: int = 128, out_dim: int = 10, ensemble_size: int = 4) -> None:
        """Initialize the BatchEnsemble MLP model with three fully connected layers.

        Args:
            in_dim (int): Dimension of the input features. Default is 3072 (32x32x3 for CIFAR-10).
            hidden (int): Number of neurons in the hidden layers. Default is 128.
            out_dim (int): Dimension of the output features. Default is 10 (number of classes in CIFAR-10).
            ensemble_size (int): Number of ensemble members. Default is 4.
        """
        super().__init__()
        self.ensemble_size = ensemble_size
        self.fc1 = BatchEnsembleLinear(in_dim, hidden, ensemble_size)
        self.fc2 = BatchEnsembleLinear(hidden, hidden, ensemble_size)
        self.fc3 = BatchEnsembleLinear(hidden, out_dim, ensemble_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the BatchEnsemble MLP.

        This mimics the standard MLP forward pass but uses BatchEnsembleLinear layers.
        Meaning having two hidden layers.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_dim).

        Returns:
            torch.Tensor: Output tensor of shape (ensemble_size, batch_size, out_dim).
        """
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [21]:
def train_batchensemble(
    base_cls: BatchEnsembleMLP,
    train_loader: DataLoader,
    epochs: int = 10,
    lr: float = 1e-3,
) -> nn.Module:
    model = base_cls().to(device)
    ensemble_size = base_cls().ensemble_size
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        t0 = time.perf_counter()
        running_loss = 0.0

        for xb, yb in train_loader:
            x = xb.to(device).float()
            y = yb.to(device).long()

            optimizer.zero_grad()

            # Forward pass: [E, B, out_dim]
            out = model(x)

            # Compute loss per ensemble member
            # out: [E, B, out_dim], y: [B]
            loss = 0.0
            for e in range(ensemble_size):
                loss += loss_fn(out[e], y)
            loss = loss / ensemble_size

            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        t1 = time.perf_counter()
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}, Time: {t1 - t0:.2f}s")
    return model

In [22]:
in_dim = 3 * 32 * 32
hidden = 128
ensemble_size = 5
epochs = 20
lr = 1e-3

In [23]:
t0_e = time.perf_counter()
trained_ensemble = train_ensemble(
    base_cls=lambda: MLP(in_dim=in_dim, hidden=hidden, out_dim=10),
    k=ensemble_size,
    train_loader=train_loader,
    epochs=epochs,
    lr=lr,
)
t1_e = time.perf_counter()
print(f"\nTrained classical ensemble of size {ensemble_size} in {t1_e - t0_e:.2f}s")


Training ensemble member 1/5
Epoch 1/20, Loss: 1.6492, Time: 41.37s
Epoch 2/20, Loss: 1.4371, Time: 33.59s
Epoch 3/20, Loss: 1.3390, Time: 33.17s
Epoch 4/20, Loss: 1.2656, Time: 33.15s
Epoch 5/20, Loss: 1.2022, Time: 33.42s
Epoch 6/20, Loss: 1.1451, Time: 33.27s
Epoch 7/20, Loss: 1.0908, Time: 33.25s
Epoch 8/20, Loss: 1.0445, Time: 35.72s
Epoch 9/20, Loss: 1.0010, Time: 35.03s
Epoch 10/20, Loss: 0.9609, Time: 35.97s
Epoch 11/20, Loss: 0.9262, Time: 35.80s
Epoch 12/20, Loss: 0.8893, Time: 36.40s
Epoch 13/20, Loss: 0.8508, Time: 36.47s
Epoch 14/20, Loss: 0.8169, Time: 36.46s
Epoch 15/20, Loss: 0.7888, Time: 35.95s
Epoch 16/20, Loss: 0.7547, Time: 36.60s
Epoch 17/20, Loss: 0.7266, Time: 36.02s
Epoch 18/20, Loss: 0.7035, Time: 36.18s
Epoch 19/20, Loss: 0.6710, Time: 35.96s
Epoch 20/20, Loss: 0.6545, Time: 35.97s

Training ensemble member 2/5
Epoch 1/20, Loss: 1.6517, Time: 35.90s
Epoch 2/20, Loss: 1.4401, Time: 36.27s
Epoch 3/20, Loss: 1.3407, Time: 36.62s
Epoch 4/20, Loss: 1.2668, Time: 

In [24]:
t0_be = time.perf_counter()
trained_be_model = train_batchensemble(
    base_cls=lambda: BatchEnsembleMLP(in_dim=in_dim, hidden=hidden, out_dim=10, ensemble_size=ensemble_size),
    train_loader=train_loader,
    epochs=epochs,
    lr=lr,
)
t1_be = time.perf_counter()
print(f"\nTrained BatchEnsemble model of size {ensemble_size} in {t1_be - t0_be:.2f}s")

Epoch 1/20, Loss: 1.6542, Time: 22.28s
Epoch 2/20, Loss: 1.4417, Time: 22.58s
Epoch 3/20, Loss: 1.3403, Time: 22.56s
Epoch 4/20, Loss: 1.2650, Time: 22.76s
Epoch 5/20, Loss: 1.1972, Time: 22.67s
Epoch 6/20, Loss: 1.1395, Time: 22.87s
Epoch 7/20, Loss: 1.0898, Time: 22.55s
Epoch 8/20, Loss: 1.0481, Time: 22.23s
Epoch 9/20, Loss: 1.0029, Time: 22.52s
Epoch 10/20, Loss: 0.9623, Time: 22.66s
Epoch 11/20, Loss: 0.9255, Time: 22.63s
Epoch 12/20, Loss: 0.8850, Time: 22.84s
Epoch 13/20, Loss: 0.8553, Time: 22.77s
Epoch 14/20, Loss: 0.8216, Time: 23.05s
Epoch 15/20, Loss: 0.7891, Time: 22.46s
Epoch 16/20, Loss: 0.7590, Time: 22.36s
Epoch 17/20, Loss: 0.7283, Time: 24.52s
Epoch 18/20, Loss: 0.7030, Time: 24.87s
Epoch 19/20, Loss: 0.6788, Time: 22.97s
Epoch 20/20, Loss: 0.6597, Time: 22.79s

Trained BatchEnsemble model of size 5 in 457.00s


In [25]:
class Evaluator:
    def __init__(self, data_loader: torch.utils.data.DataLoader, device: str) -> None:
        """Initialize the Evaluator with a data loader and device.

        Args:
            data_loader (torch.utils.data.DataLoader): DataLoader for evaluation data.
            device (str): Device to run the evaluation on ('cpu' or 'cuda').
        """
        self.data_loader = data_loader
        self.device = device

    def _setup(self) -> None:
        self.correct = 0
        self.total = 0
        self.member_predictions = []

    def evaluate_batchensemble(self, model: BatchEnsembleMLP) -> tuple[float, torch.Tensor]:
        """Evaluate a BatchEnsemble model."""
        self._setup()
        model.to(self.device)
        model.eval()

        with torch.no_grad():
            for xb, yb in self.data_loader:
                x = xb.to(self.device).float()
                y = yb.to(self.device).long()

                out = model(x)  # [E, B, out_dim]
                preds = torch.argmax(out, dim=2)  # [E, B]

                self.correct += (preds == y.unsqueeze(0)).sum().item()
                self.total += y.size(0) * model.ensemble_size
                self.member_predictions.append(preds.cpu())

        accuracy = self.correct / self.total
        all_member_preds = torch.cat(self.member_predictions, dim=1)

        return accuracy, all_member_preds

    def evaluate_classical_ensemble(self, models: list[MLP]) -> tuple[float, torch.Tensor]:
        """Evaluate a classical ensemble of models."""
        self._setup()
        for m in models:
            m.to(self.device)
            m.eval()

        with torch.no_grad():
            for xb, yb in self.data_loader:
                x = xb.to(self.device).float()
                y = yb.to(self.device).long()

                batch_member_preds = []
                for m in models:
                    out = m(x)  # [B, out_dim]
                    preds = torch.argmax(out, dim=1)  # [B]
                    batch_member_preds.append(preds.cpu().unsqueeze(0))  # [1, B]

                batch_member_preds = torch.cat(batch_member_preds, dim=0)  # [E, B]
                self.correct += (batch_member_preds == y.unsqueeze(0).cpu()).sum().item()
                self.total += y.size(0) * len(models)
                self.member_predictions.append(batch_member_preds)

        accuracy = self.correct / self.total
        all_member_preds = torch.cat(self.member_predictions, dim=1)
        return accuracy, all_member_preds

In [26]:
# Evaluate BatchEnsemble
evaluator = Evaluator(val_loader, device)
be_acc, be_member_preds = evaluator.evaluate_batchensemble(trained_be_model)
print(f"BatchEnsemble Accuracy: {be_acc:.4f}")

BatchEnsemble Accuracy: 0.5126


In [27]:
for m in trained_ensemble:
    m.to(device)
ce_acc, ce_member_preds = evaluator.evaluate_classical_ensemble(trained_ensemble)
print(f"Classical Ensemble Accuracy: {ce_acc:.4f}")

Classical Ensemble Accuracy: 0.5196
