Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add capturable ASGD impl #107857

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions test/inductor/test_compiled_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import torch._inductor

# The rest of the optimizers not yet imported: Adamax, ASGD, LBFGS, NAdam, RAdam, SGD, SparseAdam
from torch.optim import Adadelta, Adagrad, Adam, AdamW, RMSprop, Rprop
# The rest of the optimizers not yet imported: Adamax, LBFGS, NAdam, RAdam, SGD, SparseAdam
from torch.optim import Adadelta, Adagrad, Adam, AdamW, ASGD, RMSprop, Rprop

from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase

Expand Down Expand Up @@ -75,6 +75,8 @@ def test_fn(self):

with torch.set_grad_enabled(False):
compiled_step()
compiled_step()
opt_eager.step()
opt_eager.step()

self.assertEqual(
Expand Down Expand Up @@ -165,6 +167,7 @@ def tearDown(self):
test_rmsprop = make_test(RMSprop, kernel_count=1, lr=0.01)
test_adadelta = make_test(Adadelta, kernel_count=1, lr=0.01)
test_adagrad = make_test(Adagrad, kernel_count=5, lr=0.01)
test_asgd = make_test(ASGD, kernel_count=2, lr=0.1)
# test_sgd = make_test(SGD, kernel_count=1, lr=0.01)

test_adam_recompile = make_recompile_test(Adam, lr=0.01)
Expand All @@ -176,6 +179,7 @@ def tearDown(self):
test_rmsprop_recompile = make_recompile_test(RMSprop, kernel_count=1, lr=0.01)
test_adadelta_recompile = make_recompile_test(Adadelta, kernel_count=1, lr=0.01)
test_adagrad_recompile = make_recompile_test(Adagrad, kernel_count=5, lr=0.01)
test_asgd_recompile = make_recompile_test(ASGD, kernel_count=2, lr=0.01)
mlazos marked this conversation as resolved.
Show resolved Hide resolved
# test_sgd_recompile = make_recompile_test(SGD, kernel_count=1, lr=0.01)


Expand Down
10 changes: 9 additions & 1 deletion test/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,14 +807,18 @@ def _test_foreach_memory(self, optimizer_pairs_with_flags):
intermediate_size = nparams * param.nelement() * param.element_size()
nintermediates = 1 # we expect a budget of 1 intermediate most of the time
if (('capturable' in kwargs_with_flags and kwargs_with_flags['capturable']) or
optimizer_constructor.__name__ == "Adadelta"):
optimizer_constructor.__name__ in ["Adadelta", "ASGD"]):
# with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections
# with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps)
nintermediates = 3
if optimizer_constructor.__name__ == "NAdam":
# with capturable in NAdam, we have 3 extra intermediates for the
# bias_correction, mus, and mu_nexts
nintermediates = 5
elif optimizer_constructor.__name__ == "ASGD":
# ASGD allocates axs, 2x mus, 2x etas, and grads at the same time
nintermediates = 4
mlazos marked this conversation as resolved.
Show resolved Hide resolved

elif optimizer_constructor.__name__ in ["NAdam", "Adagrad", "RMSprop"]:
# NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt)
# Adagrad uses std and grads at the same time
Expand Down Expand Up @@ -896,6 +900,10 @@ def _multi_tensor_optimizer_configs(self):
(optim.ASGD, dict(weight_decay=1)),
(optim.ASGD, dict(weight_decay=0, maximize=True)),
(optim.ASGD, dict(weight_decay=1, maximize=True)),
(optim.ASGD, dict(weight_decay=0, capturable=True)),
(optim.ASGD, dict(weight_decay=1, capturable=True)),
(optim.ASGD, dict(weight_decay=0, maximize=True, capturable=True)),
(optim.ASGD, dict(weight_decay=1, maximize=True, capturable=True)),
(optim.Adamax, dict(weight_decay=0)),
(optim.Adamax, dict(weight_decay=1)),
(optim.Adamax, dict(weight_decay=0, maximize=True)),
Expand Down
4 changes: 4 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3060,8 +3060,12 @@ def test_graph_optims(self):
] + [
(optimizer_ctor, {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad})
for optimizer_ctor, amsgrad in product((torch.optim.Adam, torch.optim.AdamW), (False, True))
] + [
(torch.optim.ASGD, {"lr": 0.1, "foreach": True, "maximize": maximize, "weight_decay": weight_decay})
mlazos marked this conversation as resolved.
Show resolved Hide resolved
for maximize, weight_decay in product((False, True), (0.0, 0.1))
]


for optimizer_ctor, kwargs in cases:
with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs):
self._test_graphed_optimizer(3, 2, optimizer_ctor, kwargs)
Expand Down
10 changes: 10 additions & 0 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,16 @@ def _foreach_lerp_scalar(start_tensors, end_tensors, weight):
)


@register_decomposition(aten._foreach_lerp.List)
mlazos marked this conversation as resolved.
Show resolved Hide resolved
def _foreach_lerp_list(start_tensors, end_tensors, weight_tensors):
return aten._foreach_add.List(
start_tensors,
aten._foreach_mul.List(
aten._foreach_sub.List(end_tensors, start_tensors), weight_tensors
),
)


@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
@register_decomposition(aten.miopen_batch_norm)
def miopen_batch_norm(
Expand Down
25 changes: 25 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2856,6 +2856,25 @@ def _check_foreach_binop_tensor_lists(self, other):
)


def _check_foreach_ternop_tensor_lists(self, other0, other1):
torch._check(
isinstance(self, List)
and isinstance(other0, List)
and isinstance(other1, List),
lambda: (
"The first three arguments of must be List[Tensor], "
f"but got {type(self)}, {type(other0)} and {type(other1)}."
),
)
torch._check(
len(self) > 0 and len(self) == len(other0) and len(self) == len(other1),
lambda: (
"self and other0 and other1 must be non-empty and match in length, "
f"but got {len(self)}, {len(other0)} and {len(other1)}."
),
)


@register_meta(
[
aten._foreach_add.List,
Expand Down Expand Up @@ -2885,12 +2904,18 @@ def meta__foreach_binop__list(self, other, alpha=1):
_check_foreach_binop_tensor_lists(self, other)


@register_meta([aten._foreach_lerp_.List])
def meta__foreach_ternop__list(self, other0, other1):
_check_foreach_ternop_tensor_lists(self, other0, other1)


@register_meta(
[
aten._foreach_add_.Scalar,
aten._foreach_mul_.Scalar,
aten._foreach_sub_.Scalar,
aten._foreach_div_.Scalar,
aten._foreach_maximum_.Scalar,
]
)
def meta__foreach_binop__scalar(self, scalar=1):
Expand Down
101 changes: 67 additions & 34 deletions torch/optim/asgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

__all__ = ["ASGD", "asgd"]

def _to_tensor(x):
def _to_tensor(x, device=None):
if not isinstance(x, torch.Tensor):
return torch.tensor(x)
return torch.tensor(x, device=device)

return x

Expand All @@ -26,6 +26,7 @@ def __init__(
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
Expand All @@ -41,6 +42,7 @@ def __init__(
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
)
super().__init__(params, defaults)

Expand All @@ -50,6 +52,7 @@ def __setstate__(self, state):
group.setdefault("foreach", None)
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
group.setdefault("capturable", False)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
state_values[0]["step"]
Expand Down Expand Up @@ -81,9 +84,9 @@ def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_step
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = torch.tensor(0.0)
state["eta"] = torch.tensor(group["lr"])
state["mu"] = torch.tensor(1.0)
state["step"] = torch.zeros((), device=p.device)
mlazos marked this conversation as resolved.
Show resolved Hide resolved
state["eta"] = torch.tensor(group["lr"], device=p.device)
state["mu"] = torch.ones((), device=p.device)
state["ax"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
Expand Down Expand Up @@ -131,6 +134,7 @@ def step(self, closure=None):
foreach=group["foreach"],
maximize=group["maximize"],
differentiable=group["differentiable"],
capturable=group["capturable"],
mlazos marked this conversation as resolved.
Show resolved Hide resolved
)

return loss
Expand Down Expand Up @@ -171,6 +175,7 @@ def asgd(
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
mlazos marked this conversation as resolved.
Show resolved Hide resolved
*,
lambd: float,
lr: float,
Expand Down Expand Up @@ -208,6 +213,7 @@ def asgd(
weight_decay=weight_decay,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
)


Expand All @@ -226,12 +232,8 @@ def _single_tensor_asgd(
weight_decay: float,
maximize: bool,
differentiable: bool,
capturable: bool,
):
def _to_tensor(x):
if not isinstance(x, torch.Tensor):
return torch.tensor(x)
return x

for i, param in enumerate(params):
grad = grads[i]
grad = grad if not maximize else -grad
Expand Down Expand Up @@ -286,6 +288,7 @@ def _multi_tensor_asgd(
weight_decay: float,
maximize: bool,
differentiable: bool,
capturable: bool,
):

if len(params) == 0:
Expand All @@ -311,32 +314,62 @@ def _view_complex_as_real(tensor_list):
# update step
torch._foreach_add_(grouped_state_steps, 1)

# intermediate = grad + param * lambd
if weight_decay != 0:
# Re-use the intermediate memory (grouped_grads) already allocated for maximize
if maximize:
torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
intermediate = grouped_grads
else:
grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)

# decay term
eta = _get_value(grouped_etas[0])
torch._foreach_mul_(grouped_params, 1 - lambd * eta)
intermediate = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)

# update parameter
torch._foreach_add_(grouped_params, grouped_grads, alpha=-eta)

# averaging
for i in range(len(grouped_axs)):
if is_compiling() or grouped_mus[i].item() != 1:
grouped_axs[i].add_(grouped_params[i].sub(grouped_axs[i]).mul(grouped_mus[i]))
else:
grouped_axs[i].copy_(grouped_params[i])

# update eta and mu
for i in range(len(grouped_mus)):
new_eta = _to_tensor(
lr / (1 + lambd * lr * _get_value(grouped_state_steps[i]) ** alpha)
)
grouped_etas[i].copy_(new_eta)
new_mu = _to_tensor(1 / max(1, _get_value(grouped_state_steps[i]) - t0))
grouped_mus[i].copy_(new_mu)
torch._foreach_add_(intermediate, grouped_params, alpha=lambd)
else:
intermediate = torch._foreach_add(grouped_grads, grouped_params, alpha=lambd)

# update param
# param * (1 - lambd * eta) - eta * grad
# => param - param * lambd * eta - eta * grad
# => param - eta * intermediate
torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1)
mlazos marked this conversation as resolved.
Show resolved Hide resolved

# update grouped_axs
# averaging: ax = ax + mu * (param - ax)
# Note (mlazos): We can't use lerp here since it requires weight to be float64
# and our grouping code requires dtypes to match for all tensors in a group (and it should, since
# we use the mus in other places)
# all dtypes need to match, so we could introduce a cast in a loop
# but since this only adds one additional kernel launch, this looks like the cleaner
# and faster solution
intermediate = torch._foreach_sub(grouped_params, grouped_axs)
torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus)
mlazos marked this conversation as resolved.
Show resolved Hide resolved

if capturable:
# update grouped_mus
new_mus = torch._foreach_sub(grouped_state_steps, t0)
torch._foreach_maximum_(new_mus, 1.0)
torch._foreach_reciprocal_(new_mus)
torch._foreach_copy_(grouped_mus, new_mus)
mlazos marked this conversation as resolved.
Show resolved Hide resolved

# update eta = lr / (1 + lambd * lr * step^alpha)
new_etas = torch._foreach_pow(grouped_state_steps, alpha)
torch._foreach_mul_(new_etas, lambd)
torch._foreach_mul_(new_etas, lr)
torch._foreach_add_(new_etas, 1)
torch._foreach_reciprocal_(new_etas)
torch._foreach_mul_(new_etas, lr)
torch._foreach_copy_(grouped_etas, new_etas)
else:
step = grouped_state_steps[0].item()
new_etas = []
new_mus = []

for i in range(len(grouped_mus)):
new_eta = _to_tensor(
lr / (1 + lambd * lr * step ** alpha), device=grouped_etas[i].device
mlazos marked this conversation as resolved.
Show resolved Hide resolved
)
new_etas.append(new_eta)
new_mu = _to_tensor(1 / max(1, step - t0), device=grouped_mus[i].device)
new_mus.append(new_mu)

torch._foreach_copy_(grouped_etas, new_etas)
torch._foreach_copy_(grouped_mus, new_mus)
Loading