diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index 3df77096f3b8d..7c6ecf77ea2c4 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -10,8 +10,8 @@ import torch._inductor -# The rest of the optimizers not yet imported: Adamax, ASGD, LBFGS, NAdam, RAdam, SGD, SparseAdam -from torch.optim import Adadelta, Adagrad, Adam, AdamW, RMSprop, Rprop +# The rest of the optimizers not yet imported: Adamax, LBFGS, NAdam, RAdam, SGD, SparseAdam +from torch.optim import Adadelta, Adagrad, Adam, AdamW, ASGD, RMSprop, Rprop from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase @@ -76,6 +76,8 @@ def test_fn(self): with torch.set_grad_enabled(False): compiled_step() + compiled_step() + opt_eager.step() opt_eager.step() self.assertEqual( @@ -166,6 +168,9 @@ def tearDown(self): test_rmsprop = make_test(RMSprop, kernel_count=1, lr=0.01) test_adadelta = make_test(Adadelta, kernel_count=1, lr=0.01) test_adagrad = make_test(Adagrad, kernel_count=5, lr=0.01) + test_asgd_default = make_test(ASGD, kernel_count=2, lr=0.1) + test_asgd_single = make_test(ASGD, kernel_count=12, lr=0.1, foreach=False) + test_asgd_foreach = make_test(ASGD, kernel_count=2, lr=0.1, foreach=True) # test_sgd = make_test(SGD, kernel_count=1, lr=0.01) test_adam_recompile = make_recompile_test(Adam, lr=0.01) @@ -177,6 +182,13 @@ def tearDown(self): test_rmsprop_recompile = make_recompile_test(RMSprop, kernel_count=1, lr=0.01) test_adadelta_recompile = make_recompile_test(Adadelta, kernel_count=1, lr=0.01) test_adagrad_recompile = make_recompile_test(Adagrad, kernel_count=5, lr=0.01) + test_asgd_recompile_default = make_recompile_test(ASGD, kernel_count=2, lr=0.01) + test_asgd_recompile_single = make_recompile_test( + ASGD, kernel_count=12, lr=0.01, foreach=False + ) + test_asgd_recompile_foreach = make_recompile_test( + ASGD, kernel_count=2, lr=0.01, foreach=True + ) # test_sgd_recompile = make_recompile_test(SGD, kernel_count=1, lr=0.01) @requires_cuda() diff --git a/test/optim/test_optim.py b/test/optim/test_optim.py index 167786973e1c1..a1bd45b76fbb0 100644 --- a/test/optim/test_optim.py +++ b/test/optim/test_optim.py @@ -669,6 +669,9 @@ def _test_derived_optimizers_varying_tensors(self, optimizer_with_kwargs, kwarg) res, state = [], [] for enabled in (False, True): kwargs_clone = deepcopy(kwargs) + if optimizer_constructor.__name__ == "ASGD" and kwarg == "foreach" and not enabled: + # single tensor ASGD does not support capturable + kwargs_clone["capturable"] = False kwargs_clone[kwarg] = enabled params_clone = [] @@ -731,6 +734,9 @@ def _test_derived_optimizers(self, optimizer_pairs_with_flags, flag): ) model.to(dtype=torch.float64, device=device) params_with_flags = deepcopy(params) + if optimizer_constructor.__name__ == "ASGD" and flag == "foreach" and not flag_value: + # single tensor ASGD does not support capturable + params_with_flags["capturable"] = False params_with_flags[flag] = flag_value # foreach/fused optimizers should be tested with a param_groups['params'] with @@ -779,7 +785,12 @@ def _test_foreach_memory(self, optimizer_pairs_with_flags): max_mems = [] for flag_value in (False, True): kwargs_with_flags = deepcopy(kwargs) - kwargs_with_flags['foreach'] = flag_value + if optimizer_constructor.__name__ == "ASGD" and kwargs_with_flags.get("capturable", False) and not flag_value: + # single tensor ASGD does not support capturable + kwargs_with_flags["capturable"] = False + + kwargs_with_flags["foreach"] = flag_value + # The 128 is critical here! Our CUDACachingAllocator allocates in blocks of 512, # meaning any tensor that occupies <512 bytes of memory will allocate a whole @@ -807,14 +818,16 @@ def _test_foreach_memory(self, optimizer_pairs_with_flags): intermediate_size = nparams * param.nelement() * param.element_size() nintermediates = 1 # we expect a budget of 1 intermediate most of the time if (('capturable' in kwargs_with_flags and kwargs_with_flags['capturable']) or - optimizer_constructor.__name__ == "Adadelta"): + optimizer_constructor.__name__ in ["Adadelta", "ASGD"]): # with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections # with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps) + # ASGD allocates axs, 2x mus, 2x etas, and grads at the same time nintermediates = 3 if optimizer_constructor.__name__ == "NAdam": # with capturable in NAdam, we have 3 extra intermediates for the # bias_correction, mus, and mu_nexts nintermediates = 5 + elif optimizer_constructor.__name__ in ["NAdam", "Adagrad", "RMSprop"]: # NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt) # Adagrad uses std and grads at the same time @@ -896,6 +909,10 @@ def _multi_tensor_optimizer_configs(self): (optim.ASGD, dict(weight_decay=1)), (optim.ASGD, dict(weight_decay=0, maximize=True)), (optim.ASGD, dict(weight_decay=1, maximize=True)), + (optim.ASGD, dict(weight_decay=0, capturable=True)), + (optim.ASGD, dict(weight_decay=1, capturable=True)), + (optim.ASGD, dict(weight_decay=0, maximize=True, capturable=True)), + (optim.ASGD, dict(weight_decay=1, maximize=True, capturable=True)), (optim.Adamax, dict(weight_decay=0)), (optim.Adamax, dict(weight_decay=1)), (optim.Adamax, dict(weight_decay=0, maximize=True)), diff --git a/test/test_cuda.py b/test/test_cuda.py index 3a274f6a6ede5..534ae8aa62124 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -3060,8 +3060,12 @@ def test_graph_optims(self): ] + [ (optimizer_ctor, {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}) for optimizer_ctor, amsgrad in product((torch.optim.Adam, torch.optim.AdamW), (False, True)) + ] + [ + (torch.optim.ASGD, {"lr": 0.1, "foreach": True, "maximize": maximize, "weight_decay": weight_decay}) + for maximize, weight_decay in product((False, True), (0.0, 0.1)) ] + for optimizer_ctor, kwargs in cases: with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs): self._test_graphed_optimizer(3, 2, optimizer_ctor, kwargs) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 8a6d229e4110f..638e046b27365 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2891,6 +2891,7 @@ def meta__foreach_binop__list(self, other, alpha=1): aten._foreach_mul_.Scalar, aten._foreach_sub_.Scalar, aten._foreach_div_.Scalar, + aten._foreach_maximum_.Scalar, ] ) def meta__foreach_binop__scalar(self, scalar=1): diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index 5e140b0ca2ad7..ee4ad7f019169 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -2,15 +2,15 @@ from torch import Tensor from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _default_to_fused_or_foreach, - _differentiable_doc, _foreach_doc, _maximize_doc) + _differentiable_doc, _foreach_doc, _maximize_doc, _capturable_doc) from torch._utils import is_compiling from typing import List, Optional __all__ = ["ASGD", "asgd"] -def _to_tensor(x): +def _to_tensor(x, device=None): if not isinstance(x, torch.Tensor): - return torch.tensor(x) + return torch.tensor(x, device=device) return x @@ -26,12 +26,16 @@ def __init__( foreach: Optional[bool] = None, maximize: bool = False, differentiable: bool = False, + capturable: bool = False, ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if foreach is False and capturable: + raise ValueError("Capturable not supported with single tensor ASGD") + defaults = dict( lr=lr, lambd=lambd, @@ -41,6 +45,7 @@ def __init__( foreach=foreach, maximize=maximize, differentiable=differentiable, + capturable=capturable, ) super().__init__(params, defaults) @@ -50,6 +55,7 @@ def __setstate__(self, state): group.setdefault("foreach", None) group.setdefault("maximize", False) group.setdefault("differentiable", False) + group.setdefault("capturable", False) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( state_values[0]["step"] @@ -81,9 +87,9 @@ def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_step state = self.state[p] # State initialization if len(state) == 0: - state["step"] = torch.tensor(0.0) - state["eta"] = torch.tensor(group["lr"]) - state["mu"] = torch.tensor(1.0) + state["step"] = torch.zeros((), device=p.device) + state["eta"] = torch.tensor(group["lr"], device=p.device) + state["mu"] = torch.ones((), device=p.device) state["ax"] = torch.zeros_like( p, memory_format=torch.preserve_format ) @@ -131,6 +137,7 @@ def step(self, closure=None): foreach=group["foreach"], maximize=group["maximize"], differentiable=group["differentiable"], + capturable=group["capturable"], ) return loss @@ -152,6 +159,7 @@ def step(self, closure=None): {_foreach_doc} {_maximize_doc} {_differentiable_doc} + {_capturable_doc} For ASGD, capturable is only supported when foreach is True. .. _Acceleration of stochastic approximation by averaging: https://dl.acm.org/citation.cfm?id=131098 @@ -171,6 +179,7 @@ def asgd( foreach: Optional[bool] = None, maximize: bool = False, differentiable: bool = False, + capturable: bool = False, *, lambd: float, lr: float, @@ -192,6 +201,8 @@ def asgd( if foreach and not torch.jit.is_scripting(): func = _multi_tensor_asgd else: + if capturable and not is_compiling(): + raise RuntimeError("Capturable not supported with single tensor ASGD") func = _single_tensor_asgd func( @@ -208,6 +219,7 @@ def asgd( weight_decay=weight_decay, maximize=maximize, differentiable=differentiable, + capturable=capturable, ) @@ -226,12 +238,8 @@ def _single_tensor_asgd( weight_decay: float, maximize: bool, differentiable: bool, + capturable: bool, ): - def _to_tensor(x): - if not isinstance(x, torch.Tensor): - return torch.tensor(x) - return x - for i, param in enumerate(params): grad = grads[i] grad = grad if not maximize else -grad @@ -286,6 +294,7 @@ def _multi_tensor_asgd( weight_decay: float, maximize: bool, differentiable: bool, + capturable: bool, ): if len(params) == 0: @@ -294,8 +303,8 @@ def _multi_tensor_asgd( assert not differentiable, "_foreach ops don't support autograd" grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps]) - for ((grouped_params, grouped_grads, grouped_axs, grouped_mus, - grouped_etas, grouped_state_steps), _) in grouped_tensors.values(): + for ((device, _), ((grouped_params, grouped_grads, grouped_axs, grouped_mus, + grouped_etas, grouped_state_steps), _)) in grouped_tensors.items(): if maximize: grouped_grads = torch._foreach_neg(grouped_grads) @@ -311,32 +320,65 @@ def _view_complex_as_real(tensor_list): # update step torch._foreach_add_(grouped_state_steps, 1) + # intermediate = grad + param * lambd if weight_decay != 0: - # Re-use the intermediate memory (grouped_grads) already allocated for maximize if maximize: torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) + intermediate = grouped_grads else: - grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay) - - # decay term - eta = _get_value(grouped_etas[0]) - torch._foreach_mul_(grouped_params, 1 - lambd * eta) - - # update parameter - torch._foreach_add_(grouped_params, grouped_grads, alpha=-eta) - - # averaging - for i in range(len(grouped_axs)): - if is_compiling() or grouped_mus[i].item() != 1: - grouped_axs[i].add_(grouped_params[i].sub(grouped_axs[i]).mul(grouped_mus[i])) - else: - grouped_axs[i].copy_(grouped_params[i]) + intermediate = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay) - # update eta and mu - for i in range(len(grouped_mus)): - new_eta = _to_tensor( - lr / (1 + lambd * lr * _get_value(grouped_state_steps[i]) ** alpha) - ) - grouped_etas[i].copy_(new_eta) - new_mu = _to_tensor(1 / max(1, _get_value(grouped_state_steps[i]) - t0)) - grouped_mus[i].copy_(new_mu) + torch._foreach_add_(intermediate, grouped_params, alpha=lambd) + else: + intermediate = torch._foreach_add(grouped_grads, grouped_params, alpha=lambd) + + # update param + # param * (1 - lambd * eta) - eta * grad + # => param - param * lambd * eta - eta * grad + # => param - eta * intermediate + torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1) + del intermediate + + # update grouped_axs + # averaging: ax = ax + mu * (param - ax) + # Note (mlazos): We can't use lerp here since it requires weight to be float64 + # and our grouping code requires dtypes to match for all tensors in a group (and it should, since + # we use the mus in other places) + # all dtypes need to match, so we could introduce a cast in a loop + # but since this only adds one additional kernel launch, this looks like the cleaner + # and faster solution + intermediate = torch._foreach_sub(grouped_params, grouped_axs) + torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus) + del intermediate + + if capturable: + # update grouped_mus + new_mus = torch._foreach_sub(grouped_state_steps, t0) + torch._foreach_maximum_(new_mus, 1.0) + torch._foreach_reciprocal_(new_mus) + torch._foreach_copy_(grouped_mus, new_mus) + del new_mus + + # update eta = lr / (1 + lambd * lr * step^alpha) + new_etas = torch._foreach_pow(grouped_state_steps, alpha) + torch._foreach_mul_(new_etas, lambd) + torch._foreach_mul_(new_etas, lr) + torch._foreach_add_(new_etas, 1) + torch._foreach_reciprocal_(new_etas) + torch._foreach_mul_(new_etas, lr) + torch._foreach_copy_(grouped_etas, new_etas) + else: + step = grouped_state_steps[0].item() + new_etas = [] + new_mus = [] + + for i in range(len(grouped_mus)): + new_eta = _to_tensor( + lr / (1 + lambd * lr * step ** alpha), device=device + ) + new_etas.append(new_eta) + new_mu = _to_tensor(1 / max(1, step - t0), device=device) + new_mus.append(new_mu) + + torch._foreach_copy_(grouped_etas, new_etas) + torch._foreach_copy_(grouped_mus, new_mus)