Skip to content

Commit

Permalink
Benchmark amp on Cifar10 (#917)
Browse files Browse the repository at this point in the history
* [WIP] added scripts and notebook

* autopep8 fix

* Added amp benchmark
- torch.cuda.amp
- nvidia/apex

* Added colab link

Co-authored-by: AutoPEP8 <>
  • Loading branch information
vfdev-5 committed Apr 13, 2020
1 parent efc45a7 commit d62c4e9
Show file tree
Hide file tree
Showing 6 changed files with 737 additions and 1 deletion.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,9 @@ Basic neural network training on MNIST dataset with/without `ignite.contrib` mod
CIFAR100](https://github.com/pytorch/ignite/blob/master/examples/notebooks/EfficientNet_Cifar100_finetuning.ipynb)
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/ignite/blob/master/examples/notebooks/Cifar10_Ax_hyperparam_tuning.ipynb) [Hyperparameters tuning with
Ax](https://github.com/pytorch/ignite/blob/master/examples/notebooks/Cifar10_Ax_hyperparam_tuning.ipynb)
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/ignite/blob/master/examples/notebooks/FastaiLRFinder_MNIST.ipynb) [Basic example of LR finder on MNIST](https://github.com/pytorch/ignite/blob/master/examples/notebooks/FastaiLRFinder_MNIST.ipynb)
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/ignite/blob/master/examples/notebooks/FastaiLRFinder_MNIST.ipynb) [Basic example of LR finder on
MNIST](https://github.com/pytorch/ignite/blob/master/examples/notebooks/FastaiLRFinder_MNIST.ipynb)
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/ignite/blob/master/examples/notebooks/Cifar100_bench_amp.ipynb) [Benchmark mixed precision training on Cifar100: torch.cuda.amp vs nvidia/apex](https://github.com/pytorch/ignite/blob/master/examples/notebooks/Cifar100_bench_amp.ipynb)

## Distributed CIFAR10 Example

Expand Down
71 changes: 71 additions & 0 deletions examples/contrib/cifar100_amp_benchmark/benchmark_fp32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import fire

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD

from torchvision.models import wide_resnet50_2

from ignite.engine import Events, Engine, create_supervised_evaluator, convert_tensor
from ignite.metrics import Accuracy, Loss
from ignite.handlers import Timer
from ignite.contrib.handlers import ProgressBar

from utils import get_train_eval_loaders


def main(dataset_path, batch_size=256, max_epochs=10):
assert torch.cuda.is_available()
assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled."
torch.backends.cudnn.benchmark = True

device = "cuda"

train_loader, test_loader, eval_train_loader = get_train_eval_loaders(dataset_path, batch_size=batch_size)

model = wide_resnet50_2(num_classes=100).to(device)
optimizer = SGD(model.parameters(), lr=0.01)
criterion = CrossEntropyLoss().to(device)

def train_step(engine, batch):
x = convert_tensor(batch[0], device, non_blocking=True)
y = convert_tensor(batch[1], device, non_blocking=True)

optimizer.zero_grad()

y_pred = model(x)
loss = criterion(y_pred, y)
loss.backward()

optimizer.step()

return loss.item()

trainer = Engine(train_step)
timer = Timer(average=True)
timer.attach(trainer, step=Events.EPOCH_COMPLETED)
ProgressBar(persist=True).attach(trainer, output_transform=lambda out: {"batch loss": out})

metrics = {"Accuracy": Accuracy(), "Loss": Loss(criterion)}

evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True)

def log_metrics(engine, title):
for name in metrics:
print("\t{} {}: {:.2f}".format(title, name, engine.state.metrics[name]))

@trainer.on(Events.COMPLETED)
def run_validation(_):
print("- Mean elapsed time for 1 epoch: {}".format(timer.value()))
print("- Metrics:")
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "Train"):
evaluator.run(eval_train_loader)

with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "Test"):
evaluator.run(test_loader)

trainer.run(train_loader, max_epochs=max_epochs)


if __name__ == "__main__":
fire.Fire(main)
78 changes: 78 additions & 0 deletions examples/contrib/cifar100_amp_benchmark/benchmark_nvidia_apex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import fire

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD

from torchvision.models import wide_resnet50_2

from apex import amp

from ignite.engine import Events, Engine, create_supervised_evaluator, convert_tensor
from ignite.metrics import Accuracy, Loss
from ignite.handlers import Timer
from ignite.contrib.handlers import ProgressBar

from utils import get_train_eval_loaders


def main(dataset_path, batch_size=256, max_epochs=10, opt="O1"):
assert torch.cuda.is_available()
assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled."
torch.backends.cudnn.benchmark = True

device = "cuda"

train_loader, test_loader, eval_train_loader = get_train_eval_loaders(dataset_path, batch_size=batch_size)

model = wide_resnet50_2(num_classes=100).to(device)
optimizer = SGD(model.parameters(), lr=0.01)
criterion = CrossEntropyLoss().to(device)

model, optimizer = amp.initialize(model, optimizer, opt_level=opt)

def train_step(engine, batch):
x = convert_tensor(batch[0], device, non_blocking=True)
y = convert_tensor(batch[1], device, non_blocking=True)

optimizer.zero_grad()

y_pred = model(x)
loss = criterion(y_pred, y)

# Runs the forward pass with autocasting.
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()

optimizer.step()

return loss.item()

trainer = Engine(train_step)
timer = Timer(average=True)
timer.attach(trainer, step=Events.EPOCH_COMPLETED)
ProgressBar(persist=True).attach(trainer, output_transform=lambda out: {"batch loss": out})

metrics = {"Accuracy": Accuracy(), "Loss": Loss(criterion)}

evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True)

def log_metrics(engine, title):
for name in metrics:
print("\t{} {}: {:.2f}".format(title, name, engine.state.metrics[name]))

@trainer.on(Events.COMPLETED)
def run_validation(_):
print("- Mean elapsed time for 1 epoch: {}".format(timer.value()))
print("- Metrics:")
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "Train"):
evaluator.run(eval_train_loader)

with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "Test"):
evaluator.run(test_loader)

trainer.run(train_loader, max_epochs=max_epochs)


if __name__ == "__main__":
fire.Fire(main)
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import fire

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD

# Creates a GradScaler once at the beginning of training.
from torch.cuda.amp import GradScaler, autocast

from torchvision.models import wide_resnet50_2

from ignite.engine import Events, Engine, create_supervised_evaluator, convert_tensor
from ignite.metrics import Accuracy, Loss
from ignite.handlers import Timer
from ignite.contrib.handlers import ProgressBar

from utils import get_train_eval_loaders


def main(dataset_path, batch_size=256, max_epochs=10):
assert torch.cuda.is_available()
assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled."
torch.backends.cudnn.benchmark = True

device = "cuda"

train_loader, test_loader, eval_train_loader = get_train_eval_loaders(dataset_path, batch_size=batch_size)

model = wide_resnet50_2(num_classes=100).to(device)
optimizer = SGD(model.parameters(), lr=0.01)
criterion = CrossEntropyLoss().to(device)

scaler = GradScaler()

def train_step(engine, batch):
x = convert_tensor(batch[0], device, non_blocking=True)
y = convert_tensor(batch[1], device, non_blocking=True)

optimizer.zero_grad()

# Runs the forward pass with autocasting.
with autocast():
y_pred = model(x)
loss = criterion(y_pred, y)

# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same precision that autocast used for corresponding forward ops.
scaler.scale(loss).backward()

# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)

# Updates the scale for next iteration.
scaler.update()

return loss.item()

trainer = Engine(train_step)
timer = Timer(average=True)
timer.attach(trainer, step=Events.EPOCH_COMPLETED)
ProgressBar(persist=True).attach(trainer, output_transform=lambda out: {"batch loss": out})

metrics = {"Accuracy": Accuracy(), "Loss": Loss(criterion)}

evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True)

def log_metrics(engine, title):
for name in metrics:
print("\t{} {}: {:.2f}".format(title, name, engine.state.metrics[name]))

@trainer.on(Events.COMPLETED)
def run_validation(_):
print("- Mean elapsed time for 1 epoch: {}".format(timer.value()))
print("- Metrics:")
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "Train"):
evaluator.run(eval_train_loader)

with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "Test"):
evaluator.run(test_loader)

trainer.run(train_loader, max_epochs=max_epochs)


if __name__ == "__main__":
fire.Fire(main)
52 changes: 52 additions & 0 deletions examples/contrib/cifar100_amp_benchmark/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import random

from torchvision.datasets.cifar import CIFAR100
from torchvision.transforms import Compose, RandomCrop, Pad, RandomHorizontalFlip
from torchvision.transforms import ToTensor, Normalize, RandomErasing

from torch.utils.data import Subset, DataLoader


def get_train_eval_loaders(path, batch_size=256):
"""Setup the dataflow:
- load CIFAR100 train and test datasets
- setup train/test image transforms
- horizontally flipped randomly and augmented using cutout.
- each mini-batch contained 256 examples
- setup train/test data loaders
Returns:
train_loader, test_loader, eval_train_loader
"""
train_transform = Compose(
[
Pad(4),
RandomCrop(32),
RandomHorizontalFlip(),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
RandomErasing(),
]
)

test_transform = Compose([ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

train_dataset = CIFAR100(root=path, train=True, transform=train_transform, download=True)
test_dataset = CIFAR100(root=path, train=False, transform=test_transform, download=False)

train_eval_indices = [random.randint(0, len(train_dataset) - 1) for i in range(len(test_dataset))]
train_eval_dataset = Subset(train_dataset, train_eval_indices)

train_loader = DataLoader(
train_dataset, batch_size=batch_size, num_workers=12, shuffle=True, drop_last=True, pin_memory=True
)

test_loader = DataLoader(
test_dataset, batch_size=batch_size, num_workers=12, shuffle=False, drop_last=False, pin_memory=True
)

eval_train_loader = DataLoader(
train_eval_dataset, batch_size=batch_size, num_workers=12, shuffle=False, drop_last=False, pin_memory=True
)

return train_loader, test_loader, eval_train_loader

0 comments on commit d62c4e9

Please sign in to comment.