Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
5ad0e48
amp init
Jan 29, 2021
621850f
docs complete
Jan 29, 2021
8ec807b
add tests
Jan 29, 2021
05864d6
Merge remote-tracking branch 'upstream/master' into engine/create_sup…
Jan 29, 2021
8aa972a
unscale_ + clip_grad_norm_, move checks to private func, more edge ca…
Jan 29, 2021
0e1c387
Merge remote-tracking branch 'upstream/master' into engine/create_sup…
Jan 29, 2021
8f1280a
scaler must be provided by user and its optional
Jan 30, 2021
7cba2a3
full docstring, on_cuda_amp test
Jan 30, 2021
599b20d
Merge remote-tracking branch 'upstream/master' into engine/create_sup…
Jan 31, 2021
afce317
Merge branch 'master' into engine/create_supervised_trainer
sdesrozis Feb 5, 2021
e113e46
Merge branch 'engine/create_supervised_trainer' of https://github.com…
Feb 9, 2021
c475a56
Merge remote-tracking branch 'upstream/master' into engine/create_sup…
Feb 13, 2021
eeab791
Merge remote-tracking branch 'upstream/master' into engine/create_sup…
Feb 15, 2021
2a125f2
extract into 4 functions for normal, amp, apex and tpu training
Feb 15, 2021
a79aae9
Merge remote-tracking branch 'upstream/master' into engine/create_sup…
Feb 15, 2021
f7d7ca7
explicit training step, independent mode
Feb 15, 2021
1b5f11c
mypy fix
Feb 15, 2021
0096942
fix(tests): pytest.raises checks with match, skipif < 1.6.0
Feb 15, 2021
1d1bbf3
fix(tests): align tests name, coverage append in tpu ci
Feb 15, 2021
48c3539
fix: remove ununsed amp import
Feb 15, 2021
cef8078
fix: docstring with default values, more tests, code review suggestions
Feb 16, 2021
cb1cd8f
fix(docs): update function names
Feb 16, 2021
96727d8
fix: docstring from code review
Feb 16, 2021
43bf0bb
fix: engine state only has attribute scaler if scaler is only True
Feb 16, 2021
ea8185a
Merge remote-tracking branch 'upstream/master' into engine/create_sup…
Feb 16, 2021
6ca8742
Merge branch 'master' into engine/create_supervised_trainer
sdesrozis Feb 16, 2021
358fde0
Merge remote-tracking branch 'upstream/master' into engine/create_sup…
Feb 17, 2021
01e30cd
Merge branch 'engine/create_supervised_trainer' of https://github.com…
Feb 17, 2021
588a44f
fix: address code review
Feb 17, 2021
173cd09
fix: create scaler or None in _check_arg
Feb 17, 2021
039a38f
fix: no return for scaler in supervised_training_step_amp
Feb 17, 2021
3a16519
Merge branch 'master' into engine/create_supervised_trainer
sdesrozis Feb 17, 2021
ed36d25
Merge remote-tracking branch 'upstream/master' into engine/create_sup…
Feb 18, 2021
4a95c3b
Merge branch 'master' into engine/create_supervised_trainer
sdesrozis Feb 19, 2021
707f34d
Merge branch 'engine/create_supervised_trainer' of https://github.com…
Feb 19, 2021
29650b2
Merge remote-tracking branch 'upstream/master' into engine/create_sup…
Feb 19, 2021
a80fc02
Merge remote-tracking branch 'upstream/master' into engine/create_sup…
Feb 19, 2021
b32413d
fix: gpu tests for apex
Feb 19, 2021
a2bd32c
fix: gpu tests for apex and amp
Feb 19, 2021
a0a32c7
Merge branch 'master' into engine/create_supervised_trainer
vfdev-5 Feb 19, 2021
412e9d3
chore: add more tests for coverage
Feb 20, 2021
ea877f7
fix: state only has scaler attribute if True
Feb 20, 2021
9f71a1a
fix: use prefix for scaler
Feb 20, 2021
c8d788a
Merge remote-tracking branch 'upstream/master' into engine/create_sup…
Feb 21, 2021
4b05cc1
Apply suggestions from code review
Feb 21, 2021
5d3175d
Merge branch 'engine/create_supervised_trainer' of https://github.com…
Feb 21, 2021
f5d0d31
Merge branch 'master' into engine/create_supervised_trainer
Feb 21, 2021
bbc450b
Merge branch 'engine/create_supervised_trainer' of https://github.com…
Feb 21, 2021
4387e32
fix: skip apex tests if apex is not installed
Feb 21, 2021
935c852
fix: skip apex error test if apex
Feb 21, 2021
832564f
fix: ImportError instead of ModuleNotFoundError
Feb 21, 2021
dbe7a3d
fix(docs): no device tpu in gpu functions and vice versa
Feb 21, 2021
e4681d8
fix: raise an error instead of warn for invalid scaler and amp_mode
Feb 21, 2021
c8fcae4
Merge branch 'master' into engine/create_supervised_trainer
vfdev-5 Feb 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/engine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ More details about those structures can be found in :doc:`concepts`.

.. autofunction:: ignite.engine.create_supervised_evaluator

.. autofunction:: ignite.engine.supervised_training_step

.. autofunction:: ignite.engine.supervised_training_step_amp

.. autofunction:: ignite.engine.supervised_training_step_apex

.. autofunction:: ignite.engine.supervised_training_step_tpu

Resuming the training
---------------------
Expand Down
304 changes: 278 additions & 26 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
from ignite.metrics import Metric
from ignite.utils import convert_tensor

if idist.has_xla_support:
import torch_xla.core.xla_model as xm


__all__ = [
"State",
"create_supervised_trainer",
Expand All @@ -25,6 +21,10 @@
"EventEnum",
"CallableEventWithFilter",
"RemovableEventHandle",
"supervised_training_step",
"supervised_training_step_amp",
"supervised_training_step_apex",
"supervised_training_step_tpu",
]


Expand All @@ -41,6 +41,233 @@ def _prepare_batch(
)


def supervised_training_step(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
loss_fn: Union[Callable, torch.nn.Module],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
) -> Callable:
"""Factory function for supervised training.

Args:
model (torch.nn.Module): the model to train.
optimizer (torch.optim.Optimizer): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to batches after starting the engine. Model *will not* be moved.
Device can be CPU, GPU.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.

Returns:
Callable: update function.

.. versionadded:: 0.5.0
"""

def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return output_transform(x, y, y_pred, loss)

return update


def supervised_training_step_amp(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
loss_fn: Union[Callable, torch.nn.Module],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
) -> Callable:
"""Factory function for supervised training using ``torch.cuda.amp``.

Args:
model (torch.nn.Module): the model to train.
optimizer (torch.optim.Optimizer): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to batches after starting the engine. Model *will not* be moved.
Device can be CPU, GPU.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
scaler (torch.cuda.amp.GradScaler, optional): GradScaler instance for gradient scaling. (default: None)

Returns:
Callable: update function

.. versionadded:: 0.5.0
"""

try:
from torch.cuda.amp import autocast
except ImportError:
raise ImportError("Please install torch>=1.6.0 to use amp_mode='amp'.")

def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
with autocast(enabled=True):
y_pred = model(x)
loss = loss_fn(y_pred, y)
if scaler:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
return output_transform(x, y, y_pred, loss)

return update


def supervised_training_step_apex(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
loss_fn: Union[Callable, torch.nn.Module],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
) -> Callable:
"""Factory function for supervised training using apex.

Args:
model (torch.nn.Module): the model to train.
optimizer (torch.optim.Optimizer): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to batches after starting the engine. Model *will not* be moved.
Device can be CPU, GPU.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.

Returns:
Callable: update function.

.. versionadded:: 0.5.0
"""

try:
from apex import amp as apex_amp
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install apex from https://github.com/nvidia/apex to use amp_mode='apex'.")

def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
return output_transform(x, y, y_pred, loss)

return update


def supervised_training_step_tpu(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
loss_fn: Union[Callable, torch.nn.Module],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
) -> Callable:
"""Factory function for supervised training using ``torch_xla``.

Args:
model (torch.nn.Module): the model to train.
optimizer (torch.optim.Optimizer): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to batches after starting the engine. Model *will not* be moved.
Device can be CPU, TPU.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.

Returns:
Callable: update function.

.. versionadded:: 0.5.0
"""
try:
import torch_xla.core.xla_model as xm
except ModuleNotFoundError:
raise ModuleNotFoundError("torch_xla cannot be imported, please install PyTorch XLA.")

def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
xm.optimizer_step(optimizer, barrier=True)
return output_transform(x, y, y_pred, loss)

return update


def _check_arg(
on_tpu: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.cuda.amp.GradScaler"]]
) -> Tuple[Optional[str], Optional["torch.cuda.amp.GradScaler"]]:
"""Checking tpu, amp and GradScaler instance combinations."""
if on_tpu and not idist.has_xla_support:
raise RuntimeError("In order to run on TPU, please install PyTorch XLA")

if amp_mode and on_tpu:
raise ValueError("amp_mode cannot be used with xla device. Consider using amp_mode=None or device='cuda'.")

if scaler:
if amp_mode != "amp":
raise ValueError(f"scaler argument is {scaler}, but amp_mode is {amp_mode}. Consider using amp_mode='amp'.")
elif amp_mode == "amp" and isinstance(scaler, bool):
try:
from torch.cuda.amp import GradScaler
except ImportError:
raise ImportError("Please install torch>=1.6.0 to use scaler argument.")
scaler = GradScaler(enabled=True)

if on_tpu:
return "tpu", None
elif scaler and amp_mode == "amp":
return amp_mode, scaler # type: ignore[return-value]
else:
return amp_mode, None


def create_supervised_trainer(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
Expand All @@ -50,12 +277,14 @@ def create_supervised_trainer(
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
deterministic: bool = False,
amp_mode: Optional[str] = None,
scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
) -> Engine:
"""Factory function for creating a trainer for supervised models.

Args:
model (`torch.nn.Module`): the model to train.
optimizer (`torch.optim.Optimizer`): the optimizer to use.
model (torch.nn.Module): the model to train.
optimizer (torch.optim.Optimizer): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to batches after starting the engine. Model *will not* be moved.
Expand All @@ -69,48 +298,71 @@ def create_supervised_trainer(
deterministic (bool, optional): if True, returns deterministic engine of type
:class:`~ignite.engine.deterministic.DeterministicEngine`, otherwise :class:`~ignite.engine.engine.Engine`
(default: False).
amp_mode (str, optional): can be ``amp`` or ``apex``, model and optimizer will be casted to float16 using
`torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_ for ``amp`` and
using `apex <https://nvidia.github.io/apex>`_ for ``apex``. (default: None)
scaler (torch.cuda.amp.GradScaler, bool, optional): GradScaler instance for gradient scaling if `torch>=1.6.0`
and ``amp_mode`` is ``amp``. If ``amp_mode`` is ``apex``, this argument will be ignored.
If True, will create default GradScaler. If GradScaler instance is passed, it will be used instead.
(default: False)

Note:
If ``scaler`` is True, GradScaler instance will be created internally and trainer state has attribute named
``scaler`` for that instance and can be used for saving and loading.

Note:
`engine.state.output` for this engine is defined by `output_transform` parameter and is the loss
of the processed batch by default.

.. warning::

The internal use of `device` has changed.
`device` will now *only* be used to move the input data to the correct device.
The `model` should be moved by the user before creating an optimizer.
For more information see:

- `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_

- `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_

.. warning::
If ``amp_mode='apex'`` , the model(s) and optimizer(s) must be initialized beforehand
since ``amp.initialize`` should be called after you have finished constructing your model(s)
and optimizer(s), but before you send your model through any DistributedDataParallel wrapper.

See more: https://nvidia.github.io/apex/amp.html#module-apex.amp

Returns:
Engine: a trainer engine with supervised update function.

.. versionchanged:: 0.5.0

- Added ``amp_mode`` argument for automatic mixed precision.
- Added ``scaler`` argument for gradient scaling.
"""

device_type = device.type if isinstance(device, torch.device) else device
on_tpu = "xla" in device_type if device_type is not None else False

if on_tpu and not idist.has_xla_support:
raise RuntimeError("In order to run on TPU, please install PyTorch XLA")

def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()

if on_tpu:
xm.optimizer_step(optimizer, barrier=True)
else:
optimizer.step()

return output_transform(x, y, y_pred, loss)
mode, _scaler = _check_arg(on_tpu, amp_mode, scaler)

if mode == "amp":
_update = supervised_training_step_amp(
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, _scaler
)
elif mode == "apex":
_update = supervised_training_step_apex(
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform
)
elif mode == "tpu":
_update = supervised_training_step_tpu(
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform
)
else:
_update = supervised_training_step(
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform
)

trainer = Engine(_update) if not deterministic else DeterministicEngine(_update)
if _scaler and scaler and isinstance(scaler, bool):
trainer.state.scaler = _scaler # type: ignore[attr-defined]

return trainer

Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ warn_unreachable = False
warn_unused_configs = True
warn_unused_ignores = True

[mypy-apex.*]
ignore_missing_imports = True

[mypy-clearml.*]
ignore_missing_imports = True

Expand Down
Loading