Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add capturable ASGD impl #107857

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions test/inductor/test_compiled_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import torch._inductor

# The rest of the optimizers not yet imported: Adamax, ASGD, LBFGS, NAdam, RAdam, SGD, SparseAdam
from torch.optim import Adadelta, Adagrad, Adam, AdamW, RMSprop, Rprop
# The rest of the optimizers not yet imported: Adamax, LBFGS, NAdam, RAdam, SGD, SparseAdam
from torch.optim import Adadelta, Adagrad, Adam, AdamW, ASGD, RMSprop, Rprop

from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase

Expand Down Expand Up @@ -165,6 +165,7 @@ def tearDown(self):
test_rmsprop = make_test(RMSprop, kernel_count=1, lr=0.01)
test_adadelta = make_test(Adadelta, kernel_count=1, lr=0.01)
test_adagrad = make_test(Adagrad, kernel_count=5, lr=0.01)
test_asgd = make_test(ASGD, kernel_count=10, lr=0.01)
# test_sgd = make_test(SGD, kernel_count=1, lr=0.01)

test_adam_recompile = make_recompile_test(Adam, lr=0.01)
Expand All @@ -176,6 +177,7 @@ def tearDown(self):
test_rmsprop_recompile = make_recompile_test(RMSprop, kernel_count=1, lr=0.01)
test_adadelta_recompile = make_recompile_test(Adadelta, kernel_count=1, lr=0.01)
test_adagrad_recompile = make_recompile_test(Adagrad, kernel_count=5, lr=0.01)
test_asgd_recompile = make_recompile_test(ASGD, kernel_count=10, lr=0.01)
# test_sgd_recompile = make_recompile_test(SGD, kernel_count=1, lr=0.01)


Expand Down
88 changes: 60 additions & 28 deletions torch/optim/asgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
capturable=False,
)
super().__init__(params, defaults)

Expand All @@ -50,6 +51,7 @@ def __setstate__(self, state):
group.setdefault("foreach", None)
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
group.setdefault("capturable", False)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
state_values[0]["step"]
Expand Down Expand Up @@ -81,12 +83,20 @@ def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_step
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = torch.tensor(0.0)
state["eta"] = torch.tensor(group["lr"])
state["mu"] = torch.tensor(1.0)
state["ax"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if group["capturable"]:
state["step"] = torch.zeros((), dtype=torch.float, device=p.device)
state["eta"] = torch.tensor(group["lr"], dtype=torch.float, device=p.device)
state["mu"] = torch.ones((), dtype=torch.float, device=p.device)
state["ax"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
else:
state["step"] = torch.tensor(0.0)
state["eta"] = torch.tensor(group["lr"])
state["mu"] = torch.tensor(1.0)
state["ax"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)

mus.append(state["mu"])
axs.append(state["ax"])
Expand Down Expand Up @@ -131,6 +141,7 @@ def step(self, closure=None):
foreach=group["foreach"],
maximize=group["maximize"],
differentiable=group["differentiable"],
capturable=group["capturable"],
mlazos marked this conversation as resolved.
Show resolved Hide resolved
)

return loss
Expand Down Expand Up @@ -177,6 +188,7 @@ def asgd(
t0: float,
alpha: float,
weight_decay: float,
capturable: bool = False,
):
r"""Functional API that performs asgd algorithm computation.

Expand Down Expand Up @@ -208,6 +220,7 @@ def asgd(
weight_decay=weight_decay,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
)


Expand Down Expand Up @@ -286,6 +299,7 @@ def _multi_tensor_asgd(
weight_decay: float,
maximize: bool,
differentiable: bool,
capturable: bool,
):

if len(params) == 0:
Expand Down Expand Up @@ -318,25 +332,43 @@ def _view_complex_as_real(tensor_list):
else:
grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)

# decay term
eta = _get_value(grouped_etas[0])
torch._foreach_mul_(grouped_params, 1 - lambd * eta)

# update parameter
torch._foreach_add_(grouped_params, grouped_grads, alpha=-eta)

# averaging
for i in range(len(grouped_axs)):
if is_compiling() or grouped_mus[i].item() != 1:
grouped_axs[i].add_(grouped_params[i].sub(grouped_axs[i]).mul(grouped_mus[i]))
else:
grouped_axs[i].copy_(grouped_params[i])

# update eta and mu
for i in range(len(grouped_mus)):
new_eta = _to_tensor(
lr / (1 + lambd * lr * _get_value(grouped_state_steps[i]) ** alpha)
)
grouped_etas[i].copy_(new_eta)
new_mu = _to_tensor(1 / max(1, _get_value(grouped_state_steps[i]) - t0))
grouped_mus[i].copy_(new_mu)
if capturable:
# decay term
decay = torch._foreach_add(torch._foreach_mul(etas, -lambd), 1)
# update parameter
torch._foreach_mul_(grouped_params,
mlazos marked this conversation as resolved.
Show resolved Hide resolved
torch._foreach_add(torch._foreach_div(torch._foreach_mul(grouped_grads,
torch._foreach_mul(etas, -1)),
torch._foreach_mul(grouped_params, decay)),
1.0))

torch._foreach_add_(grouped_axs, torch._foreach_mul(torch._foreach_sub(grouped_params, grouped_axs), grouped_mus))
# until we have foreach_copy, zero out grouped mus and add
mlazos marked this conversation as resolved.
Show resolved Hide resolved
# these memory bound ops will be fused by the compiler so it doesn't matter
torch._foreach_copy_(grouped_mus,
torch._foreach_reciprocal(torch._foreach_maximum(torch._foreach_sub(grouped_state_steps, t0),
1.0)))
else:
# decay term
eta = _get_value(grouped_etas[0])
mlazos marked this conversation as resolved.
Show resolved Hide resolved
torch._foreach_mul_(grouped_params, 1 - lambd * eta)

# update parameter
torch._foreach_add_(grouped_params, grouped_grads, alpha=-eta)
mlazos marked this conversation as resolved.
Show resolved Hide resolved


# averaging
for i in range(len(grouped_axs)):
if is_compiling() or grouped_mus[i].item() != 1:
grouped_axs[i].add_(grouped_params[i].sub(grouped_axs[i]).mul(grouped_mus[i]))
else:
grouped_axs[i].copy_(grouped_params[i])
mlazos marked this conversation as resolved.
Show resolved Hide resolved

# update eta and mu
for i in range(len(grouped_mus)):
new_eta = _to_tensor(
lr / (1 + lambd * lr * _get_value(grouped_state_steps[i]) ** alpha)
)
grouped_etas[i].copy_(new_eta)
new_mu = _to_tensor(1 / max(1, _get_value(grouped_state_steps[i]) - t0))
grouped_mus[i].copy_(new_mu)
mlazos marked this conversation as resolved.
Show resolved Hide resolved
Loading