This repository was archived by the owner on Aug 21, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 105
This repository was archived by the owner on Aug 21, 2025. It is now read-only.
runtime performance gain on model ensembling #1061
Copy link
Copy link
Closed
Description
Hi,
After noticing this nice package from the release note of pytorch, we are making our efforts to include it into our repo Ensemble-Pytorch, a member of the pytorch ecosystem focusing on state-of-the-art ensemble methods.
Following the introduction on model ensembling, here is our code snippet on runtime benchmarking. The snippet trains 5 simple LeNet5 models on CIFAR-10, and checks the runtime on test_loader
using functorch and the original forward
method.
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchensemble.voting import VotingClassifier
from torchensemble.utils.logging import set_logger
from functorch import vmap
from memory_profiler import profile
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(400, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 400)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# @profile
def functorch_inference(data_loader, fmodel, params, buffers, device):
for idx, (data, target) in enumerate(data_loader):
data = data.to(device)
vmap(fmodel, in_dims=(0, 0, None))(params, buffers, data)
# @profile
def pytorch_inference(data_loader, model):
for idx, (data, target) in enumerate(data_loader):
data = data.to(model.device)
model(data)
if __name__ == "__main__":
# Hyper-parameters
n_estimators = 5
lr = 1e-3
weight_decay = 5e-4
epochs = 5
n_trials = 10
# Utils
batch_size = 128
data_dir = "../../Dataset/cifar" # MODIFY THIS IF YOU WANT
records = []
torch.manual_seed(0)
# Load data
train_transformer = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
)
test_transformer = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
)
train_loader = DataLoader(
datasets.CIFAR10(
data_dir, train=True, download=True, transform=train_transformer
),
batch_size=batch_size,
shuffle=True,
)
test_loader = DataLoader(
datasets.CIFAR10(data_dir, train=False, transform=test_transformer),
batch_size=batch_size,
shuffle=True,
)
logger = set_logger("functorch_benchmark", use_tb_logger=True)
# VotingClassifier
model = VotingClassifier(
estimator=LeNet5, n_estimators=n_estimators, cuda=False
)
# Set the optimizer
model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay)
# Training
tic = time.time()
model.fit(train_loader, epochs=epochs)
toc = time.time()
training_time = toc - tic
fmodel, params, buffers = model.vectorize() # Internally: fmodel, params, buffers = combine_state_for_ensemble(self.estimators_)
tic = time.time()
for _ in range(n_trials):
functorch_inference(test_loader, fmodel, params, buffers, model.device)
toc = time.time()
print("functorch: {:.3f}s".format(toc - tic))
tic = time.time()
for _ in range(n_trials):
pytorch_inference(test_loader, model)
toc = time.time()
print("pytorch: {:.3f}s".format(toc - tic))
The result is kind of strange:
- CPU (Overclocked)
- functorch: 34.716s
- pytorch: 30.519s
- GPU
- functorch: 43.139s
- pytorch: 43.905s
The performance gain is marginal compared to the official document. I will appreciate it very much if anyone could tell me where goes wrong. Thanks!
Metadata
Metadata
Assignees
Labels
No labels