From 69524b506a79223d365bdd51a1a14d38bab1c448 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Mon, 30 Nov 2020 13:17:33 +0800 Subject: [PATCH 1/4] support amp (auto mixed precision) --- test/args_parse.py | 1 + test/test_amp.py | 61 ++ test/test_train_mp_mnist.py | 2 + torch_xla/amp/__init__.py | 2 + torch_xla/amp/autocast_mode.py | 238 ++++++++ torch_xla/amp/grad_scaler.py | 563 ++++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 23 + torch_xla/csrc/aten_xla_type.h | 10 + torch_xla/csrc/batch_norm.cpp | 45 +- ...p_foreach_non_finite_check_and_unscale.cpp | 50 ++ ...amp_foreach_non_finite_check_and_unscale.h | 20 + torch_xla/csrc/ops/amp_update_scale.cpp | 44 ++ torch_xla/csrc/ops/amp_update_scale.h | 20 + torch_xla/csrc/ops/xla_ops.cpp | 3 + torch_xla/csrc/ops/xla_ops.h | 2 + torch_xla/csrc/tensor.h | 11 + torch_xla/csrc/tensor_methods.cpp | 46 ++ torch_xla/csrc/xla_lower_util.cpp | 63 ++ torch_xla/csrc/xla_lower_util.h | 6 + 19 files changed, 1208 insertions(+), 2 deletions(-) create mode 100644 test/test_amp.py create mode 100644 torch_xla/amp/__init__.py create mode 100644 torch_xla/amp/autocast_mode.py create mode 100644 torch_xla/amp/grad_scaler.py create mode 100644 torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp create mode 100644 torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h create mode 100644 torch_xla/csrc/ops/amp_update_scale.cpp create mode 100644 torch_xla/csrc/ops/amp_update_scale.h diff --git a/test/args_parse.py b/test/args_parse.py index 0c6479ce8d7..e69d9960b87 100644 --- a/test/args_parse.py +++ b/test/args_parse.py @@ -32,6 +32,7 @@ def parse_common_options(datadir=None, parser.add_argument('--fake_data', action='store_true') parser.add_argument('--tidy', action='store_true') parser.add_argument('--metrics_debug', action='store_true') + parser.add_argument('--amp', action='store_true') if opts: for name, aopts in opts: parser.add_argument(name, **aopts) diff --git a/test/test_amp.py b/test/test_amp.py new file mode 100644 index 00000000000..af3aaf9b760 --- /dev/null +++ b/test/test_amp.py @@ -0,0 +1,61 @@ +import torch +import torch_xla.core.xla_model as xm +import unittest + + +class TestAmp(unittest.TestCase): + + def test_amp_update_scale(self): + device = xm.xla_device() + growth_tracker = torch.tensor(0, dtype=torch.int32, device=device) + current_scale = torch.tensor(4, dtype=torch.float, device=device) + found_inf = torch.tensor(0, dtype=torch.float, device=device) + scale_growth_factor = 2.0 + scale_backoff_factor = 0.5 + growth_interval = 3 + current_scale = torch._amp_update_scale(growth_tracker, current_scale, + found_inf, scale_growth_factor, + scale_backoff_factor, + growth_interval) + self.assertAlmostEqual(current_scale.item(), 4.0) + self.assertEqual(growth_tracker.item(), 1) + current_scale = torch._amp_update_scale(growth_tracker, current_scale, + found_inf, scale_growth_factor, + scale_backoff_factor, + growth_interval) + self.assertAlmostEqual(current_scale.item(), 4.0) + self.assertEqual(growth_tracker.item(), 2) + current_scale = torch._amp_update_scale(growth_tracker, current_scale, + found_inf, scale_growth_factor, + scale_backoff_factor, + growth_interval) + self.assertAlmostEqual(current_scale.item(), 8.0) + self.assertEqual(growth_tracker.item(), 0) + found_inf = torch.tensor(1, dtype=torch.float, device=device) + current_scale = torch._amp_update_scale(growth_tracker, current_scale, + found_inf, scale_growth_factor, + scale_backoff_factor, + growth_interval) + self.assertAlmostEqual(current_scale.item(), 4.0) + self.assertEqual(growth_tracker.item(), 0) + + def test_amp_foreach_non_finite_check_and_unscale(self): + device = xm.xla_device() + grads = [torch.tensor([1, 2, 3, 4], dtype=torch.float, device=device)] + inv_scale = torch.tensor(0.2, dtype=torch.float, device=device) + found_inf = torch.tensor(0, dtype=torch.float, device=device) + torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, + inv_scale) + self.assertAlmostEqual(found_inf.item(), 0.0) + + grads = [ + torch.tensor([1, 2, 3, float('nan')], dtype=torch.float, device=device), + torch.tensor([1, 2, 3, 5], dtype=torch.float, device=device) + ] + torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, + inv_scale) + self.assertAlmostEqual(found_inf.item(), 1.0) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 7c556191294..05f9adf82d2 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -24,6 +24,7 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.test.test_utils as test_utils +from torch_xla.amp import autocast, GradScaler class MNIST(nn.Module): @@ -118,6 +119,7 @@ def train_mnist(flags, **kwargs): writer = test_utils.get_summary_writer(flags.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() + scaler = GradScaler() def train_loop_fn(loader): tracker = xm.RateTracker() diff --git a/torch_xla/amp/__init__.py b/torch_xla/amp/__init__.py new file mode 100644 index 00000000000..de6de31877e --- /dev/null +++ b/torch_xla/amp/__init__.py @@ -0,0 +1,2 @@ +from .autocast_mode import autocast, custom_fwd, custom_bwd # noqa: F401 +from .grad_scaler import GradScaler # noqa: F401 \ No newline at end of file diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py new file mode 100644 index 00000000000..fc995f6b747 --- /dev/null +++ b/torch_xla/amp/autocast_mode.py @@ -0,0 +1,238 @@ +import torch +import functools +import warnings +try: + import numpy as np +except ModuleNotFoundError: + np = None +from torch._six import container_abcs, string_classes + + +class autocast(object): + r""" + Instances of :class:`autocast` serve as context managers or decorators that + allow regions of your script to run in mixed precision. + + In these regions, CUDA ops run in an op-specific dtype chosen by autocast + to improve performance while maintaining accuracy. + See the :ref:`Autocast Op Reference` for details. + + When entering an autocast-enabled region, Tensors may be any type. + You should not call ``.half()`` on your model(s) or inputs when using autocasting. + + :class:`autocast` should wrap only the forward pass(es) of your network, including the loss + computation(s). Backward passes under autocast are not recommended. + Backward ops run in the same type that autocast used for corresponding forward ops. + + Example:: + + # Creates model and optimizer in default precision + model = Net().cuda() + optimizer = optim.SGD(model.parameters(), ...) + + for input, target in data: + optimizer.zero_grad() + + # Enables autocasting for the forward pass (model + loss) + with autocast(): + output = model(input) + loss = loss_fn(output, target) + + # Exits the context manager before backward() + loss.backward() + optimizer.step() + + See the :ref:`Automatic Mixed Precision examples` for usage (along with gradient scaling) + in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions). + + :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model:: + + class AutocastModel(nn.Module): + ... + @autocast() + def forward(self, input): + ... + + Floating-point Tensors produced in an autocast-enabled region may be ``float16``. + After returning to an autocast-disabled region, using them with floating-point + Tensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s) + produced in the autocast region back to ``float32`` (or other dtype if desired). + If a Tensor from the autocast region is already ``float32``, the cast is a no-op, + and incurs no additional overhead. Example:: + + # Creates some tensors in default dtype (here assumed to be float32) + a_float32 = torch.rand((8, 8), device="cuda") + b_float32 = torch.rand((8, 8), device="cuda") + c_float32 = torch.rand((8, 8), device="cuda") + d_float32 = torch.rand((8, 8), device="cuda") + + with autocast(): + # torch.mm is on autocast's list of ops that should run in float16. + # Inputs are float32, but the op runs in float16 and produces float16 output. + # No manual casts are required. + e_float16 = torch.mm(a_float32, b_float32) + # Also handles mixed input types + f_float16 = torch.mm(d_float32, e_float16) + + # After exiting autocast, calls f_float16.float() to use with d_float32 + g_float32 = torch.mm(d_float32, f_float16.float()) + + Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe, + please file an issue. + + ``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions. + Locally disabling autocast can be useful, for example, if you want to force a subregion + to run in a particular ``dtype``. Disabling autocast gives you explicit control over + the execution type. In the subregion, inputs from the surrounding region + should be cast to ``dtype`` before use:: + + # Creates some tensors in default dtype (here assumed to be float32) + a_float32 = torch.rand((8, 8), device="cuda") + b_float32 = torch.rand((8, 8), device="cuda") + c_float32 = torch.rand((8, 8), device="cuda") + d_float32 = torch.rand((8, 8), device="cuda") + + with autocast(): + e_float16 = torch.mm(a_float32, b_float32) + + with autocast(enabled=False): + # Calls e_float16.float() to ensure float32 execution + # (necessary because e_float16 was created in an autocasted region) + f_float32 = torch.mm(c_float32, e_float16.float()) + + # No manual casts are required when re-entering the autocast-enabled region. + # torch.mm again runs in float16 and produces float16 output, regardless of input types. + g_float16 = torch.mm(d_float32, f_float32) + + The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decorator + must be invoked in that thread. This affects :class:`torch.nn.DataParallel` and + :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process + (see :ref:`Working with Multiple GPUs`). + + Arguments: + enabled(bool, optional, default=True): Whether autocasting should be enabled in the region. + """ + + def __init__(self, enabled=True): + self._enabled = enabled + + def __enter__(self): + self.prev = torch.is_autocast_enabled() + torch.set_autocast_enabled(self._enabled) + torch.autocast_increment_nesting() + + def __exit__(self, *args): + # Drop the cache when we exit to a nesting level that's outside any instance of autocast. + if torch.autocast_decrement_nesting() == 0: + torch.clear_autocast_cache() + torch.set_autocast_enabled(self.prev) + return False + + def __call__(self, func): + + @functools.wraps(func) + def decorate_autocast(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return decorate_autocast + + +# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which +# may be falsely detected as "Iterables." +def _cast(value, dtype): + if isinstance(value, torch.Tensor): + is_eligible = ( + value.is_floating_point() and value.is_cuda and + (value.dtype is not torch.float64)) + return value.to(dtype) if is_eligible else value + elif isinstance(value, string_classes): + return value + elif np is not None and isinstance(value, np.ndarray): + return value + elif isinstance(value, container_abcs.Mapping): + return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()} + elif isinstance(value, container_abcs.Iterable): + iterable = map(lambda v: _cast(v, dtype), value) + if isinstance(value, list) or isinstance(value, tuple): + return type(value)(iterable) + else: + return iterable + else: + return value + + +# custom_fwd is a decorator that may or may not be used with arguments, following +# https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument. +# this works: +# @custom_fwd +# def forward(...): +# this also works: +# @custom_fwd(cast_inputs=torch.float) +# def forward(...): +# TODO: when python 2 support is dropped, change the signature to +# def custom_fwd(fwd=None, *, cast_inputs=None) with internal changes following the link above. +def custom_fwd(fwd=None, **kwargs): + """ + Helper decorator for ``forward`` methods of custom autograd functions (subclasses of + :class:`torch.autograd.Function`). See the :ref:`example page` for more detail. + + Arguments: + cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, + when ``forward`` runs in an autocast-enabled region, casts incoming + floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected), + then executes ``forward`` with autocast disabled. + If ``None``, ``forward``'s internal ops execute with the current autocast state. + + .. note:: + If the decorated ``forward`` is called outside an autocast-enabled region, + :func:`custom_fwd` is a no-op and ``cast_inputs`` has no effect. + """ + if fwd is None: + if len(kwargs) == 0: + cast_inputs = None + else: + assert len(kwargs) == 1 + cast_inputs = kwargs["cast_inputs"] + return functools.partial(custom_fwd, cast_inputs=cast_inputs) + + if len(kwargs) == 0: + cast_inputs = None + else: + assert len(kwargs) == 1 + cast_inputs = kwargs["cast_inputs"] + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if cast_inputs is None: + args[0]._fwd_used_autocast = torch.is_autocast_enabled() + return fwd(*args, **kwargs) + else: + autocast_context = torch.is_autocast_enabled() + args[0]._fwd_used_autocast = False + if autocast_context: + with autocast(enabled=False): + return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs)) + else: + return fwd(*args, **kwargs) + + return decorate_fwd + + +# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate +# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match +# cast_inputs supplied to custom_fwd. +def custom_bwd(bwd): + """ + Helper decorator for backward methods of custom autograd functions (subclasses of + :class:`torch.autograd.Function`). + Ensures that ``backward`` executes with the same autocast state as ``forward``. + See the :ref:`example page` for more detail. + """ + + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with autocast(args[0]._fwd_used_autocast): + return bwd(*args, **kwargs) + + return decorate_bwd diff --git a/torch_xla/amp/grad_scaler.py b/torch_xla/amp/grad_scaler.py new file mode 100644 index 00000000000..a70096d8d00 --- /dev/null +++ b/torch_xla/amp/grad_scaler.py @@ -0,0 +1,563 @@ +import torch +from collections import defaultdict +from torch._six import container_abcs +import warnings +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +import torch_xla.core.xla_model as xm + + +class _MultiDeviceReplicator(object): + """ + Lazily serves copies of a tensor to requested devices. Copies are cached per-device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + self.master = master_tensor + self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} + + def get(self, device) -> torch.Tensor: + retval = self._per_device_tensors.get(device, None) + if retval is None: + retval = self.master.to(device=device, non_blocking=True, copy=True) + self._per_device_tensors[device] = retval + return retval + + +# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, +# as well as associated "enum" values. Prefers defining these at top level because +# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. +# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler +# causes a circular reference, which we'd rather avoid. +class OptState(Enum): + READY = 0 + UNSCALED = 1 + STEPPED = 2 + + +def _refresh_per_optimizer_state(): + return {"stage": OptState.READY, "found_inf_per_device": {}} + + +class GradScaler(object): + _scale: Optional[torch.Tensor] + _grows_tracker: Optional[torch.Tensor] + _per_optimizer_states: Dict[int, Dict[str, Any]] + """ + An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling + conveniently. + + * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. + * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. + * ``scaler.update()`` updates ``scaler``'s scale factor. + + Example:: + + # Creates a GradScaler once at the beginning of training. + scaler = GradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See the :ref:`Automatic Mixed Precision examples` for usage + (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, + and multiple losses/optimizers. + + ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, + a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if + the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used + without incurring inf or NaN gradient values. + ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every + ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). + + * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params + themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. + + * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. + If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by + ``growth_factor``. + + The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its + value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these + iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). + + Arguments: + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + """ + + def __init__(self, + init_scale=2.**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + enabled=True): + self._enabled = enabled + + if self._enabled: + assert growth_factor > 1.0, "The growth factor must be > 1.0." + assert backoff_factor < 1.0, "The backoff factor must be < 1.0." + + self._init_scale = init_scale + # self._scale will be lazily initialized during the first call to scale() + self._scale = None + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._init_growth_tracker = 0 + # self._growth_tracker will be lazily initialized during the first call to scale() + self._growth_tracker = None + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _check_scale_growth_tracker( + self, funcname) -> Tuple[torch.Tensor, torch.Tensor]: + fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." + assert self._scale is not None, "Attempted {} but _scale is None. ".format( + funcname) + fix + assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format( + funcname) + fix + return (self._scale, self._growth_tracker) + + def _lazy_init_scale_growth_tracker(self, dev): + assert self._growth_tracker is None, "_growth_tracker initialized before _scale" + self._scale = torch.tensor( + self._init_scale, dtype=torch.float32, device=dev) + self._growth_tracker = torch.tensor( + self._init_growth_tracker, dtype=torch.int32, device=dev) + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Arguments: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + + # Short-circuit for the common case. + if isinstance(outputs, torch.Tensor): + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + # return outputs * self._scale.to(device=outputs.device, non_blocking=True) + return outputs * self._scale + + # Invoke the more complex machinery only if we're treating multiple outputs. + stash: List[_MultiDeviceReplicator] = [ + ] # holds a reference that can be overwritten by apply_scale + + def apply_scale(val): + if isinstance(val, torch.Tensor): + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + elif isinstance(val, container_abcs.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, list) or isinstance(val, tuple): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict( + lambda: defaultdict(list)) # type: ignore[var-annotated] + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # TODO: is there a way to split by device and dtype without appending in the inner loop? + per_device_and_dtype_grads[to_unscale.device][ + to_unscale.dtype].append(to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._amp_foreach_non_finite_check_and_unscale_( + grads, per_device_found_inf.get(device), + per_device_inv_scale.get(device)) + + return per_device_found_inf._per_device_tensors + + def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + + Arguments: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + + .. note:: + :meth:`unscale_` does not incur a CPU-GPU sync. + + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update()." + ) + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = self._scale.double().reciprocal().float() + found_inf = torch.tensor( + 0.0, dtype=torch.float32, device=self._scale.device) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_( + optimizer, inv_scale, found_inf, False) + optimizer_state["stage"] = OptState.UNSCALED + + def step(self, optimizer, *args, **kwargs): + """ + :meth:`step` carries out the following two operations: + + 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` + earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. + 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled + gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. + + ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. + + Returns the return value of ``optimizer.step(*args, **kwargs)``. + + Arguments: + optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. + args: Any arguments. + kwargs: Any keyword arguments. + + .. warning:: + Closure use is not currently supported. + """ + if (not self._enabled): + return optimizer.step(*args, **kwargs) + + if "closure" in kwargs: + raise RuntimeError( + "Closure use is not currently supported if GradScaler is enabled.") + + self._check_scale_growth_tracker("step") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError( + "step() has already been called since the last update().") + + retval = None + + if (hasattr(optimizer, "_step_supports_amp_scaling") and + optimizer._step_supports_amp_scaling): + # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. + # The contract with custom optimizers is that their step() should accept an additional, + # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: + # it can query its own state, invoke unscale_ on itself, etc + retval = optimizer.step(*args, **dict(kwargs, grad_scaler=self)) + optimizer_state["stage"] = OptState.STEPPED + return retval + + if optimizer_state["stage"] is OptState.READY: + self.unscale_(optimizer) + + assert len(optimizer_state["found_inf_per_device"] + ) > 0, "No inf checks were recorded for this optimizer." + + # call mark_step before v.item() to make sure the gradients could be reused in optimizer.step + xm.mark_step() + if not sum( + v.item() for v in optimizer_state["found_inf_per_device"].values()): + retval = optimizer.step(*args, **kwargs) + + optimizer_state["stage"] = OptState.STEPPED + + return retval + + def update(self, new_scale=None): + """ + Updates the scale factor. + + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + + Passing ``new_scale`` sets the scale directly. + + Arguments: + new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. + + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale = torch.tensor( + new_scale, dtype=torch.float32, device=_scale.device) + else: + reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." + assert isinstance( + new_scale, + torch.cuda.FloatTensor), reason # type: ignore[attr-defined] + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale = new_scale + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + self._scale = torch._amp_update_scale(_growth_tracker, _scale, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval) + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _get_scale_async(self): + return self._scale + + def get_scale(self): + """ + Returns a Python float containing the current scale, or 1.0 if scaling is disabled. + + .. warning:: + :meth:`get_scale` incurs a CPU-GPU sync. + """ + if self._enabled: + return self._init_scale if self._scale is None else self._get_scale_async( + ).item() + else: + return 1.0 + + def get_growth_factor(self): + r""" + Returns a Python float containing the scale growth factor. + """ + return self._growth_factor + + def set_growth_factor(self, new_factor): + r""" + Arguments: + new_scale (float): Value to use as the new scale growth factor. + """ + self._growth_factor = new_factor + + def get_backoff_factor(self): + r""" + Returns a Python float containing the scale backoff factor. + """ + return self._backoff_factor + + def set_backoff_factor(self, new_factor): + r""" + Arguments: + new_scale (float): Value to use as the new scale backoff factor. + """ + self._backoff_factor = new_factor + + def get_growth_interval(self): + r""" + Returns a Python int containing the growth interval. + """ + return self._growth_interval + + def set_growth_interval(self, new_interval): + r""" + Arguments: + new_interval (int): Value to use as the new growth interval. + """ + self._growth_interval = new_interval + + def _get_growth_tracker(self): + if self._enabled: + return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item( + ) + else: + return 0 + + def is_enabled(self): + r""" + Returns a bool indicating whether this instance is enabled. + """ + return self._enabled + + def state_dict(self): + r""" + Returns the state of the scaler as a :class:`dict`. It contains five entries: + + * ``"scale"`` - a Python float containing the current scale + * ``"growth_factor"`` - a Python float containing the current growth factor + * ``"backoff_factor"`` - a Python float containing the current backoff factor + * ``"growth_interval"`` - a Python int containing the current growth interval + * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. + + If this instance is not enabled, returns an empty dict. + + .. note:: + If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` + should be called after :meth:`update`. + """ + return { + "scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker() + } if self._enabled else {} + + def load_state_dict(self, state_dict): + r""" + Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. + + Arguments: + state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. + """ + if not self._enabled: + return + + if len(state_dict) == 0: + raise RuntimeError( + "The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler.") + + self._init_scale = state_dict["scale"] + if self._scale is not None: + self._scale.fill_(state_dict["scale"]) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._growth_interval = state_dict["growth_interval"] + self._init_growth_tracker = state_dict["_growth_tracker"] + if self._growth_tracker is not None: + self._growth_tracker.fill_(state_dict["_growth_tracker"]) + + def __getstate__(self): + state = self.__dict__.copy() + if self._enabled: + assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ + "of an iteration, or at the end after scaler.update()." + # Pickling _scale and _growth_tracker Tensors directly triggers + # "warnings.warn("pickle support for Storage will be removed in 1.5..." + # so instead, we set the unpickled instance up to reinitialize them lazily. + state['_init_scale'] = self.get_scale() + state['_init_growth_tracker'] = self._get_growth_tracker() + state['_scale'] = None + state['_growth_tracker'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def _check_inf_per_device(self, optimizer): + _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") + + dummy_inv_scale = torch.tensor( + 1.0, dtype=torch.float32, device=_scale.device) + found_inf = torch.tensor(0.0, dtype=torch.float32, device=_scale.device) + + self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ + self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] + + def _found_inf_per_device(self, optimizer): + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 649f1dcb59c..76ac397b1cb 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -302,6 +302,29 @@ at::Tensor AtenXlaType::_adaptive_avg_pool2d_backward( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self))); } +void AtenXlaType::_amp_foreach_non_finite_check_and_unscale_( + at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale) { + XLA_FN_COUNTER("xla::"); + auto xla_self = bridge::GetXlaTensors(self); + absl::Span self_tensor{xla_self}; + XLATensor found_inf_tensor = bridge::GetXlaTensor(found_inf); + XLATensor::_amp_foreach_non_finite_check_and_unscale_( + self_tensor, found_inf_tensor, bridge::GetXlaTensor(inv_scale)); +} + +at::Tensor AtenXlaType::_amp_update_scale(at::Tensor& growth_tracker, + const at::Tensor& current_scale, + const at::Tensor& found_inf, + double scale_growth_factor, + double scale_backoff_factor, + int64_t growth_interval) { + XLA_FN_COUNTER("xla::"); + return bridge::AtenFromXlaTensor(XLATensor::_amp_update_scale( + bridge::GetXlaTensor(growth_tracker), bridge::GetXlaTensor(current_scale), + bridge::GetXlaTensor(found_inf), scale_growth_factor, + scale_backoff_factor, growth_interval)); +} + at::Tensor AtenXlaType::_copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) { XLA_FN_COUNTER("xla::"); diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index d1924bf665c..0e70aeec8b9 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -43,6 +43,16 @@ class AtenXlaType { static at::Tensor _adaptive_avg_pool2d_backward(const at::Tensor& grad_output, const at::Tensor& self); + static void _amp_foreach_non_finite_check_and_unscale_( + at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale); + + static at::Tensor _amp_update_scale(at::Tensor& growth_tracker, + const at::Tensor& current_scale, + const at::Tensor& found_inf, + double scale_growth_factor, + double scale_backoff_factor, + int64_t growth_interval); + static at::Tensor _copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking); diff --git a/torch_xla/csrc/batch_norm.cpp b/torch_xla/csrc/batch_norm.cpp index de52b50bda7..92250fa8ae2 100644 --- a/torch_xla/csrc/batch_norm.cpp +++ b/torch_xla/csrc/batch_norm.cpp @@ -1,11 +1,25 @@ #include "torch_xla/csrc/batch_norm.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "torch_xla/csrc/helpers.h" namespace torch_xla { namespace { +bool IsF32BatchNormWithFP16Inputs(const xla::XlaOp& input, + const xla::XlaOp& weight) { + xla::XlaBuilder* builder = input.builder(); + if (builder->GetShape(input).ok() && builder->GetShape(weight).ok() && + builder->GetShape(input).ValueOrDie().element_type() == + xla::PrimitiveType::F16 && + builder->GetShape(weight).ValueOrDie().element_type() == + xla::PrimitiveType::F32) { + return true; + } + return false; +} + xla::XlaOp VarianceRecover(xla::XlaOp invstd, float eps_value) { const xla::Shape& invstd_shape = XlaHelpers::ShapeOfXlaOp(invstd); xla::XlaOp eps = XlaHelpers::ScalarValue( @@ -26,31 +40,58 @@ xla::XlaOp BatchNormVarianceInvert(xla::XlaOp variance, float eps_value) { BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight, xla::XlaOp bias, float eps_value) { + bool is_batchnorm_with_fp16_inputs = + IsF32BatchNormWithFP16Inputs(input, weight); + if (is_batchnorm_with_fp16_inputs) { + input = xla::ConvertElementType(input, xla::PrimitiveType::F32); + } xla::XlaOp outputs = xla::BatchNormTraining(input, weight, bias, eps_value, /*feature_index=*/1); xla::XlaOp output = xla::GetTupleElement(outputs, 0); xla::XlaOp batch_mean = xla::GetTupleElement(outputs, 1); xla::XlaOp batch_variance = xla::GetTupleElement(outputs, 2); + if (is_batchnorm_with_fp16_inputs) { + output = xla::ConvertElementType(output, xla::PrimitiveType::F16); + } return {output, batch_mean, batch_variance}; } xla::XlaOp BuildBatchNormInference(xla::XlaOp input, xla::XlaOp weight, xla::XlaOp bias, xla::XlaOp mean, xla::XlaOp variance, float eps_value) { - return xla::BatchNormInference(input, weight, bias, mean, variance, eps_value, - /*feature_index=*/1); + bool is_batchnorm_with_fp16_inputs = + IsF32BatchNormWithFP16Inputs(input, weight); + if (is_batchnorm_with_fp16_inputs) { + input = xla::ConvertElementType(input, xla::PrimitiveType::F32); + } + xla::XlaOp output = + xla::BatchNormInference(input, weight, bias, mean, variance, eps_value, + /*feature_index=*/1); + if (is_batchnorm_with_fp16_inputs) { + output = xla::ConvertElementType(output, xla::PrimitiveType::F16); + } + return output; } BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input, xla::XlaOp weight, xla::XlaOp save_mean, xla::XlaOp save_invstd, bool training, float eps_value) { + bool is_batchnorm_with_fp16_inputs = + IsF32BatchNormWithFP16Inputs(input, weight); + if (is_batchnorm_with_fp16_inputs) { + input = xla::ConvertElementType(input, xla::PrimitiveType::F32); + grad = xla::ConvertElementType(grad, xla::PrimitiveType::F32); + } xla::XlaOp grads = xla::BatchNormGrad(input, weight, save_mean, VarianceRecover(save_invstd, eps_value), grad, eps_value, /*feature_index=*/1); xla::XlaOp grad_input = xla::GetTupleElement(grads, 0); xla::XlaOp grad_weight = xla::GetTupleElement(grads, 1); xla::XlaOp grad_bias = xla::GetTupleElement(grads, 2); + if (is_batchnorm_with_fp16_inputs) { + grad_input = xla::ConvertElementType(grad_input, xla::PrimitiveType::F16); + } return {grad_input, grad_weight, grad_bias}; } diff --git a/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp new file mode 100644 index 00000000000..777fee03712 --- /dev/null +++ b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp @@ -0,0 +1,50 @@ +#include "torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h" + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/xla_lower_util.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape(const OpList& inputs) { + std::vector output_shapes; + output_shapes.reserve(inputs.size() - 1); + for (size_t i = 0; i < inputs.size() - 2; ++i) { + const xla::Shape& input_shape = inputs[i].shape(); + output_shapes.push_back(input_shape); + } + output_shapes.push_back(xla::ShapeUtil::MakeShape( + inputs[inputs.size() - 2].shape().element_type(), {})); + return xla::ShapeUtil::MakeTupleShape(output_shapes); +} + +} // namespace + +AmpForachNonFiniteCheckAndUnscale::AmpForachNonFiniteCheckAndUnscale( + const OpList& inputs) + : Node(xla_amp_foreach_non_finite_check_and_unscale, inputs, + NodeOutputShape(inputs), + /*num_outputs=*/inputs.size() - 1) {} + +NodePtr AmpForachNonFiniteCheckAndUnscale::Clone(OpList operands) const { + return MakeNode(operands); +} + +XlaOpVector AmpForachNonFiniteCheckAndUnscale::Lower( + LoweringContext* loctx) const { + std::vector inputs; + for (size_t i = 0; i < num_outputs() + 1; ++i) { + inputs.push_back(loctx->GetOutputOp(operand(i))); + } + return ReturnOps(BuildAmpForachNonFiniteCheckAndUnscale(inputs), loctx); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h new file mode 100644 index 00000000000..6ab1988546b --- /dev/null +++ b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h @@ -0,0 +1,20 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class AmpForachNonFiniteCheckAndUnscale : public Node { + public: + AmpForachNonFiniteCheckAndUnscale(const OpList& inputs); + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/ops/amp_update_scale.cpp b/torch_xla/csrc/ops/amp_update_scale.cpp new file mode 100644 index 00000000000..97ea894d9ad --- /dev/null +++ b/torch_xla/csrc/ops/amp_update_scale.cpp @@ -0,0 +1,44 @@ +#include "torch_xla/csrc/ops/amp_update_scale.h" + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/xla_lower_util.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape(const OpList& inputs) { + std::vector output_shapes; + for (size_t i = 0; i < 2; ++i) { + const xla::Shape& input_shape = inputs[i].shape(); + output_shapes.push_back(input_shape); + } + return xla::ShapeUtil::MakeTupleShape(output_shapes); +} + +} // namespace + +AmpUpdateScale::AmpUpdateScale(const OpList& inputs) + : Node(xla_amp_update_scale, inputs, NodeOutputShape(inputs), + /*num_outputs=*/2) {} + +NodePtr AmpUpdateScale::Clone(OpList operands) const { + return MakeNode(operands); +} + +XlaOpVector AmpUpdateScale::Lower(LoweringContext* loctx) const { + std::vector inputs; + for (size_t i = 0; i < 6; ++i) { + inputs.push_back(loctx->GetOutputOp(operand(i))); + } + return ReturnOps(BuildAmpUpdateScale(inputs), loctx); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/ops/amp_update_scale.h b/torch_xla/csrc/ops/amp_update_scale.h new file mode 100644 index 00000000000..e07933d8b10 --- /dev/null +++ b/torch_xla/csrc/ops/amp_update_scale.h @@ -0,0 +1,20 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class AmpUpdateScale : public Node { + public: + AmpUpdateScale(const OpList& inputs); + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/ops/xla_ops.cpp b/torch_xla/csrc/ops/xla_ops.cpp index a669b5b9863..852f95398c3 100644 --- a/torch_xla/csrc/ops/xla_ops.cpp +++ b/torch_xla/csrc/ops/xla_ops.cpp @@ -5,6 +5,9 @@ namespace ir { namespace ops { const OpKindWrapper xla_all_to_all("xla::all_to_all"); +const OpKindWrapper xla_amp_foreach_non_finite_check_and_unscale( + "xla::amp_foreach_non_finite_check_and_unscale"); +const OpKindWrapper xla_amp_update_scale("xla::amp_update_scale"); const OpKindWrapper xla_as_strided_view_update("xla::as_strided_view_update"); const OpKindWrapper xla_cast("xla::cast"); const OpKindWrapper xla_collective_permute("xla::collective_permute"); diff --git a/torch_xla/csrc/ops/xla_ops.h b/torch_xla/csrc/ops/xla_ops.h index 6b8e1cbaab4..c15bcc3dada 100644 --- a/torch_xla/csrc/ops/xla_ops.h +++ b/torch_xla/csrc/ops/xla_ops.h @@ -29,6 +29,8 @@ class OpKindWrapper { }; extern const OpKindWrapper xla_all_to_all; +extern const OpKindWrapper xla_amp_foreach_non_finite_check_and_unscale; +extern const OpKindWrapper xla_amp_update_scale; extern const OpKindWrapper xla_as_strided_view_update; extern const OpKindWrapper xla_cast; extern const OpKindWrapper xla_collective_permute; diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 33c7eb8f14f..840410dae2d 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -246,6 +246,17 @@ class XLATensor { static XLATensor _adaptive_avg_pool2d_backward(const XLATensor& grad_output, const XLATensor& input); + static void _amp_foreach_non_finite_check_and_unscale_( + absl::Span self, XLATensor& found_inf, + const XLATensor& inv_scale); + + static XLATensor _amp_update_scale(XLATensor growth_tracker, + const XLATensor& current_scale, + const XLATensor& found_inf, + float scale_growth_factor, + float scale_backoff_factor, + int growth_interval); + static XLATensor abs(const XLATensor& input); static void abs_(XLATensor& input); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 61a9504d8ee..912684829a5 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -21,6 +21,8 @@ #include "torch_xla/csrc/ops/all.h" #include "torch_xla/csrc/ops/all_reduce.h" #include "torch_xla/csrc/ops/all_to_all.h" +#include "torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h" +#include "torch_xla/csrc/ops/amp_update_scale.h" #include "torch_xla/csrc/ops/any.h" #include "torch_xla/csrc/ops/arg_max.h" #include "torch_xla/csrc/ops/arg_min.h" @@ -468,6 +470,50 @@ XLATensor XLATensor::_adaptive_avg_pool2d_backward(const XLATensor& grad_output, grad_output.GetIrValue(), input.GetIrValue())); } +void XLATensor::_amp_foreach_non_finite_check_and_unscale_( + absl::Span self, XLATensor& found_inf, + const XLATensor& inv_scale) { + std::vector inputs; + for (const auto& x : self) { + inputs.push_back(x.GetIrValue()); + } + inputs.push_back(found_inf.GetIrValue()); + inputs.push_back(inv_scale.GetIrValue()); + absl::Span inputs_span{inputs}; + ir::NodePtr node = + ir::MakeNode(inputs_span); + for (size_t i = 0; i < self.size(); ++i) { + self[i].SetInPlaceIrValue(ir::Value(node, i)); + } + found_inf.SetInPlaceIrValue(ir::Value(node, self.size())); +} + +XLATensor XLATensor::_amp_update_scale(XLATensor growth_tracker, + const XLATensor& current_scale, + const XLATensor& found_inf, + float scale_growth_factor, + float scale_backoff_factor, + int growth_interval) { + ir::Value scale_growth_factor_ir = GetIrValueForScalar( + scale_growth_factor, xla::PrimitiveType::F32, growth_tracker.GetDevice()); + ir::Value scale_backoff_factor_ir = + GetIrValueForScalar(scale_backoff_factor, xla::PrimitiveType::F32, + growth_tracker.GetDevice()); + ir::Value growth_interval_ir = GetIrValueForScalar( + growth_interval, xla::PrimitiveType::S32, growth_tracker.GetDevice()); + std::vector inputs; + inputs.push_back(growth_tracker.GetIrValue()); + inputs.push_back(current_scale.GetIrValue()); + inputs.push_back(found_inf.GetIrValue()); + inputs.push_back(scale_growth_factor_ir); + inputs.push_back(scale_backoff_factor_ir); + inputs.push_back(growth_interval_ir); + absl::Span inputs_span{inputs}; + ir::NodePtr node = ir::MakeNode(inputs_span); + growth_tracker.SetInPlaceIrValue(ir::Value(node, 0)); + return current_scale.CreateFrom(ir::Value(node, 1)); +} + XLATensor XLATensor::abs(const XLATensor& input) { return input.CreateFrom(ir::ops::Abs(input.GetIrValue())); } diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index e1a90226de9..53bdffc7922 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -749,4 +749,67 @@ xla::XlaOp BuildMaskedScatter(xla::XlaOp input, xla::XlaOp mask, scatter_dnums); } +std::vector BuildAmpForachNonFiniteCheckAndUnscale( + const std::vector& inputs) { + const xla::PrimitiveType origin_type = xla::PrimitiveType::F32; + std::vector found_infs; + xla::XlaOp one = xla::One(inputs[0].builder(), xla::PrimitiveType::S32); + for (size_t i = 0; i < inputs.size() - 2; ++i) { + xla::XlaOp all_finite = + xla::ReduceAll(xla::ConvertElementType(xla::IsFinite(inputs[i]), + xla::PrimitiveType::S32), + one, + xla::CreateScalarAndComputation(xla::PrimitiveType::S32, + inputs[i].builder())); + found_infs.push_back(one - all_finite); + } + xla::XlaOp found_inf = xla::ConvertElementType(inputs[inputs.size() - 2], + xla::PrimitiveType::S32); + for (size_t i = 0; i < found_infs.size(); ++i) { + found_inf = xla::Or(found_inf, found_infs[i]); + } + xla::XlaOp inv_scale = inputs[inputs.size() - 1]; + std::vector results; + + for (size_t i = 0; i < inputs.size() - 2; ++i) { + results.push_back(inputs[i] * inv_scale); + } + results.push_back(xla::ConvertElementType(found_inf, origin_type)); + return results; +} + +std::vector BuildAmpUpdateScale( + const std::vector& inputs) { + const auto& growth_tracker = inputs[0]; + xla::XlaOp one = xla::One(growth_tracker.builder(), xla::PrimitiveType::S32); + xla::XlaOp one_float = + xla::One(growth_tracker.builder(), xla::PrimitiveType::F32); + const auto& current_scale = inputs[1]; + const auto& found_inf = xla::Min( + xla::ConvertElementType(inputs[2], xla::PrimitiveType::S32), one); + const auto& growth_factor = inputs[3]; + const auto& backoff_factor = inputs[4]; + const auto& growth_interval = inputs[5]; + + xla::XlaOp all_finite = one - found_inf; + xla::XlaOp not_achieve_interval = + xla::Min((growth_interval - one - growth_tracker), one); + xla::XlaOp new_growth_tracker = + (growth_tracker + one) * all_finite * not_achieve_interval; + xla::XlaOp new_scale = + current_scale * + xla::Max(growth_factor * xla::ConvertElementType( + all_finite * (one - not_achieve_interval), + xla::PrimitiveType::F32), + one_float) * + (backoff_factor * + xla::ConvertElementType(found_inf, xla::PrimitiveType::F32) + + xla::ConvertElementType((one - found_inf) * one, + xla::PrimitiveType::F32)); + std::vector results; + results.push_back(new_growth_tracker); + results.push_back(new_scale); + return results; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index e63efe08135..ace04605ea2 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -88,4 +88,10 @@ std::vector BuildMaskedSelect(xla::XlaOp input, xla::XlaOp mask); xla::XlaOp BuildMaskedScatter(xla::XlaOp input, xla::XlaOp mask, xla::XlaOp source); +std::vector BuildAmpForachNonFiniteCheckAndUnscale( + const std::vector& inputs); + +std::vector BuildAmpUpdateScale( + const std::vector& inputs); + } // namespace torch_xla From ac0d0c7aaa573762e182cf0bd8aa0ee48c42eb87 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Fri, 11 Dec 2020 07:29:56 +0800 Subject: [PATCH 2/4] update amp to match the code change in pytorch --- torch_xla/amp/autocast_mode.py | 239 +------------- torch_xla/amp/grad_scaler.py | 554 +-------------------------------- 2 files changed, 5 insertions(+), 788 deletions(-) diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py index fc995f6b747..2f7a0a83555 100644 --- a/torch_xla/amp/autocast_mode.py +++ b/torch_xla/amp/autocast_mode.py @@ -1,238 +1,5 @@ import torch -import functools -import warnings -try: - import numpy as np -except ModuleNotFoundError: - np = None -from torch._six import container_abcs, string_classes - -class autocast(object): - r""" - Instances of :class:`autocast` serve as context managers or decorators that - allow regions of your script to run in mixed precision. - - In these regions, CUDA ops run in an op-specific dtype chosen by autocast - to improve performance while maintaining accuracy. - See the :ref:`Autocast Op Reference` for details. - - When entering an autocast-enabled region, Tensors may be any type. - You should not call ``.half()`` on your model(s) or inputs when using autocasting. - - :class:`autocast` should wrap only the forward pass(es) of your network, including the loss - computation(s). Backward passes under autocast are not recommended. - Backward ops run in the same type that autocast used for corresponding forward ops. - - Example:: - - # Creates model and optimizer in default precision - model = Net().cuda() - optimizer = optim.SGD(model.parameters(), ...) - - for input, target in data: - optimizer.zero_grad() - - # Enables autocasting for the forward pass (model + loss) - with autocast(): - output = model(input) - loss = loss_fn(output, target) - - # Exits the context manager before backward() - loss.backward() - optimizer.step() - - See the :ref:`Automatic Mixed Precision examples` for usage (along with gradient scaling) - in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions). - - :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model:: - - class AutocastModel(nn.Module): - ... - @autocast() - def forward(self, input): - ... - - Floating-point Tensors produced in an autocast-enabled region may be ``float16``. - After returning to an autocast-disabled region, using them with floating-point - Tensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s) - produced in the autocast region back to ``float32`` (or other dtype if desired). - If a Tensor from the autocast region is already ``float32``, the cast is a no-op, - and incurs no additional overhead. Example:: - - # Creates some tensors in default dtype (here assumed to be float32) - a_float32 = torch.rand((8, 8), device="cuda") - b_float32 = torch.rand((8, 8), device="cuda") - c_float32 = torch.rand((8, 8), device="cuda") - d_float32 = torch.rand((8, 8), device="cuda") - - with autocast(): - # torch.mm is on autocast's list of ops that should run in float16. - # Inputs are float32, but the op runs in float16 and produces float16 output. - # No manual casts are required. - e_float16 = torch.mm(a_float32, b_float32) - # Also handles mixed input types - f_float16 = torch.mm(d_float32, e_float16) - - # After exiting autocast, calls f_float16.float() to use with d_float32 - g_float32 = torch.mm(d_float32, f_float16.float()) - - Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe, - please file an issue. - - ``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions. - Locally disabling autocast can be useful, for example, if you want to force a subregion - to run in a particular ``dtype``. Disabling autocast gives you explicit control over - the execution type. In the subregion, inputs from the surrounding region - should be cast to ``dtype`` before use:: - - # Creates some tensors in default dtype (here assumed to be float32) - a_float32 = torch.rand((8, 8), device="cuda") - b_float32 = torch.rand((8, 8), device="cuda") - c_float32 = torch.rand((8, 8), device="cuda") - d_float32 = torch.rand((8, 8), device="cuda") - - with autocast(): - e_float16 = torch.mm(a_float32, b_float32) - - with autocast(enabled=False): - # Calls e_float16.float() to ensure float32 execution - # (necessary because e_float16 was created in an autocasted region) - f_float32 = torch.mm(c_float32, e_float16.float()) - - # No manual casts are required when re-entering the autocast-enabled region. - # torch.mm again runs in float16 and produces float16 output, regardless of input types. - g_float16 = torch.mm(d_float32, f_float32) - - The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decorator - must be invoked in that thread. This affects :class:`torch.nn.DataParallel` and - :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process - (see :ref:`Working with Multiple GPUs`). - - Arguments: - enabled(bool, optional, default=True): Whether autocasting should be enabled in the region. - """ - - def __init__(self, enabled=True): - self._enabled = enabled - - def __enter__(self): - self.prev = torch.is_autocast_enabled() - torch.set_autocast_enabled(self._enabled) - torch.autocast_increment_nesting() - - def __exit__(self, *args): - # Drop the cache when we exit to a nesting level that's outside any instance of autocast. - if torch.autocast_decrement_nesting() == 0: - torch.clear_autocast_cache() - torch.set_autocast_enabled(self.prev) - return False - - def __call__(self, func): - - @functools.wraps(func) - def decorate_autocast(*args, **kwargs): - with self: - return func(*args, **kwargs) - - return decorate_autocast - - -# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which -# may be falsely detected as "Iterables." -def _cast(value, dtype): - if isinstance(value, torch.Tensor): - is_eligible = ( - value.is_floating_point() and value.is_cuda and - (value.dtype is not torch.float64)) - return value.to(dtype) if is_eligible else value - elif isinstance(value, string_classes): - return value - elif np is not None and isinstance(value, np.ndarray): - return value - elif isinstance(value, container_abcs.Mapping): - return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()} - elif isinstance(value, container_abcs.Iterable): - iterable = map(lambda v: _cast(v, dtype), value) - if isinstance(value, list) or isinstance(value, tuple): - return type(value)(iterable) - else: - return iterable - else: - return value - - -# custom_fwd is a decorator that may or may not be used with arguments, following -# https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument. -# this works: -# @custom_fwd -# def forward(...): -# this also works: -# @custom_fwd(cast_inputs=torch.float) -# def forward(...): -# TODO: when python 2 support is dropped, change the signature to -# def custom_fwd(fwd=None, *, cast_inputs=None) with internal changes following the link above. -def custom_fwd(fwd=None, **kwargs): - """ - Helper decorator for ``forward`` methods of custom autograd functions (subclasses of - :class:`torch.autograd.Function`). See the :ref:`example page` for more detail. - - Arguments: - cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, - when ``forward`` runs in an autocast-enabled region, casts incoming - floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected), - then executes ``forward`` with autocast disabled. - If ``None``, ``forward``'s internal ops execute with the current autocast state. - - .. note:: - If the decorated ``forward`` is called outside an autocast-enabled region, - :func:`custom_fwd` is a no-op and ``cast_inputs`` has no effect. - """ - if fwd is None: - if len(kwargs) == 0: - cast_inputs = None - else: - assert len(kwargs) == 1 - cast_inputs = kwargs["cast_inputs"] - return functools.partial(custom_fwd, cast_inputs=cast_inputs) - - if len(kwargs) == 0: - cast_inputs = None - else: - assert len(kwargs) == 1 - cast_inputs = kwargs["cast_inputs"] - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if cast_inputs is None: - args[0]._fwd_used_autocast = torch.is_autocast_enabled() - return fwd(*args, **kwargs) - else: - autocast_context = torch.is_autocast_enabled() - args[0]._fwd_used_autocast = False - if autocast_context: - with autocast(enabled=False): - return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs)) - else: - return fwd(*args, **kwargs) - - return decorate_fwd - - -# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate -# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match -# cast_inputs supplied to custom_fwd. -def custom_bwd(bwd): - """ - Helper decorator for backward methods of custom autograd functions (subclasses of - :class:`torch.autograd.Function`). - Ensures that ``backward`` executes with the same autocast state as ``forward``. - See the :ref:`example page` for more detail. - """ - - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with autocast(args[0]._fwd_used_autocast): - return bwd(*args, **kwargs) - - return decorate_bwd +autocast = torch.cuda.amp.autocast +custom_fwd = torch.cuda.amp.custom_fwd +custom_bwd = torch.cuda.amp.custom_bwd diff --git a/torch_xla/amp/grad_scaler.py b/torch_xla/amp/grad_scaler.py index a70096d8d00..77c54599c16 100644 --- a/torch_xla/amp/grad_scaler.py +++ b/torch_xla/amp/grad_scaler.py @@ -1,563 +1,13 @@ import torch -from collections import defaultdict -from torch._six import container_abcs -import warnings -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple - import torch_xla.core.xla_model as xm -class _MultiDeviceReplicator(object): - """ - Lazily serves copies of a tensor to requested devices. Copies are cached per-device. - """ - - def __init__(self, master_tensor: torch.Tensor) -> None: - self.master = master_tensor - self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} - - def get(self, device) -> torch.Tensor: - retval = self._per_device_tensors.get(device, None) - if retval is None: - retval = self.master.to(device=device, non_blocking=True, copy=True) - self._per_device_tensors[device] = retval - return retval - - -# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, -# as well as associated "enum" values. Prefers defining these at top level because -# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. -# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler -# causes a circular reference, which we'd rather avoid. -class OptState(Enum): - READY = 0 - UNSCALED = 1 - STEPPED = 2 - - -def _refresh_per_optimizer_state(): - return {"stage": OptState.READY, "found_inf_per_device": {}} - - -class GradScaler(object): - _scale: Optional[torch.Tensor] - _grows_tracker: Optional[torch.Tensor] - _per_optimizer_states: Dict[int, Dict[str, Any]] - """ - An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling - conveniently. - - * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. - * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. - * ``scaler.update()`` updates ``scaler``'s scale factor. - - Example:: - - # Creates a GradScaler once at the beginning of training. - scaler = GradScaler() - - for epoch in epochs: - for input, target in data: - optimizer.zero_grad() - output = model(input) - loss = loss_fn(output, target) - - # Scales loss. Calls backward() on scaled loss to create scaled gradients. - scaler.scale(loss).backward() - - # scaler.step() first unscales gradients of the optimizer's params. - # If gradients don't contain infs/NaNs, optimizer.step() is then called, - # otherwise, optimizer.step() is skipped. - scaler.step(optimizer) - - # Updates the scale for next iteration. - scaler.update() - - See the :ref:`Automatic Mixed Precision examples` for usage - (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, - and multiple losses/optimizers. - - ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, - a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if - the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used - without incurring inf or NaN gradient values. - ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every - ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). - - * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params - themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. - - * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. - If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by - ``growth_factor``. - - The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its - value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these - iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). - - Arguments: - init_scale (float, optional, default=2.**16): Initial scale factor. - growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during - :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. - backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during - :meth:`update` if inf/NaN gradients occur in an iteration. - growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients - that must occur for the scale to be multiplied by ``growth_factor``. - enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply - invokes the underlying ``optimizer.step()``, and other methods become no-ops. - """ - - def __init__(self, - init_scale=2.**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - enabled=True): - self._enabled = enabled - - if self._enabled: - assert growth_factor > 1.0, "The growth factor must be > 1.0." - assert backoff_factor < 1.0, "The backoff factor must be < 1.0." - - self._init_scale = init_scale - # self._scale will be lazily initialized during the first call to scale() - self._scale = None - self._growth_factor = growth_factor - self._backoff_factor = backoff_factor - self._growth_interval = growth_interval - self._init_growth_tracker = 0 - # self._growth_tracker will be lazily initialized during the first call to scale() - self._growth_tracker = None - self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - - def _check_scale_growth_tracker( - self, funcname) -> Tuple[torch.Tensor, torch.Tensor]: - fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." - assert self._scale is not None, "Attempted {} but _scale is None. ".format( - funcname) + fix - assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format( - funcname) + fix - return (self._scale, self._growth_tracker) - - def _lazy_init_scale_growth_tracker(self, dev): - assert self._growth_tracker is None, "_growth_tracker initialized before _scale" - self._scale = torch.tensor( - self._init_scale, dtype=torch.float32, device=dev) - self._growth_tracker = torch.tensor( - self._init_growth_tracker, dtype=torch.int32, device=dev) - - def scale(self, outputs): - """ - Multiplies ('scales') a tensor or list of tensors by the scale factor. - - Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned - unmodified. - - Arguments: - outputs (Tensor or iterable of Tensors): Outputs to scale. - """ - if not self._enabled: - return outputs - - # Short-circuit for the common case. - if isinstance(outputs, torch.Tensor): - if self._scale is None: - self._lazy_init_scale_growth_tracker(outputs.device) - assert self._scale is not None - # return outputs * self._scale.to(device=outputs.device, non_blocking=True) - return outputs * self._scale - - # Invoke the more complex machinery only if we're treating multiple outputs. - stash: List[_MultiDeviceReplicator] = [ - ] # holds a reference that can be overwritten by apply_scale - - def apply_scale(val): - if isinstance(val, torch.Tensor): - if len(stash) == 0: - if self._scale is None: - self._lazy_init_scale_growth_tracker(val.device) - assert self._scale is not None - stash.append(_MultiDeviceReplicator(self._scale)) - return val * stash[0].get(val.device) - elif isinstance(val, container_abcs.Iterable): - iterable = map(apply_scale, val) - if isinstance(val, list) or isinstance(val, tuple): - return type(val)(iterable) - else: - return iterable - else: - raise ValueError("outputs must be a Tensor or an iterable of Tensors") - - return apply_scale(outputs) - - def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): - per_device_inv_scale = _MultiDeviceReplicator(inv_scale) - per_device_found_inf = _MultiDeviceReplicator(found_inf) - - # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. - # There could be hundreds of grads, so we'd like to iterate through them just once. - # However, we don't know their devices or dtypes in advance. - - # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict - # Google says mypy struggles with defaultdicts type annotations. - per_device_and_dtype_grads = defaultdict( - lambda: defaultdict(list)) # type: ignore[var-annotated] - with torch.no_grad(): - for group in optimizer.param_groups: - for param in group["params"]: - if param.grad is None: - continue - if (not allow_fp16) and param.grad.dtype == torch.float16: - raise ValueError("Attempting to unscale FP16 gradients.") - if param.grad.is_sparse: - # is_coalesced() == False means the sparse grad has values with duplicate indices. - # coalesce() deduplicates indices and adds all values that have the same index. - # For scaled fp16 values, there's a good chance coalescing will cause overflow, - # so we should check the coalesced _values(). - if param.grad.dtype is torch.float16: - param.grad = param.grad.coalesce() - to_unscale = param.grad._values() - else: - to_unscale = param.grad - - # TODO: is there a way to split by device and dtype without appending in the inner loop? - per_device_and_dtype_grads[to_unscale.device][ - to_unscale.dtype].append(to_unscale) - - for device, per_dtype_grads in per_device_and_dtype_grads.items(): - for grads in per_dtype_grads.values(): - torch._amp_foreach_non_finite_check_and_unscale_( - grads, per_device_found_inf.get(device), - per_device_inv_scale.get(device)) - - return per_device_found_inf._per_device_tensors - - def unscale_(self, optimizer): - """ - Divides ("unscales") the optimizer's gradient tensors by the scale factor. - - :meth:`unscale_` is optional, serving cases where you need to - :ref:`modify or inspect gradients` - between the backward pass(es) and :meth:`step`. - If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. - - Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: - - ... - scaler.scale(loss).backward() - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) - scaler.step(optimizer) - scaler.update() - - Arguments: - optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. - - .. note:: - :meth:`unscale_` does not incur a CPU-GPU sync. - - .. warning:: - :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, - and only after all gradients for that optimizer's assigned parameters have been accumulated. - Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. - - .. warning:: - :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. - """ - if not self._enabled: - return - - self._check_scale_growth_tracker("unscale_") - - optimizer_state = self._per_optimizer_states[id(optimizer)] - - if optimizer_state["stage"] is OptState.UNSCALED: - raise RuntimeError( - "unscale_() has already been called on this optimizer since the last update()." - ) - elif optimizer_state["stage"] is OptState.STEPPED: - raise RuntimeError("unscale_() is being called after step().") - - # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. - assert self._scale is not None - inv_scale = self._scale.double().reciprocal().float() - found_inf = torch.tensor( - 0.0, dtype=torch.float32, device=self._scale.device) - - optimizer_state["found_inf_per_device"] = self._unscale_grads_( - optimizer, inv_scale, found_inf, False) - optimizer_state["stage"] = OptState.UNSCALED - - def step(self, optimizer, *args, **kwargs): - """ - :meth:`step` carries out the following two operations: - - 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` - earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. - 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled - gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. - - ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. - - Returns the return value of ``optimizer.step(*args, **kwargs)``. - - Arguments: - optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. - args: Any arguments. - kwargs: Any keyword arguments. - - .. warning:: - Closure use is not currently supported. - """ - if (not self._enabled): - return optimizer.step(*args, **kwargs) - - if "closure" in kwargs: - raise RuntimeError( - "Closure use is not currently supported if GradScaler is enabled.") - - self._check_scale_growth_tracker("step") - - optimizer_state = self._per_optimizer_states[id(optimizer)] - - if optimizer_state["stage"] is OptState.STEPPED: - raise RuntimeError( - "step() has already been called since the last update().") +class GradScaler(torch.cuda.amp.GradScaler): + def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): retval = None - - if (hasattr(optimizer, "_step_supports_amp_scaling") and - optimizer._step_supports_amp_scaling): - # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. - # The contract with custom optimizers is that their step() should accept an additional, - # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: - # it can query its own state, invoke unscale_ on itself, etc - retval = optimizer.step(*args, **dict(kwargs, grad_scaler=self)) - optimizer_state["stage"] = OptState.STEPPED - return retval - - if optimizer_state["stage"] is OptState.READY: - self.unscale_(optimizer) - - assert len(optimizer_state["found_inf_per_device"] - ) > 0, "No inf checks were recorded for this optimizer." - - # call mark_step before v.item() to make sure the gradients could be reused in optimizer.step xm.mark_step() if not sum( v.item() for v in optimizer_state["found_inf_per_device"].values()): retval = optimizer.step(*args, **kwargs) - - optimizer_state["stage"] = OptState.STEPPED - return retval - - def update(self, new_scale=None): - """ - Updates the scale factor. - - If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` - to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, - the scale is multiplied by ``growth_factor`` to increase it. - - Passing ``new_scale`` sets the scale directly. - - Arguments: - new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. - - .. warning:: - :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has - been invoked for all optimizers used this iteration. - """ - if not self._enabled: - return - - _scale, _growth_tracker = self._check_scale_growth_tracker("update") - - if new_scale is not None: - # Accept a new user-defined scale. - if isinstance(new_scale, float): - self._scale = torch.tensor( - new_scale, dtype=torch.float32, device=_scale.device) - else: - reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." - assert isinstance( - new_scale, - torch.cuda.FloatTensor), reason # type: ignore[attr-defined] - assert new_scale.numel() == 1, reason - assert new_scale.requires_grad is False, reason - self._scale = new_scale - else: - # Consume shared inf/nan data collected from optimizers to update the scale. - # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. - found_infs = [ - found_inf.to(device=_scale.device, non_blocking=True) - for state in self._per_optimizer_states.values() - for found_inf in state["found_inf_per_device"].values() - ] - - assert len(found_infs) > 0, "No inf checks were recorded prior to update." - - found_inf_combined = found_infs[0] - if len(found_infs) > 1: - for i in range(1, len(found_infs)): - found_inf_combined += found_infs[i] - - self._scale = torch._amp_update_scale(_growth_tracker, _scale, - found_inf_combined, - self._growth_factor, - self._backoff_factor, - self._growth_interval) - - # To prepare for next iteration, clear the data collected from optimizers this iteration. - self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - - def _get_scale_async(self): - return self._scale - - def get_scale(self): - """ - Returns a Python float containing the current scale, or 1.0 if scaling is disabled. - - .. warning:: - :meth:`get_scale` incurs a CPU-GPU sync. - """ - if self._enabled: - return self._init_scale if self._scale is None else self._get_scale_async( - ).item() - else: - return 1.0 - - def get_growth_factor(self): - r""" - Returns a Python float containing the scale growth factor. - """ - return self._growth_factor - - def set_growth_factor(self, new_factor): - r""" - Arguments: - new_scale (float): Value to use as the new scale growth factor. - """ - self._growth_factor = new_factor - - def get_backoff_factor(self): - r""" - Returns a Python float containing the scale backoff factor. - """ - return self._backoff_factor - - def set_backoff_factor(self, new_factor): - r""" - Arguments: - new_scale (float): Value to use as the new scale backoff factor. - """ - self._backoff_factor = new_factor - - def get_growth_interval(self): - r""" - Returns a Python int containing the growth interval. - """ - return self._growth_interval - - def set_growth_interval(self, new_interval): - r""" - Arguments: - new_interval (int): Value to use as the new growth interval. - """ - self._growth_interval = new_interval - - def _get_growth_tracker(self): - if self._enabled: - return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item( - ) - else: - return 0 - - def is_enabled(self): - r""" - Returns a bool indicating whether this instance is enabled. - """ - return self._enabled - - def state_dict(self): - r""" - Returns the state of the scaler as a :class:`dict`. It contains five entries: - - * ``"scale"`` - a Python float containing the current scale - * ``"growth_factor"`` - a Python float containing the current growth factor - * ``"backoff_factor"`` - a Python float containing the current backoff factor - * ``"growth_interval"`` - a Python int containing the current growth interval - * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. - - If this instance is not enabled, returns an empty dict. - - .. note:: - If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` - should be called after :meth:`update`. - """ - return { - "scale": self.get_scale(), - "growth_factor": self._growth_factor, - "backoff_factor": self._backoff_factor, - "growth_interval": self._growth_interval, - "_growth_tracker": self._get_growth_tracker() - } if self._enabled else {} - - def load_state_dict(self, state_dict): - r""" - Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. - - Arguments: - state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. - """ - if not self._enabled: - return - - if len(state_dict) == 0: - raise RuntimeError( - "The source state dict is empty, possibly because it was saved " - "from a disabled instance of GradScaler.") - - self._init_scale = state_dict["scale"] - if self._scale is not None: - self._scale.fill_(state_dict["scale"]) - self._growth_factor = state_dict["growth_factor"] - self._backoff_factor = state_dict["backoff_factor"] - self._growth_interval = state_dict["growth_interval"] - self._init_growth_tracker = state_dict["_growth_tracker"] - if self._growth_tracker is not None: - self._growth_tracker.fill_(state_dict["_growth_tracker"]) - - def __getstate__(self): - state = self.__dict__.copy() - if self._enabled: - assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ - "of an iteration, or at the end after scaler.update()." - # Pickling _scale and _growth_tracker Tensors directly triggers - # "warnings.warn("pickle support for Storage will be removed in 1.5..." - # so instead, we set the unpickled instance up to reinitialize them lazily. - state['_init_scale'] = self.get_scale() - state['_init_growth_tracker'] = self._get_growth_tracker() - state['_scale'] = None - state['_growth_tracker'] = None - return state - - def __setstate__(self, state): - self.__dict__.update(state) - - def _check_inf_per_device(self, optimizer): - _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") - - dummy_inv_scale = torch.tensor( - 1.0, dtype=torch.float32, device=_scale.device) - found_inf = torch.tensor(0.0, dtype=torch.float32, device=_scale.device) - - self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ - self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) - - return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] - - def _found_inf_per_device(self, optimizer): - return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] From c79782027ea60d31d5d2169803c1e6a1e73e571e Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Sat, 2 Jan 2021 07:34:39 +0800 Subject: [PATCH 3/4] add amp op cpp test --- .circleci/test.sh | 3 + test/args_parse.py | 1 - test/cpp/test_aten_xla_tensor.cpp | 102 +++++++++ test/test_amp.py | 61 ------ test/test_train_mp_mnist.py | 2 - test/test_train_mp_mnist_amp.py | 194 ++++++++++++++++++ torch_xla/amp/__init__.py | 2 +- torch_xla/csrc/aten_xla_type.cpp | 5 +- torch_xla/csrc/batch_norm.cpp | 9 +- ...p_foreach_non_finite_check_and_unscale.cpp | 43 ++-- ...amp_foreach_non_finite_check_and_unscale.h | 6 +- torch_xla/csrc/ops/amp_update_scale.cpp | 43 ++-- torch_xla/csrc/ops/amp_update_scale.h | 11 +- torch_xla/csrc/ops/xla_ops.cpp | 3 - torch_xla/csrc/ops/xla_ops.h | 2 - torch_xla/csrc/tensor.h | 6 +- torch_xla/csrc/tensor_methods.cpp | 33 +-- torch_xla/csrc/xla_lower_util.cpp | 92 +++++---- torch_xla/csrc/xla_lower_util.h | 15 +- 19 files changed, 448 insertions(+), 185 deletions(-) delete mode 100644 test/test_amp.py create mode 100644 test/test_train_mp_mnist_amp.py diff --git a/.circleci/test.sh b/.circleci/test.sh index eaf5755cdac..9a67de17f9b 100755 --- a/.circleci/test.sh +++ b/.circleci/test.sh @@ -20,6 +20,9 @@ echo "Running Python Tests" echo "Running MNIST Test" python test/test_train_mnist.py --tidy +if [ -x "$(command -v nvidia-smi)" ]; then + python test/test_train_mp_mnist_amp.py +fi echo "Running C++ Tests" pushd test/cpp diff --git a/test/args_parse.py b/test/args_parse.py index e69d9960b87..0c6479ce8d7 100644 --- a/test/args_parse.py +++ b/test/args_parse.py @@ -32,7 +32,6 @@ def parse_common_options(datadir=None, parser.add_argument('--fake_data', action='store_true') parser.add_argument('--tidy', action='store_true') parser.add_argument('--metrics_debug', action='store_true') - parser.add_argument('--amp', action='store_true') if opts: for name, aopts in opts: parser.add_argument(name, **aopts) diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index ab1ca22bfa3..da89600f4cf 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -10007,5 +10007,107 @@ TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) { } } +TEST_F(AtenXlaTensorTest, TestAmpForeachNonFiniteCheckAndUnscale) { + torch::Tensor grads0 = + torch::tensor({1, 2, 3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor grads1 = torch::tensor({1.0, 2.0, std::nan("1"), 4.0}, + torch::TensorOptions(torch::kFloat)); + torch::Tensor inv_scale = + torch::scalar_tensor(0.2, torch::TensorOptions(torch::kFloat)); + torch::Tensor found_inf = + torch::scalar_tensor(0, torch::TensorOptions(torch::kFloat)); + torch::Tensor grads_output0 = grads0 * inv_scale; + torch::Tensor found_inf_output0 = + torch::scalar_tensor(0, torch::TensorOptions(torch::kFloat)); + torch::Tensor found_inf_output1 = + torch::scalar_tensor(1, torch::TensorOptions(torch::kFloat)); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_grads0 = CopyToDevice(grads0, device); + torch::Tensor xla_inv_scale = CopyToDevice(inv_scale, device); + torch::Tensor xla_found_inf = CopyToDevice(found_inf, device); + torch::_amp_foreach_non_finite_check_and_unscale_(xla_grads0, xla_found_inf, + xla_inv_scale); + AllClose(grads_output0, xla_grads0, /*rtol=*/1e-2, /*atol=*/1e-4); + AllEqual(found_inf_output0, xla_found_inf); + + torch::Tensor xla_grads1 = CopyToDevice(grads1, device); + torch::_amp_foreach_non_finite_check_and_unscale_(xla_grads1, xla_found_inf, + xla_inv_scale); + AllEqual(found_inf_output1, xla_found_inf); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_amp_foreach_non_finite_check_and_unscale_", + cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestAmpUpdateScale) { + torch::Tensor growth_tracker = + torch::scalar_tensor(0, torch::TensorOptions(torch::kInt32)); + torch::Tensor current_scale = + torch::scalar_tensor(4, torch::TensorOptions(torch::kFloat)); + torch::Tensor found_inf = + torch::scalar_tensor(1, torch::TensorOptions(torch::kFloat)); + torch::Tensor not_found_inf = + torch::scalar_tensor(0, torch::TensorOptions(torch::kFloat)); + float scale_growth_factor = 2.0; + float scale_backoff_factor = 0.5; + int growth_interval = 3; + + torch::Tensor growth_tracker_result0 = + torch::scalar_tensor(1, torch::TensorOptions(torch::kInt32)); + torch::Tensor current_scale_result0 = + torch::scalar_tensor(4, torch::TensorOptions(torch::kFloat)); + torch::Tensor growth_tracker_result1 = + torch::scalar_tensor(2, torch::TensorOptions(torch::kInt32)); + torch::Tensor current_scale_result1 = + torch::scalar_tensor(4, torch::TensorOptions(torch::kFloat)); + torch::Tensor growth_tracker_result2 = + torch::scalar_tensor(0, torch::TensorOptions(torch::kInt32)); + torch::Tensor current_scale_result2 = + torch::scalar_tensor(8, torch::TensorOptions(torch::kFloat)); + torch::Tensor growth_tracker_result3 = + torch::scalar_tensor(0, torch::TensorOptions(torch::kInt32)); + torch::Tensor current_scale_result3 = + torch::scalar_tensor(4, torch::TensorOptions(torch::kFloat)); + + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_growth_tracker = CopyToDevice(growth_tracker, device); + torch::Tensor xla_current_scale = CopyToDevice(current_scale, device); + torch::Tensor xla_found_inf = CopyToDevice(found_inf, device); + torch::Tensor xla_not_found_inf = CopyToDevice(not_found_inf, device); + + xla_current_scale = torch::_amp_update_scale( + xla_growth_tracker, xla_current_scale, xla_not_found_inf, + scale_growth_factor, scale_backoff_factor, growth_interval); + AllClose(current_scale_result0, xla_current_scale, /*rtol=*/1e-2, + /*atol=*/1e-4); + AllEqual(growth_tracker_result0, xla_growth_tracker); + + xla_current_scale = torch::_amp_update_scale( + xla_growth_tracker, xla_current_scale, xla_not_found_inf, + scale_growth_factor, scale_backoff_factor, growth_interval); + AllClose(current_scale_result1, xla_current_scale, /*rtol=*/1e-2, + /*atol=*/1e-4); + AllEqual(growth_tracker_result1, xla_growth_tracker); + + xla_current_scale = torch::_amp_update_scale( + xla_growth_tracker, xla_current_scale, xla_not_found_inf, + scale_growth_factor, scale_backoff_factor, growth_interval); + AllClose(current_scale_result2, xla_current_scale, /*rtol=*/1e-2, + /*atol=*/1e-4); + AllEqual(growth_tracker_result2, xla_growth_tracker); + + xla_current_scale = torch::_amp_update_scale( + xla_growth_tracker, xla_current_scale, xla_found_inf, + scale_growth_factor, scale_backoff_factor, growth_interval); + AllClose(current_scale_result3, xla_current_scale, /*rtol=*/1e-2, + /*atol=*/1e-4); + AllEqual(growth_tracker_result3, xla_growth_tracker); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_amp_update_scale", + cpp_test::GetIgnoredCounters()); +} + } // namespace cpp_test } // namespace torch_xla diff --git a/test/test_amp.py b/test/test_amp.py deleted file mode 100644 index af3aaf9b760..00000000000 --- a/test/test_amp.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch -import torch_xla.core.xla_model as xm -import unittest - - -class TestAmp(unittest.TestCase): - - def test_amp_update_scale(self): - device = xm.xla_device() - growth_tracker = torch.tensor(0, dtype=torch.int32, device=device) - current_scale = torch.tensor(4, dtype=torch.float, device=device) - found_inf = torch.tensor(0, dtype=torch.float, device=device) - scale_growth_factor = 2.0 - scale_backoff_factor = 0.5 - growth_interval = 3 - current_scale = torch._amp_update_scale(growth_tracker, current_scale, - found_inf, scale_growth_factor, - scale_backoff_factor, - growth_interval) - self.assertAlmostEqual(current_scale.item(), 4.0) - self.assertEqual(growth_tracker.item(), 1) - current_scale = torch._amp_update_scale(growth_tracker, current_scale, - found_inf, scale_growth_factor, - scale_backoff_factor, - growth_interval) - self.assertAlmostEqual(current_scale.item(), 4.0) - self.assertEqual(growth_tracker.item(), 2) - current_scale = torch._amp_update_scale(growth_tracker, current_scale, - found_inf, scale_growth_factor, - scale_backoff_factor, - growth_interval) - self.assertAlmostEqual(current_scale.item(), 8.0) - self.assertEqual(growth_tracker.item(), 0) - found_inf = torch.tensor(1, dtype=torch.float, device=device) - current_scale = torch._amp_update_scale(growth_tracker, current_scale, - found_inf, scale_growth_factor, - scale_backoff_factor, - growth_interval) - self.assertAlmostEqual(current_scale.item(), 4.0) - self.assertEqual(growth_tracker.item(), 0) - - def test_amp_foreach_non_finite_check_and_unscale(self): - device = xm.xla_device() - grads = [torch.tensor([1, 2, 3, 4], dtype=torch.float, device=device)] - inv_scale = torch.tensor(0.2, dtype=torch.float, device=device) - found_inf = torch.tensor(0, dtype=torch.float, device=device) - torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, - inv_scale) - self.assertAlmostEqual(found_inf.item(), 0.0) - - grads = [ - torch.tensor([1, 2, 3, float('nan')], dtype=torch.float, device=device), - torch.tensor([1, 2, 3, 5], dtype=torch.float, device=device) - ] - torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, - inv_scale) - self.assertAlmostEqual(found_inf.item(), 1.0) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 05f9adf82d2..7c556191294 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -24,7 +24,6 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.test.test_utils as test_utils -from torch_xla.amp import autocast, GradScaler class MNIST(nn.Module): @@ -119,7 +118,6 @@ def train_mnist(flags, **kwargs): writer = test_utils.get_summary_writer(flags.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() - scaler = GradScaler() def train_loop_fn(loader): tracker = xm.RateTracker() diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py new file mode 100644 index 00000000000..4c6f54195c7 --- /dev/null +++ b/test/test_train_mp_mnist_amp.py @@ -0,0 +1,194 @@ +import args_parse + +FLAGS = args_parse.parse_common_options( + datadir='/tmp/mnist-data', + batch_size=128, + momentum=0.5, + lr=0.01, + target_accuracy=98.0, + num_epochs=18) + +import os +import shutil +import sys +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +import torch_xla +import torch_xla.debug.metrics as met +import torch_xla.distributed.parallel_loader as pl +import torch_xla.utils.utils as xu +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.test.test_utils as test_utils +from torch_xla.amp import autocast, GradScaler + + +class MNIST(nn.Module): + + def __init__(self): + super(MNIST, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.bn1 = nn.BatchNorm2d(10) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.bn2 = nn.BatchNorm2d(20) + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = self.bn1(x) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def _train_update(device, x, loss, tracker, writer): + test_utils.print_training_update( + device, + x, + loss.item(), + tracker.rate(), + tracker.global_rate(), + summary_writer=writer) + + +def train_mnist(flags, **kwargs): + torch.manual_seed(1) + + if flags.fake_data: + train_loader = xu.SampleGenerator( + data=(torch.zeros(flags.batch_size, 1, 28, + 28), torch.zeros(flags.batch_size, + dtype=torch.int64)), + sample_count=60000 // flags.batch_size // xm.xrt_world_size()) + test_loader = xu.SampleGenerator( + data=(torch.zeros(flags.batch_size, 1, 28, + 28), torch.zeros(flags.batch_size, + dtype=torch.int64)), + sample_count=10000 // flags.batch_size // xm.xrt_world_size()) + else: + train_dataset = datasets.MNIST( + os.path.join(flags.datadir, str(xm.get_ordinal())), + train=True, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,))])) + test_dataset = datasets.MNIST( + os.path.join(flags.datadir, str(xm.get_ordinal())), + train=False, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,))])) + train_sampler = None + if xm.xrt_world_size() > 1: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal(), + shuffle=True) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=flags.batch_size, + sampler=train_sampler, + drop_last=flags.drop_last, + shuffle=False if train_sampler else True, + num_workers=flags.num_workers) + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=flags.batch_size, + drop_last=flags.drop_last, + shuffle=False, + num_workers=flags.num_workers) + + # Scale learning rate to num cores + lr = flags.lr * xm.xrt_world_size() + + device = xm.xla_device() + model = MNIST().to(device) + writer = None + if xm.is_master_ordinal(): + writer = test_utils.get_summary_writer(flags.logdir) + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) + loss_fn = nn.NLLLoss() + scaler = GradScaler() + + def train_loop_fn(loader): + tracker = xm.RateTracker() + model.train() + for step, (data, target) in enumerate(loader): + optimizer.zero_grad() + with autocast(): + output = model(data) + loss = loss_fn(output, target) + scaler.scale(loss).backward() + gradients = xm._fetch_gradients(optimizer) + xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) + scaler.step(optimizer) + scaler.update() + xm.mark_step() + tracker.add(flags.batch_size) + if step % flags.log_steps == 0: + xm.add_step_closure( + _train_update, args=(device, step, loss, tracker, writer)) + + def test_loop_fn(loader): + total_samples = 0 + correct = 0 + model.eval() + for data, target in loader: + output = model(data) + pred = output.max(1, keepdim=True)[1] + correct += pred.eq(target.view_as(pred)).sum() + total_samples += data.size()[0] + + accuracy = 100.0 * correct.item() / total_samples + accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) + return accuracy + + train_device_loader = pl.MpDeviceLoader(train_loader, device) + test_device_loader = pl.MpDeviceLoader(test_loader, device) + accuracy, max_accuracy = 0.0, 0.0 + for epoch in range(1, flags.num_epochs + 1): + xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) + train_loop_fn(train_device_loader) + xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) + + accuracy = test_loop_fn(test_device_loader) + xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( + epoch, test_utils.now(), accuracy)) + max_accuracy = max(accuracy, max_accuracy) + test_utils.write_to_summary( + writer, + epoch, + dict_to_write={'Accuracy/test': accuracy}, + write_xla_metrics=True) + if flags.metrics_debug: + xm.master_print(met.metrics_report()) + + test_utils.close_summary_writer(writer) + xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) + return max_accuracy + + +def _mp_fn(index, flags): + torch.set_default_tensor_type('torch.FloatTensor') + accuracy = train_mnist(flags) + if flags.tidy and os.path.isdir(flags.datadir): + shutil.rmtree(flags.datadir) + if accuracy < flags.target_accuracy: + print('Accuracy {} is below target {}'.format(accuracy, + flags.target_accuracy)) + sys.exit(21) + + +if __name__ == '__main__': + xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores) diff --git a/torch_xla/amp/__init__.py b/torch_xla/amp/__init__.py index de6de31877e..1c0ecd08876 100644 --- a/torch_xla/amp/__init__.py +++ b/torch_xla/amp/__init__.py @@ -1,2 +1,2 @@ from .autocast_mode import autocast, custom_fwd, custom_bwd # noqa: F401 -from .grad_scaler import GradScaler # noqa: F401 \ No newline at end of file +from .grad_scaler import GradScaler # noqa: F401 diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 76ac397b1cb..fd89ac874c8 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -305,11 +305,10 @@ at::Tensor AtenXlaType::_adaptive_avg_pool2d_backward( void AtenXlaType::_amp_foreach_non_finite_check_and_unscale_( at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale) { XLA_FN_COUNTER("xla::"); - auto xla_self = bridge::GetXlaTensors(self); - absl::Span self_tensor{xla_self}; XLATensor found_inf_tensor = bridge::GetXlaTensor(found_inf); XLATensor::_amp_foreach_non_finite_check_and_unscale_( - self_tensor, found_inf_tensor, bridge::GetXlaTensor(inv_scale)); + bridge::GetXlaTensors(self), found_inf_tensor, + bridge::GetXlaTensor(inv_scale)); } at::Tensor AtenXlaType::_amp_update_scale(at::Tensor& growth_tracker, diff --git a/torch_xla/csrc/batch_norm.cpp b/torch_xla/csrc/batch_norm.cpp index 92250fa8ae2..315538dc0c2 100644 --- a/torch_xla/csrc/batch_norm.cpp +++ b/torch_xla/csrc/batch_norm.cpp @@ -9,11 +9,9 @@ namespace { bool IsF32BatchNormWithFP16Inputs(const xla::XlaOp& input, const xla::XlaOp& weight) { - xla::XlaBuilder* builder = input.builder(); - if (builder->GetShape(input).ok() && builder->GetShape(weight).ok() && - builder->GetShape(input).ValueOrDie().element_type() == + if (XlaHelpers::ShapeOfXlaOp(input).element_type() == xla::PrimitiveType::F16 && - builder->GetShape(weight).ValueOrDie().element_type() == + XlaHelpers::ShapeOfXlaOp(weight).element_type() == xla::PrimitiveType::F32) { return true; } @@ -42,6 +40,7 @@ BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight, xla::XlaOp bias, float eps_value) { bool is_batchnorm_with_fp16_inputs = IsF32BatchNormWithFP16Inputs(input, weight); + // Handle the mixed precision use case. if (is_batchnorm_with_fp16_inputs) { input = xla::ConvertElementType(input, xla::PrimitiveType::F32); } @@ -61,6 +60,7 @@ xla::XlaOp BuildBatchNormInference(xla::XlaOp input, xla::XlaOp weight, xla::XlaOp variance, float eps_value) { bool is_batchnorm_with_fp16_inputs = IsF32BatchNormWithFP16Inputs(input, weight); + // Handle the mixed precision use case. if (is_batchnorm_with_fp16_inputs) { input = xla::ConvertElementType(input, xla::PrimitiveType::F32); } @@ -79,6 +79,7 @@ BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input, float eps_value) { bool is_batchnorm_with_fp16_inputs = IsF32BatchNormWithFP16Inputs(input, weight); + // Handle the mixed precision use case. if (is_batchnorm_with_fp16_inputs) { input = xla::ConvertElementType(input, xla::PrimitiveType::F32); grad = xla::ConvertElementType(grad, xla::PrimitiveType::F32); diff --git a/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp index 777fee03712..c1da99ce656 100644 --- a/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp +++ b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp @@ -12,39 +12,56 @@ namespace ir { namespace ops { namespace { -xla::Shape NodeOutputShape(const OpList& inputs) { +xla::Shape NodeOutputShape(const OpList& inputs, const Value& found_inf) { std::vector output_shapes; - output_shapes.reserve(inputs.size() - 1); - for (size_t i = 0; i < inputs.size() - 2; ++i) { + output_shapes.reserve(inputs.size() + 1); + for (size_t i = 0; i < inputs.size(); ++i) { const xla::Shape& input_shape = inputs[i].shape(); output_shapes.push_back(input_shape); } - output_shapes.push_back(xla::ShapeUtil::MakeShape( - inputs[inputs.size() - 2].shape().element_type(), {})); + output_shapes.push_back( + xla::ShapeUtil::MakeShape(found_inf.shape().element_type(), {})); return xla::ShapeUtil::MakeTupleShape(output_shapes); } +std::vector GetOperandList(absl::Span operands, + const Value& found_inf, + const Value& inv_scale) { + std::vector operand_list(operands.begin(), operands.end()); + operand_list.push_back(found_inf); + operand_list.push_back(inv_scale); + return operand_list; +} + } // namespace AmpForachNonFiniteCheckAndUnscale::AmpForachNonFiniteCheckAndUnscale( - const OpList& inputs) - : Node(xla_amp_foreach_non_finite_check_and_unscale, inputs, - NodeOutputShape(inputs), - /*num_outputs=*/inputs.size() - 1) {} + const OpList& inputs, const Value& found_inf, const Value& inv_scale) + : Node(ir::OpKind(at::aten::_amp_foreach_non_finite_check_and_unscale_), + GetOperandList(inputs, found_inf, inv_scale), + NodeOutputShape(inputs, found_inf), + /*num_outputs=*/inputs.size() + 1) {} NodePtr AmpForachNonFiniteCheckAndUnscale::Clone(OpList operands) const { - return MakeNode(operands); + std::vector operand_list(operands.begin(), operands.end() - 2); + size_t sz = operand_list.size(); + return MakeNode(operand_list, operands[sz], + operands[sz + 1]); } XlaOpVector AmpForachNonFiniteCheckAndUnscale::Lower( LoweringContext* loctx) const { std::vector inputs; - for (size_t i = 0; i < num_outputs() + 1; ++i) { + for (size_t i = 0; i < operands().size() - 2; ++i) { inputs.push_back(loctx->GetOutputOp(operand(i))); } - return ReturnOps(BuildAmpForachNonFiniteCheckAndUnscale(inputs), loctx); + return ReturnOps( + BuildAmpForeachNonFiniteCheckAndUnscale( + inputs, loctx->GetOutputOp(operand(operands().size() - 2)), + loctx->GetOutputOp(operand(operands().size() - 1))), + loctx); } } // namespace ops } // namespace ir -} // namespace torch_xla \ No newline at end of file +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h index 6ab1988546b..e651e1fc0ee 100644 --- a/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h +++ b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h @@ -8,7 +8,9 @@ namespace ops { class AmpForachNonFiniteCheckAndUnscale : public Node { public: - AmpForachNonFiniteCheckAndUnscale(const OpList& inputs); + AmpForachNonFiniteCheckAndUnscale(const OpList& inputs, + const Value& found_inf, + const Value& inv_scale); NodePtr Clone(OpList operands) const override; @@ -17,4 +19,4 @@ class AmpForachNonFiniteCheckAndUnscale : public Node { } // namespace ops } // namespace ir -} // namespace torch_xla \ No newline at end of file +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/amp_update_scale.cpp b/torch_xla/csrc/ops/amp_update_scale.cpp index 97ea894d9ad..a16eed38af1 100644 --- a/torch_xla/csrc/ops/amp_update_scale.cpp +++ b/torch_xla/csrc/ops/amp_update_scale.cpp @@ -12,33 +12,42 @@ namespace ir { namespace ops { namespace { -xla::Shape NodeOutputShape(const OpList& inputs) { - std::vector output_shapes; - for (size_t i = 0; i < 2; ++i) { - const xla::Shape& input_shape = inputs[i].shape(); - output_shapes.push_back(input_shape); - } - return xla::ShapeUtil::MakeTupleShape(output_shapes); +xla::Shape NodeOutputShape(const Value& growth_tracker, + const Value& current_scale) { + return xla::ShapeUtil::MakeTupleShape( + {growth_tracker.shape(), current_scale.shape()}); } } // namespace -AmpUpdateScale::AmpUpdateScale(const OpList& inputs) - : Node(xla_amp_update_scale, inputs, NodeOutputShape(inputs), - /*num_outputs=*/2) {} +AmpUpdateScale::AmpUpdateScale(const Value& growth_tracker, + const Value& current_scale, + const Value& found_inf, + double scale_growth_factor, + double scale_backoff_factor, int growth_interval) + : Node(ir::OpKind(at::aten::_amp_update_scale), + {growth_tracker, current_scale, found_inf}, + NodeOutputShape(growth_tracker, current_scale), + /*num_outputs=*/2), + scale_growth_factor_(scale_growth_factor), + scale_backoff_factor_(scale_backoff_factor), + growth_interval_(growth_interval) {} NodePtr AmpUpdateScale::Clone(OpList operands) const { - return MakeNode(operands); + return MakeNode(operands[0], operands[1], operands[2], + scale_growth_factor_, scale_backoff_factor_, + growth_interval_); } XlaOpVector AmpUpdateScale::Lower(LoweringContext* loctx) const { - std::vector inputs; - for (size_t i = 0; i < 6; ++i) { - inputs.push_back(loctx->GetOutputOp(operand(i))); - } - return ReturnOps(BuildAmpUpdateScale(inputs), loctx); + return ReturnOps( + BuildAmpUpdateScale(loctx->GetOutputOp(operand(0)), + loctx->GetOutputOp(operand(1)), + loctx->GetOutputOp(operand(2)), scale_growth_factor_, + scale_backoff_factor_, growth_interval_), + loctx); } } // namespace ops } // namespace ir -} // namespace torch_xla \ No newline at end of file +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/amp_update_scale.h b/torch_xla/csrc/ops/amp_update_scale.h index e07933d8b10..145669a5de6 100644 --- a/torch_xla/csrc/ops/amp_update_scale.h +++ b/torch_xla/csrc/ops/amp_update_scale.h @@ -8,13 +8,20 @@ namespace ops { class AmpUpdateScale : public Node { public: - AmpUpdateScale(const OpList& inputs); + AmpUpdateScale(const Value& growth_tracker, const Value& current_scale, + const Value& found_inf, double scale_growth_factor, + double scale_backoff_factor, int growth_interval); NodePtr Clone(OpList operands) const override; XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + double scale_growth_factor_; + double scale_backoff_factor_; + int growth_interval_; }; } // namespace ops } // namespace ir -} // namespace torch_xla \ No newline at end of file +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/xla_ops.cpp b/torch_xla/csrc/ops/xla_ops.cpp index 852f95398c3..a669b5b9863 100644 --- a/torch_xla/csrc/ops/xla_ops.cpp +++ b/torch_xla/csrc/ops/xla_ops.cpp @@ -5,9 +5,6 @@ namespace ir { namespace ops { const OpKindWrapper xla_all_to_all("xla::all_to_all"); -const OpKindWrapper xla_amp_foreach_non_finite_check_and_unscale( - "xla::amp_foreach_non_finite_check_and_unscale"); -const OpKindWrapper xla_amp_update_scale("xla::amp_update_scale"); const OpKindWrapper xla_as_strided_view_update("xla::as_strided_view_update"); const OpKindWrapper xla_cast("xla::cast"); const OpKindWrapper xla_collective_permute("xla::collective_permute"); diff --git a/torch_xla/csrc/ops/xla_ops.h b/torch_xla/csrc/ops/xla_ops.h index c15bcc3dada..6b8e1cbaab4 100644 --- a/torch_xla/csrc/ops/xla_ops.h +++ b/torch_xla/csrc/ops/xla_ops.h @@ -29,8 +29,6 @@ class OpKindWrapper { }; extern const OpKindWrapper xla_all_to_all; -extern const OpKindWrapper xla_amp_foreach_non_finite_check_and_unscale; -extern const OpKindWrapper xla_amp_update_scale; extern const OpKindWrapper xla_as_strided_view_update; extern const OpKindWrapper xla_cast; extern const OpKindWrapper xla_collective_permute; diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 840410dae2d..8063f1bfb91 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -247,14 +247,14 @@ class XLATensor { const XLATensor& input); static void _amp_foreach_non_finite_check_and_unscale_( - absl::Span self, XLATensor& found_inf, + std::vector self, XLATensor& found_inf, const XLATensor& inv_scale); static XLATensor _amp_update_scale(XLATensor growth_tracker, const XLATensor& current_scale, const XLATensor& found_inf, - float scale_growth_factor, - float scale_backoff_factor, + double scale_growth_factor, + double scale_backoff_factor, int growth_interval); static XLATensor abs(const XLATensor& input); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 912684829a5..c8d28adfbda 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -471,17 +471,14 @@ XLATensor XLATensor::_adaptive_avg_pool2d_backward(const XLATensor& grad_output, } void XLATensor::_amp_foreach_non_finite_check_and_unscale_( - absl::Span self, XLATensor& found_inf, + std::vector self, XLATensor& found_inf, const XLATensor& inv_scale) { std::vector inputs; for (const auto& x : self) { inputs.push_back(x.GetIrValue()); } - inputs.push_back(found_inf.GetIrValue()); - inputs.push_back(inv_scale.GetIrValue()); - absl::Span inputs_span{inputs}; - ir::NodePtr node = - ir::MakeNode(inputs_span); + ir::NodePtr node = ir::MakeNode( + inputs, found_inf.GetIrValue(), inv_scale.GetIrValue()); for (size_t i = 0; i < self.size(); ++i) { self[i].SetInPlaceIrValue(ir::Value(node, i)); } @@ -491,25 +488,13 @@ void XLATensor::_amp_foreach_non_finite_check_and_unscale_( XLATensor XLATensor::_amp_update_scale(XLATensor growth_tracker, const XLATensor& current_scale, const XLATensor& found_inf, - float scale_growth_factor, - float scale_backoff_factor, + double scale_growth_factor, + double scale_backoff_factor, int growth_interval) { - ir::Value scale_growth_factor_ir = GetIrValueForScalar( - scale_growth_factor, xla::PrimitiveType::F32, growth_tracker.GetDevice()); - ir::Value scale_backoff_factor_ir = - GetIrValueForScalar(scale_backoff_factor, xla::PrimitiveType::F32, - growth_tracker.GetDevice()); - ir::Value growth_interval_ir = GetIrValueForScalar( - growth_interval, xla::PrimitiveType::S32, growth_tracker.GetDevice()); - std::vector inputs; - inputs.push_back(growth_tracker.GetIrValue()); - inputs.push_back(current_scale.GetIrValue()); - inputs.push_back(found_inf.GetIrValue()); - inputs.push_back(scale_growth_factor_ir); - inputs.push_back(scale_backoff_factor_ir); - inputs.push_back(growth_interval_ir); - absl::Span inputs_span{inputs}; - ir::NodePtr node = ir::MakeNode(inputs_span); + ir::NodePtr node = ir::MakeNode( + growth_tracker.GetIrValue(), current_scale.GetIrValue(), + found_inf.GetIrValue(), scale_growth_factor, scale_backoff_factor, + growth_interval); growth_tracker.SetInPlaceIrValue(ir::Value(node, 0)); return current_scale.CreateFrom(ir::Value(node, 1)); } diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 53bdffc7922..79699fe5ac8 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -749,63 +749,71 @@ xla::XlaOp BuildMaskedScatter(xla::XlaOp input, xla::XlaOp mask, scatter_dnums); } -std::vector BuildAmpForachNonFiniteCheckAndUnscale( - const std::vector& inputs) { - const xla::PrimitiveType origin_type = xla::PrimitiveType::F32; - std::vector found_infs; - xla::XlaOp one = xla::One(inputs[0].builder(), xla::PrimitiveType::S32); - for (size_t i = 0; i < inputs.size() - 2; ++i) { +std::vector BuildAmpForeachNonFiniteCheckAndUnscale( + const std::vector& inputs, const xla::XlaOp& found_inf_float, + const xla::XlaOp& inv_scale) { + const xla::PrimitiveType origin_type = + XlaHelpers::ShapeOfXlaOp(found_inf_float).element_type(); + xla::XlaOp one = xla::One(inputs[0].builder(), xla::PrimitiveType::PRED); + xla::XlaOp found_inf = + xla::ConvertElementType(found_inf_float, xla::PrimitiveType::PRED); + for (size_t i = 0; i < inputs.size(); ++i) { xla::XlaOp all_finite = - xla::ReduceAll(xla::ConvertElementType(xla::IsFinite(inputs[i]), - xla::PrimitiveType::S32), - one, - xla::CreateScalarAndComputation(xla::PrimitiveType::S32, + xla::ReduceAll(xla::IsFinite(inputs[i]), one, + xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, inputs[i].builder())); - found_infs.push_back(one - all_finite); + found_inf = xla::Or(found_inf, xla::Not(all_finite)); } - xla::XlaOp found_inf = xla::ConvertElementType(inputs[inputs.size() - 2], - xla::PrimitiveType::S32); - for (size_t i = 0; i < found_infs.size(); ++i) { - found_inf = xla::Or(found_inf, found_infs[i]); - } - xla::XlaOp inv_scale = inputs[inputs.size() - 1]; std::vector results; - - for (size_t i = 0; i < inputs.size() - 2; ++i) { + for (size_t i = 0; i < inputs.size(); ++i) { results.push_back(inputs[i] * inv_scale); } results.push_back(xla::ConvertElementType(found_inf, origin_type)); return results; } -std::vector BuildAmpUpdateScale( - const std::vector& inputs) { - const auto& growth_tracker = inputs[0]; +std::vector BuildAmpUpdateScale(const xla::XlaOp& growth_tracker, + const xla::XlaOp& current_scale, + const xla::XlaOp& found_inf_float, + double scale_growth_factor, + double scale_backoff_factor, + int scale_growth_interval) { xla::XlaOp one = xla::One(growth_tracker.builder(), xla::PrimitiveType::S32); xla::XlaOp one_float = xla::One(growth_tracker.builder(), xla::PrimitiveType::F32); - const auto& current_scale = inputs[1]; - const auto& found_inf = xla::Min( - xla::ConvertElementType(inputs[2], xla::PrimitiveType::S32), one); - const auto& growth_factor = inputs[3]; - const auto& backoff_factor = inputs[4]; - const auto& growth_interval = inputs[5]; - - xla::XlaOp all_finite = one - found_inf; - xla::XlaOp not_achieve_interval = - xla::Min((growth_interval - one - growth_tracker), one); + xla::XlaOp found_inf = + xla::ConvertElementType(found_inf_float, xla::PrimitiveType::PRED); + const auto& growth_factor = XlaHelpers::ScalarValue( + scale_growth_factor, + XlaHelpers::ShapeOfXlaOp(current_scale).element_type(), + growth_tracker.builder()); + const auto& backoff_factor = XlaHelpers::ScalarValue( + scale_backoff_factor, + XlaHelpers::ShapeOfXlaOp(current_scale).element_type(), + growth_tracker.builder()); + const auto& growth_interval = XlaHelpers::ScalarValue( + scale_growth_interval, + XlaHelpers::ShapeOfXlaOp(growth_tracker).element_type(), + growth_tracker.builder()); + + xla::XlaOp all_finite = xla::Not(found_inf); + xla::XlaOp not_achieve_interval = xla::ConvertElementType( + growth_interval - one - growth_tracker, xla::PrimitiveType::PRED); xla::XlaOp new_growth_tracker = - (growth_tracker + one) * all_finite * not_achieve_interval; + (growth_tracker + one) * + ConvertElementType(xla::And(all_finite, not_achieve_interval), + xla::PrimitiveType::S32); + xla::XlaOp growth_factor_or_one = xla::Max( + growth_factor * xla::ConvertElementType( + xla::And(all_finite, xla::Not(not_achieve_interval)), + xla::PrimitiveType::F32), + one_float); + xla::XlaOp backoff_factor_or_one = + backoff_factor * + xla::ConvertElementType(found_inf, xla::PrimitiveType::F32) + + xla::ConvertElementType(all_finite, xla::PrimitiveType::F32); xla::XlaOp new_scale = - current_scale * - xla::Max(growth_factor * xla::ConvertElementType( - all_finite * (one - not_achieve_interval), - xla::PrimitiveType::F32), - one_float) * - (backoff_factor * - xla::ConvertElementType(found_inf, xla::PrimitiveType::F32) + - xla::ConvertElementType((one - found_inf) * one, - xla::PrimitiveType::F32)); + current_scale * growth_factor_or_one * backoff_factor_or_one; std::vector results; results.push_back(new_growth_tracker); results.push_back(new_scale); diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index ace04605ea2..f474b009c36 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -88,10 +88,15 @@ std::vector BuildMaskedSelect(xla::XlaOp input, xla::XlaOp mask); xla::XlaOp BuildMaskedScatter(xla::XlaOp input, xla::XlaOp mask, xla::XlaOp source); -std::vector BuildAmpForachNonFiniteCheckAndUnscale( - const std::vector& inputs); - -std::vector BuildAmpUpdateScale( - const std::vector& inputs); +std::vector BuildAmpForeachNonFiniteCheckAndUnscale( + const std::vector& inputs, const xla::XlaOp& found_inf_float, + const xla::XlaOp& inv_scale); + +std::vector BuildAmpUpdateScale(const xla::XlaOp& growth_tracker, + const xla::XlaOp& current_scale, + const xla::XlaOp& found_inf, + double scale_growth_factor, + double scale_backoff_factor, + int scale_growth_interval); } // namespace torch_xla From 92118558237ae8ae1237b81e16fd5c8517720d8c Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Thu, 18 Feb 2021 13:03:17 +0800 Subject: [PATCH 4/4] xla amp could handle shape (1, ) now. --- .circleci/test.sh | 2 +- .../csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp | 3 +-- torch_xla/csrc/tensor_methods.cpp | 3 ++- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.circleci/test.sh b/.circleci/test.sh index 9a67de17f9b..43e45635873 100755 --- a/.circleci/test.sh +++ b/.circleci/test.sh @@ -21,7 +21,7 @@ echo "Running Python Tests" echo "Running MNIST Test" python test/test_train_mnist.py --tidy if [ -x "$(command -v nvidia-smi)" ]; then - python test/test_train_mp_mnist_amp.py + python test/test_train_mp_mnist_amp.py --fake_data fi echo "Running C++ Tests" diff --git a/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp index c1da99ce656..19da4ea2881 100644 --- a/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp +++ b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp @@ -19,8 +19,7 @@ xla::Shape NodeOutputShape(const OpList& inputs, const Value& found_inf) { const xla::Shape& input_shape = inputs[i].shape(); output_shapes.push_back(input_shape); } - output_shapes.push_back( - xla::ShapeUtil::MakeShape(found_inf.shape().element_type(), {})); + output_shapes.push_back(found_inf.shape()); return xla::ShapeUtil::MakeTupleShape(output_shapes); } diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index c8d28adfbda..6329fefd968 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -474,11 +474,12 @@ void XLATensor::_amp_foreach_non_finite_check_and_unscale_( std::vector self, XLATensor& found_inf, const XLATensor& inv_scale) { std::vector inputs; + XLATensor new_inv_scale = XLATensor::max(inv_scale); for (const auto& x : self) { inputs.push_back(x.GetIrValue()); } ir::NodePtr node = ir::MakeNode( - inputs, found_inf.GetIrValue(), inv_scale.GetIrValue()); + inputs, found_inf.GetIrValue(), new_inv_scale.GetIrValue()); for (size_t i = 0; i < self.size(); ++i) { self[i].SetInPlaceIrValue(ir::Value(node, i)); }