Skip to content

Commit

Permalink
feat(optim): Add NAdamsupport for complex, with has_complex short…
Browse files Browse the repository at this point in the history
…cut (#110634)

Partial fix: #110606

More on `has_complex` shortcut: #110613 (comment)

CC: @janeyx99 @mlazos @lezcano
Pull Request resolved: #110634
Approved by: https://github.com/lezcano
  • Loading branch information
jon-chuang authored and pytorchmergebot committed Oct 6, 2023
1 parent 347ea3f commit 11047be
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
22 changes: 22 additions & 0 deletions test/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,6 +1398,28 @@ def test_nadam(self):
with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"):
optim.NAdam(None, lr=1e-2, momentum_decay=-0.2)

def test_nadam_complex(self):
for foreach in (False, True):
self._test_complex_optimizer(
lambda param: optim.NAdam([param], lr=1e-1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: optim.NAdam(
[param],
lr=1e-1,
weight_decay=0.01,
foreach=foreach,
)
)
self._test_complex_optimizer(
lambda param: optim.NAdam(
[param],
lr=1e-1,
momentum_decay=0.01,
foreach=foreach,
)
)

def test_adagrad(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adagrad(
Expand Down
33 changes: 28 additions & 5 deletions torch/optim/nadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ def __setstate__(self, state):
s['mu_product'] = torch.tensor(s['mu_product'])

def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps):
has_complex = False
for p in group['params']:
if p.grad is not None:
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('NAdam does not support sparse gradients')
Expand Down Expand Up @@ -77,6 +79,7 @@ def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_
exp_avg_sqs.append(state['exp_avg_sq'])
mu_products.append(state['mu_product'])
state_steps.append(state['step'])
return has_complex

@_use_grad_for_differentiable
def step(self, closure=None):
Expand All @@ -102,7 +105,7 @@ def step(self, closure=None):
state_steps = []
beta1, beta2 = group['betas']

self._init_group(group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps)
has_complex = self._init_group(group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps)

nadam(params_with_grad,
grads,
Expand All @@ -119,7 +122,8 @@ def step(self, closure=None):
decoupled_weight_decay=group['decoupled_weight_decay'],
foreach=group['foreach'],
capturable=group['capturable'],
differentiable=group['differentiable'])
differentiable=group['differentiable'],
has_complex=has_complex)

return loss

Expand Down Expand Up @@ -195,6 +199,7 @@ def nadam(params: List[Tensor],
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
has_complex: bool = False,
*,
beta1: float,
beta2: float,
Expand Down Expand Up @@ -239,7 +244,8 @@ def nadam(params: List[Tensor],
decoupled_weight_decay=decoupled_weight_decay,
eps=eps,
capturable=capturable,
differentiable=differentiable)
differentiable=differentiable,
has_complex=has_complex)


def _single_tensor_nadam(params: List[Tensor],
Expand All @@ -257,7 +263,8 @@ def _single_tensor_nadam(params: List[Tensor],
eps: float,
decoupled_weight_decay: bool,
capturable: bool,
differentiable: bool):
differentiable: bool,
has_complex: bool):

for i, param in enumerate(params):
grad = grads[i]
Expand All @@ -266,6 +273,12 @@ def _single_tensor_nadam(params: List[Tensor],
mu_product = mu_products[i]
step_t = state_steps[i]

if torch.is_complex(param):
param = torch.view_as_real(param)
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
exp_avg_sq = torch.view_as_real(exp_avg_sq)

# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
assert (
Expand Down Expand Up @@ -333,7 +346,8 @@ def _multi_tensor_nadam(params: List[Tensor],
eps: float,
decoupled_weight_decay: bool,
capturable: bool,
differentiable: bool):
differentiable: bool,
has_complex: bool):

if len(params) == 0:
return
Expand All @@ -351,6 +365,15 @@ def _multi_tensor_nadam(params: List[Tensor],
for ((grouped_params, grouped_grads, grouped_exp_avgs,
grouped_exp_avg_sqs, grouped_mu_products, grouped_state_steps), _) in grouped_tensors.values():

# handle complex
if has_complex:
for i in range(len(grouped_params)):
if torch.is_complex(grouped_params[i]):
grouped_params[i] = torch.view_as_real(grouped_params[i])
grouped_grads[i] = torch.view_as_real(grouped_grads[i])
grouped_exp_avgs[i] = torch.view_as_real(grouped_exp_avgs[i])
grouped_exp_avg_sqs[i] = torch.view_as_real(grouped_exp_avg_sqs[i])

# update steps
torch._foreach_add_(grouped_state_steps, 1)

Expand Down

0 comments on commit 11047be

Please sign in to comment.