# **Batch Ensemble Networks**

In this notebook, we:

A) Explain the idea behind **Batch Ensembles**.

B) Create a small Multi-Layer-Perceptron (MLP).

D) Train the **MLP** as an Ensemble and BatchEnsemble network on CIFAR10.

E) Compare **accuracy** and **speed**.


## A) Introduction: What are Batch Ensembles?

**Batch Ensembles** are a way to efficiently approximate an ensemble of neural networks. Traditional ensembles require training and storing multiple independent networks, which is memory and computation expensive.

Key ideas:
- Use **shared base weights (and biases)** for all ensemble members.
- Introduce **rank-1 multiplicative factors** for each member.
- Much **faster and memory-efficient** than classic ensembles.

Mathematically the classic forward of

$$y_i = W \circ x + b$$

transforms to

$$y_i = (W \circ (x \circ s_i^T)) \circ r_i+ b$$

Where $r_i, s_i$ are the rank-1 vectors for ensemble member $i$, and $\circ$ denotes element-wise multiplication.
 

## B) Setup of the MLP

**Standard Imports and Pytorch Setup**

In [1]:
import time
import torch
from typing import Any
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cpu


**Import CIFAR10 Dataset**

In [2]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        # add more transforms if desired
    ],
)

train_data = CIFAR10(root="./data", train=True, transform=transform, download=True)
val_data = CIFAR10(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=2)

print(f"Train samples: {len(train_data)},  Val samples: {len(val_data)}")

Train samples: 50000,  Val samples: 10000


**The MLP Class**

We create a MLP inherting basic functionality from the nn.Module parent class. The MLP has to hidden layers utilizing the ReLU activation function.
 

In [3]:
class MLP(nn.Module):
    def __init__(self, in_dim: int = 3072, hidden: int = 128, out_dim: int = 10) -> None:
        """Initialize the MLP model with two hidden layers.

        Args:
            in_dim (int): Dimension of the input features. Default is 3072 (32x32x3 for CIFAR-10).
            hidden (int): Number of neurons in the hidden layers. Default is 128.
            out_dim (int): Dimension of the output features. Default is 10 (number of classes in CIFAR-10).
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the MLP model.

        Before passing the input through the network, it flattens the input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_dim).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, out_dim).
        """
        x = x.view(x.size(0), -1)
        return self.net(x)

## C) Training

**Training Methods**

While there is currently no training functionality implemented in *probly* we define the training methods below.

**Base Training Method**

In [4]:
def train_model(model: nn.Module,
              optimizer: Any,
              loss_function: Any,
              train_loader: DataLoader,
              epochs: int=10,
              num_members : int = None,
) -> list[nn.Module]:
    for epoch in range(epochs):
        t0 = time.perf_counter()
        total_loss = 0.0
        model.train()
        for xb, yb in train_loader:
            x = xb.to(device).float()
            y = yb.to(device).long()
            optimizer.zero_grad()
            out = model(x)

            if isinstance(num_members,int) and num_members > 0:
                loss = 0.0
                for e in range(num_members):
                    loss += loss_function(out[e], y)
                loss = loss / num_members
            else:
                # fallback to standard loss computation
                loss = loss_function(out, y)
            loss.backward()
            total_loss += loss.item()
            optimizer.step()
        avg_loss = total_loss / len(train_loader)
        t1 = time.perf_counter()
        print(f"Epoch {epoch + 1}/{epochs} trained in {t1-t0} seconds.")
        print(f"> Loss: {avg_loss}")
    return model        


In [5]:
def train_ensemble(
    ensemble : list[nn.Module],
    train_loader: DataLoader,
    epochs: int = 10,
    lr: float = 1e-3,
) -> list[nn.Module]:
    models = []

    for i,member in enumerate(ensemble):
        print(f"\nTraining ensemble member {i + 1}/{len(ensemble)}")
        model_i = member.to(device)
        optimizer = optim.Adam(model_i.parameters(), lr=lr)
        trained_model_i = train_model(model_i, optimizer=optimizer,
                                      loss_function=nn.CrossEntropyLoss(),
                              train_loader=train_loader,epochs=epochs)
        models.append(trained_model_i)
    return models

def train_batchensemble(
    base_cls: MLP,
    num_members: int,
    train_loader: DataLoader,
    epochs: int = 10,
    lr: float = 1e-3,
) -> nn.Module:
    
    model = base_cls.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    model = train_model(model, optimizer=optimizer, 
                        loss_function=nn.CrossEntropyLoss(),
                        train_loader=train_loader,epochs=epochs,num_members=num_members)
    
    return model

In [6]:
from probly.transformation import batchensemble
in_dim = 3 * 32 * 32
hidden = 128
num_members = 5
epochs = 20
lr = 1e-3
ensemble = [MLP(in_dim=in_dim, hidden=hidden, out_dim=10) for _ in range(num_members)]
batch_ensemble_mlp = batchensemble(MLP(in_dim=in_dim, hidden=hidden, out_dim=10),num_members=num_members)

In [7]:
t0_batch_ensemble = time.perf_counter()
trained_batch_ensemble = train_batchensemble(
    base_cls=batch_ensemble_mlp,
    num_members=num_members,
    train_loader=train_loader,
    epochs=epochs,
    lr=lr,
)
t1_batch_ensemble = time.perf_counter()
print(f"\nTrained BatchEnsemble model of size {num_members} in {t1_batch_ensemble - t0_batch_ensemble:.2f}s")

Epoch 1/20 trained in 40.554668200202286 seconds.
> Loss: 1.8352828882324794
Epoch 2/20 trained in 30.627204100135714 seconds.
> Loss: 1.5650636937781472
Epoch 3/20 trained in 29.23773320019245 seconds.
> Loss: 1.4680251480674256
Epoch 4/20 trained in 26.651254000142217 seconds.
> Loss: 1.4031920902483446
Epoch 5/20 trained in 26.01576830027625 seconds.
> Loss: 1.354238022655077
Epoch 6/20 trained in 26.190833599772304 seconds.
> Loss: 1.3148494001694848
Epoch 7/20 trained in 26.76800880022347 seconds.
> Loss: 1.2799603526819539
Epoch 8/20 trained in 29.592886600177735 seconds.
> Loss: 1.2490024723384294
Epoch 9/20 trained in 26.600823500193655 seconds.
> Loss: 1.2213217937938692
Epoch 10/20 trained in 30.731435799971223 seconds.
> Loss: 1.197284526422248
Epoch 11/20 trained in 28.348196799866855 seconds.
> Loss: 1.1747562160723803
Epoch 12/20 trained in 30.84166119992733 seconds.
> Loss: 1.1535143873970706
Epoch 13/20 trained in 28.81197979999706 seconds.
> Loss: 1.134618471504707
Epo

In [8]:
t0_ensemble = time.perf_counter()
trained_ensemble = train_ensemble(
    ensemble,
    train_loader=train_loader,
    epochs=epochs,
    lr=lr,
)
t1_ensemble = time.perf_counter()
print(f"\nTrained classical ensemble of size {num_members} in {t1_ensemble - t0_ensemble:.2f}s")


Training ensemble member 1/5
Epoch 1/20 trained in 20.00968790007755 seconds.
> Loss: 1.6475496249403316
Epoch 2/20 trained in 20.433384000323713 seconds.
> Loss: 1.4513696497324102
Epoch 3/20 trained in 19.71125359972939 seconds.
> Loss: 1.3573608495299814
Epoch 4/20 trained in 20.114626599941403 seconds.
> Loss: 1.2816591315824712
Epoch 5/20 trained in 20.998445800039917 seconds.
> Loss: 1.2200282713730826
Epoch 6/20 trained in 21.642976799979806 seconds.
> Loss: 1.1680420410991554
Epoch 7/20 trained in 21.16135359974578 seconds.
> Loss: 1.1187594093272721
Epoch 8/20 trained in 21.403368500061333 seconds.
> Loss: 1.0765911825787768
Epoch 9/20 trained in 20.987796700093895 seconds.
> Loss: 1.0347559028760942
Epoch 10/20 trained in 21.356161499861628 seconds.
> Loss: 0.9983648232596087
Epoch 11/20 trained in 21.00062010018155 seconds.
> Loss: 0.9627541714536785
Epoch 12/20 trained in 21.23691130010411 seconds.
> Loss: 0.932914604857726
Epoch 13/20 trained in 25.427497000433505 seconds

In [9]:
class Evaluator:
    def __init__(self, data_loader: torch.utils.data.DataLoader, device: str) -> None:
        """Initialize the Evaluator with a data loader and device.

        Args:
            data_loader (torch.utils.data.DataLoader): DataLoader for evaluation data.
            device (str): Device to run the evaluation on ('cpu' or 'cuda').
        """
        self.data_loader = data_loader
        self.device = device

    def _setup(self) -> None:
        self.correct = 0
        self.total = 0
        self.member_predictions = []

    def evaluate_batchensemble(self, model: MLP, num_members) -> tuple[float, torch.Tensor]:
        """Evaluate a BatchEnsemble model."""
        self._setup()
        model.to(self.device)
        model.eval()

        with torch.no_grad():
            for xb, yb in self.data_loader:
                x = xb.to(self.device).float()
                y = yb.to(self.device).long()

                out = model(x)  # [E, B, out_dim]
                preds = torch.argmax(out, dim=2)  # [E, B]

                self.correct += (preds == y.unsqueeze(0)).sum().item()
                self.total += y.size(0) * num_members
                self.member_predictions.append(preds.cpu())

        accuracy = self.correct / self.total
        all_member_preds = torch.cat(self.member_predictions, dim=1)

        return accuracy, all_member_preds

    def evaluate_classical_ensemble(self, models: list[MLP]) -> tuple[float, torch.Tensor]:
        """Evaluate a classical ensemble of models."""
        self._setup()
        for m in models:
            m.to(self.device)
            m.eval()

        with torch.no_grad():
            for xb, yb in self.data_loader:
                x = xb.to(self.device).float()
                y = yb.to(self.device).long()

                batch_member_preds = []
                for m in models:
                    out = m(x)  # [B, out_dim]
                    preds = torch.argmax(out, dim=1)  # [B]
                    batch_member_preds.append(preds.cpu().unsqueeze(0))  # [1, B]

                batch_member_preds = torch.cat(batch_member_preds, dim=0)  # [E, B]
                self.correct += (batch_member_preds == y.unsqueeze(0).cpu()).sum().item()
                self.total += y.size(0) * len(models)
                self.member_predictions.append(batch_member_preds)

        accuracy = self.correct / self.total
        all_member_preds = torch.cat(self.member_predictions, dim=1)
        return accuracy, all_member_preds

In [10]:
# Evaluate BatchEnsemble
evaluator = Evaluator(val_loader, device)
be_acc, be_member_preds = evaluator.evaluate_batchensemble(trained_batch_ensemble, num_members)
print(f"BatchEnsemble Accuracy: {be_acc:.4f}")

BatchEnsemble Accuracy: 0.5110


In [11]:
for m in trained_ensemble:
    m.to(device)
ce_acc, ce_member_preds = evaluator.evaluate_classical_ensemble(trained_ensemble)
print(f"Classical Ensemble Accuracy: {ce_acc:.4f}")

Classical Ensemble Accuracy: 0.5081
