Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add capturable ASGD impl #107857

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions test/inductor/test_compiled_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import torch._inductor

# The rest of the optimizers not yet imported: Adamax, ASGD, LBFGS, NAdam, RAdam, SGD, SparseAdam
from torch.optim import Adadelta, Adagrad, Adam, AdamW, RMSprop, Rprop
# The rest of the optimizers not yet imported: Adamax, LBFGS, NAdam, RAdam, SGD, SparseAdam
from torch.optim import Adadelta, Adagrad, Adam, AdamW, ASGD, RMSprop, Rprop

from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase

Expand Down Expand Up @@ -76,6 +76,8 @@ def test_fn(self):

with torch.set_grad_enabled(False):
compiled_step()
compiled_step()
opt_eager.step()
opt_eager.step()

self.assertEqual(
Expand Down Expand Up @@ -166,6 +168,7 @@ def tearDown(self):
test_rmsprop = make_test(RMSprop, kernel_count=1, lr=0.01)
test_adadelta = make_test(Adadelta, kernel_count=1, lr=0.01)
test_adagrad = make_test(Adagrad, kernel_count=5, lr=0.01)
test_asgd = make_test(ASGD, kernel_count=2, lr=0.1)
# test_sgd = make_test(SGD, kernel_count=1, lr=0.01)

test_adam_recompile = make_recompile_test(Adam, lr=0.01)
Expand All @@ -177,6 +180,7 @@ def tearDown(self):
test_rmsprop_recompile = make_recompile_test(RMSprop, kernel_count=1, lr=0.01)
test_adadelta_recompile = make_recompile_test(Adadelta, kernel_count=1, lr=0.01)
test_adagrad_recompile = make_recompile_test(Adagrad, kernel_count=5, lr=0.01)
test_asgd_recompile = make_recompile_test(ASGD, kernel_count=2, lr=0.01)
mlazos marked this conversation as resolved.
Show resolved Hide resolved
# test_sgd_recompile = make_recompile_test(SGD, kernel_count=1, lr=0.01)

@requires_cuda()
Expand Down
21 changes: 19 additions & 2 deletions test/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,9 @@ def _test_derived_optimizers_varying_tensors(self, optimizer_with_kwargs, kwarg)
res, state = [], []
for enabled in (False, True):
kwargs_clone = deepcopy(kwargs)
if optimizer_constructor.__name__ == "ASGD" and kwarg == "foreach" and not enabled:
# single tensor ASGD does not support capturable
kwargs_clone["capturable"] = False
mlazos marked this conversation as resolved.
Show resolved Hide resolved
kwargs_clone[kwarg] = enabled

params_clone = []
Expand Down Expand Up @@ -731,6 +734,9 @@ def _test_derived_optimizers(self, optimizer_pairs_with_flags, flag):
)
model.to(dtype=torch.float64, device=device)
params_with_flags = deepcopy(params)
if optimizer_constructor.__name__ == "ASGD" and flag == "foreach" and not flag_value:
# single tensor ASGD does not support capturable
params_with_flags["capturable"] = False
params_with_flags[flag] = flag_value

# foreach/fused optimizers should be tested with a param_groups['params'] with
Expand Down Expand Up @@ -779,7 +785,12 @@ def _test_foreach_memory(self, optimizer_pairs_with_flags):
max_mems = []
for flag_value in (False, True):
kwargs_with_flags = deepcopy(kwargs)
kwargs_with_flags['foreach'] = flag_value
if optimizer_constructor.__name__ == "ASGD" and kwargs_with_flags.get("capturable", False) and not flag_value:
# single tensor ASGD does not support capturable
kwargs_with_flags["capturable"] = False

kwargs_with_flags["foreach"] = flag_value


# The 128 is critical here! Our CUDACachingAllocator allocates in blocks of 512,
# meaning any tensor that occupies <512 bytes of memory will allocate a whole
Expand Down Expand Up @@ -807,14 +818,16 @@ def _test_foreach_memory(self, optimizer_pairs_with_flags):
intermediate_size = nparams * param.nelement() * param.element_size()
nintermediates = 1 # we expect a budget of 1 intermediate most of the time
if (('capturable' in kwargs_with_flags and kwargs_with_flags['capturable']) or
optimizer_constructor.__name__ == "Adadelta"):
optimizer_constructor.__name__ in ["Adadelta", "ASGD"]):
# with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections
# with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps)
# ASGD allocates axs, 2x mus, 2x etas, and grads at the same time
nintermediates = 3
if optimizer_constructor.__name__ == "NAdam":
# with capturable in NAdam, we have 3 extra intermediates for the
# bias_correction, mus, and mu_nexts
nintermediates = 5

elif optimizer_constructor.__name__ in ["NAdam", "Adagrad", "RMSprop"]:
# NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt)
# Adagrad uses std and grads at the same time
Expand Down Expand Up @@ -896,6 +909,10 @@ def _multi_tensor_optimizer_configs(self):
(optim.ASGD, dict(weight_decay=1)),
(optim.ASGD, dict(weight_decay=0, maximize=True)),
(optim.ASGD, dict(weight_decay=1, maximize=True)),
(optim.ASGD, dict(weight_decay=0, capturable=True)),
(optim.ASGD, dict(weight_decay=1, capturable=True)),
(optim.ASGD, dict(weight_decay=0, maximize=True, capturable=True)),
(optim.ASGD, dict(weight_decay=1, maximize=True, capturable=True)),
(optim.Adamax, dict(weight_decay=0)),
(optim.Adamax, dict(weight_decay=1)),
(optim.Adamax, dict(weight_decay=0, maximize=True)),
Expand Down
4 changes: 4 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3060,8 +3060,12 @@ def test_graph_optims(self):
] + [
(optimizer_ctor, {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad})
for optimizer_ctor, amsgrad in product((torch.optim.Adam, torch.optim.AdamW), (False, True))
] + [
(torch.optim.ASGD, {"lr": 0.1, "foreach": True, "maximize": maximize, "weight_decay": weight_decay})
mlazos marked this conversation as resolved.
Show resolved Hide resolved
for maximize, weight_decay in product((False, True), (0.0, 0.1))
]


for optimizer_ctor, kwargs in cases:
with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs):
self._test_graphed_optimizer(3, 2, optimizer_ctor, kwargs)
Expand Down
117 changes: 80 additions & 37 deletions torch/optim/asgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from torch import Tensor

from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _default_to_fused_or_foreach,
_differentiable_doc, _foreach_doc, _maximize_doc)
_differentiable_doc, _foreach_doc, _maximize_doc, _capturable_doc)
from torch._utils import is_compiling
from typing import List, Optional

__all__ = ["ASGD", "asgd"]

def _to_tensor(x):
def _to_tensor(x, device=None):
if not isinstance(x, torch.Tensor):
return torch.tensor(x)
return torch.tensor(x, device=device)

return x

Expand All @@ -26,12 +26,16 @@ def __init__(
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")

if foreach is False and capturable:
raise ValueError("Capturable not supported with single tensor ASGD")

defaults = dict(
lr=lr,
lambd=lambd,
Expand All @@ -41,6 +45,7 @@ def __init__(
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
)
super().__init__(params, defaults)

Expand All @@ -50,6 +55,7 @@ def __setstate__(self, state):
group.setdefault("foreach", None)
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
group.setdefault("capturable", False)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
state_values[0]["step"]
Expand Down Expand Up @@ -81,9 +87,9 @@ def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_step
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = torch.tensor(0.0)
state["eta"] = torch.tensor(group["lr"])
state["mu"] = torch.tensor(1.0)
state["step"] = torch.zeros((), device=p.device)
mlazos marked this conversation as resolved.
Show resolved Hide resolved
state["eta"] = torch.tensor(group["lr"], device=p.device)
state["mu"] = torch.ones((), device=p.device)
state["ax"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
Expand Down Expand Up @@ -131,6 +137,7 @@ def step(self, closure=None):
foreach=group["foreach"],
maximize=group["maximize"],
differentiable=group["differentiable"],
capturable=group["capturable"],
mlazos marked this conversation as resolved.
Show resolved Hide resolved
)

return loss
Expand All @@ -152,6 +159,7 @@ def step(self, closure=None):
{_foreach_doc}
{_maximize_doc}
{_differentiable_doc}
{_capturable_doc} For ASGD, capturable is only supported when foreach is True.

.. _Acceleration of stochastic approximation by averaging:
https://dl.acm.org/citation.cfm?id=131098
Expand All @@ -171,6 +179,7 @@ def asgd(
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
mlazos marked this conversation as resolved.
Show resolved Hide resolved
*,
lambd: float,
lr: float,
Expand All @@ -192,6 +201,8 @@ def asgd(
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_asgd
else:
if capturable and not is_compiling():
raise RuntimeError("Capturable not supported with single tensor ASGD")
func = _single_tensor_asgd

func(
Expand All @@ -208,6 +219,7 @@ def asgd(
weight_decay=weight_decay,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
)


Expand All @@ -226,12 +238,8 @@ def _single_tensor_asgd(
weight_decay: float,
maximize: bool,
differentiable: bool,
capturable: bool,
):
def _to_tensor(x):
if not isinstance(x, torch.Tensor):
return torch.tensor(x)
return x

for i, param in enumerate(params):
grad = grads[i]
grad = grad if not maximize else -grad
Expand Down Expand Up @@ -286,6 +294,7 @@ def _multi_tensor_asgd(
weight_decay: float,
maximize: bool,
differentiable: bool,
capturable: bool,
):

if len(params) == 0:
Expand All @@ -294,8 +303,8 @@ def _multi_tensor_asgd(
assert not differentiable, "_foreach ops don't support autograd"

grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
for ((grouped_params, grouped_grads, grouped_axs, grouped_mus,
grouped_etas, grouped_state_steps), _) in grouped_tensors.values():
for ((device, _), ((grouped_params, grouped_grads, grouped_axs, grouped_mus,
grouped_etas, grouped_state_steps), _)) in grouped_tensors.items():
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads)

Expand All @@ -311,32 +320,66 @@ def _view_complex_as_real(tensor_list):
# update step
torch._foreach_add_(grouped_state_steps, 1)

# intermediate = grad + param * lambd
if weight_decay != 0:
# Re-use the intermediate memory (grouped_grads) already allocated for maximize
if maximize:
torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
intermediate = grouped_grads
else:
grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)

# decay term
eta = _get_value(grouped_etas[0])
torch._foreach_mul_(grouped_params, 1 - lambd * eta)

# update parameter
torch._foreach_add_(grouped_params, grouped_grads, alpha=-eta)

# averaging
for i in range(len(grouped_axs)):
if is_compiling() or grouped_mus[i].item() != 1:
grouped_axs[i].add_(grouped_params[i].sub(grouped_axs[i]).mul(grouped_mus[i]))
else:
grouped_axs[i].copy_(grouped_params[i])
intermediate = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)

# update eta and mu
for i in range(len(grouped_mus)):
new_eta = _to_tensor(
lr / (1 + lambd * lr * _get_value(grouped_state_steps[i]) ** alpha)
)
grouped_etas[i].copy_(new_eta)
new_mu = _to_tensor(1 / max(1, _get_value(grouped_state_steps[i]) - t0))
grouped_mus[i].copy_(new_mu)
torch._foreach_add_(intermediate, grouped_params, alpha=lambd)
else:
intermediate = torch._foreach_add(grouped_grads, grouped_params, alpha=lambd)

# update param
# param * (1 - lambd * eta) - eta * grad
# => param - param * lambd * eta - eta * grad
# => param - eta * intermediate
torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1)
mlazos marked this conversation as resolved.
Show resolved Hide resolved
del intermediate

# update grouped_axs
# averaging: ax = ax + mu * (param - ax)
# Note (mlazos): We can't use lerp here since it requires weight to be float64
# and our grouping code requires dtypes to match for all tensors in a group (and it should, since
# we use the mus in other places)
# all dtypes need to match, so we could introduce a cast in a loop
# but since this only adds one additional kernel launch, this looks like the cleaner
# and faster solution
intermediate = torch._foreach_sub(grouped_params, grouped_axs)
torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus)
mlazos marked this conversation as resolved.
Show resolved Hide resolved
del intermediate

if capturable:
# update grouped_mus
new_mus = torch._foreach_sub(grouped_state_steps, t0)
torch._foreach_maximum_(new_mus, 1.0)
torch._foreach_reciprocal_(new_mus)
torch._foreach_copy_(grouped_mus, new_mus)
mlazos marked this conversation as resolved.
Show resolved Hide resolved
del new_mus

# update eta = lr / (1 + lambd * lr * step^alpha)
new_etas = torch._foreach_pow(grouped_state_steps, alpha)
torch._foreach_mul_(new_etas, lambd)
torch._foreach_mul_(new_etas, lr)
torch._foreach_add_(new_etas, 1)
torch._foreach_reciprocal_(new_etas)
torch._foreach_mul_(new_etas, lr)
torch._foreach_copy_(grouped_etas, new_etas)
del new_etas
mlazos marked this conversation as resolved.
Show resolved Hide resolved
else:
step = grouped_state_steps[0].item()
new_etas = []
new_mus = []

for i in range(len(grouped_mus)):
new_eta = _to_tensor(
lr / (1 + lambd * lr * step ** alpha), device=device
)
new_etas.append(new_eta)
new_mu = _to_tensor(1 / max(1, step - t0), device=device)
new_mus.append(new_mu)

torch._foreach_copy_(grouped_etas, new_etas)
torch._foreach_copy_(grouped_mus, new_mus)
Loading