Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

runtime performance gain on model ensembling #1061

@xuyxu

Description

@xuyxu

Hi,

After noticing this nice package from the release note of pytorch, we are making our efforts to include it into our repo Ensemble-Pytorch, a member of the pytorch ecosystem focusing on state-of-the-art ensemble methods.

Following the introduction on model ensembling, here is our code snippet on runtime benchmarking. The snippet trains 5 simple LeNet5 models on CIFAR-10, and checks the runtime on test_loader using functorch and the original forward method.

import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchensemble.voting import VotingClassifier
from torchensemble.utils.logging import set_logger

from functorch import vmap
from memory_profiler import profile


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 400)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# @profile
def functorch_inference(data_loader, fmodel, params, buffers, device):
    for idx, (data, target) in enumerate(data_loader):
        data = data.to(device)
        vmap(fmodel, in_dims=(0, 0, None))(params, buffers, data)


# @profile
def pytorch_inference(data_loader, model):
    for idx, (data, target) in enumerate(data_loader):
        data = data.to(model.device)
        model(data)


if __name__ == "__main__":

    # Hyper-parameters
    n_estimators = 5
    lr = 1e-3
    weight_decay = 5e-4
    epochs = 5
    n_trials = 10

    # Utils
    batch_size = 128
    data_dir = "../../Dataset/cifar"  # MODIFY THIS IF YOU WANT
    records = []
    torch.manual_seed(0)

    # Load data
    train_transformer = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
            ),
        ]
    )

    test_transformer = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
            ),
        ]
    )

    train_loader = DataLoader(
        datasets.CIFAR10(
            data_dir, train=True, download=True, transform=train_transformer
        ),
        batch_size=batch_size,
        shuffle=True,
    )

    test_loader = DataLoader(
        datasets.CIFAR10(data_dir, train=False, transform=test_transformer),
        batch_size=batch_size,
        shuffle=True,
    )

    logger = set_logger("functorch_benchmark", use_tb_logger=True)
    
    # VotingClassifier
    model = VotingClassifier(
        estimator=LeNet5, n_estimators=n_estimators, cuda=False
    )

    # Set the optimizer
    model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay)

    # Training
    tic = time.time()
    model.fit(train_loader, epochs=epochs)
    toc = time.time()
    training_time = toc - tic

    fmodel, params, buffers = model.vectorize()  # Internally: fmodel, params, buffers = combine_state_for_ensemble(self.estimators_)

    tic = time.time()
    for _ in range(n_trials):
        functorch_inference(test_loader, fmodel, params, buffers, model.device)
    toc = time.time()
    print("functorch: {:.3f}s".format(toc - tic))

    tic = time.time()
    for _ in range(n_trials):
        pytorch_inference(test_loader, model)
    toc = time.time()
    print("pytorch: {:.3f}s".format(toc - tic))

The result is kind of strange:

  • CPU (Overclocked)
    • functorch: 34.716s
    • pytorch: 30.519s
  • GPU
    • functorch: 43.139s
    • pytorch: 43.905s

The performance gain is marginal compared to the official document. I will appreciate it very much if anyone could tell me where goes wrong. Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions