Skip to content

Commit

Permalink
[optim] Merge the pyi files into py files of optimizer (#125452)
Browse files Browse the repository at this point in the history
Continue the work of #125153
Pull Request resolved: #125452
Approved by: https://github.com/janeyx99
  • Loading branch information
david20571015 authored and pytorchmergebot committed May 14, 2024
1 parent a00a99e commit 1a28f73
Show file tree
Hide file tree
Showing 28 changed files with 284 additions and 421 deletions.
28 changes: 14 additions & 14 deletions torch/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@
from .sgd import SGD
from .sparse_adam import SparseAdam

del adadelta # noqa: F821
del adagrad # noqa: F821
del adam # noqa: F821
del adamw # noqa: F821
del sparse_adam # noqa: F821
del adamax # noqa: F821
del asgd # noqa: F821
del sgd # noqa: F821
del radam # noqa: F821
del rprop # noqa: F821
del rmsprop # noqa: F821
del optimizer # noqa: F821
del nadam # noqa: F821
del lbfgs # noqa: F821
del adadelta # type: ignore[name-defined] # noqa: F821
del adagrad # type: ignore[name-defined] # noqa: F821
del adam # type: ignore[name-defined] # noqa: F821
del adamw # type: ignore[name-defined] # noqa: F821
del sparse_adam # type: ignore[name-defined] # noqa: F821
del adamax # type: ignore[name-defined] # noqa: F821
del asgd # type: ignore[name-defined] # noqa: F821
del sgd # type: ignore[name-defined] # noqa: F821
del radam # type: ignore[name-defined] # noqa: F821
del rprop # type: ignore[name-defined] # noqa: F821
del rmsprop # type: ignore[name-defined] # noqa: F821
del optimizer # type: ignore[name-defined] # noqa: F821
del nadam # type: ignore[name-defined] # noqa: F821
del lbfgs # type: ignore[name-defined] # noqa: F821
15 changes: 0 additions & 15 deletions torch/optim/__init__.pyi

This file was deleted.

8 changes: 4 additions & 4 deletions torch/optim/adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ class Adadelta(Optimizer):
def __init__(
self,
params: ParamsT,
lr=1.0,
rho=0.9,
eps=1e-6,
weight_decay=0,
lr: float = 1.0,
rho: float = 0.9,
eps: float = 1e-6,
weight_decay: float = 0,
foreach: Optional[bool] = None,
*,
capturable: bool = False,
Expand Down
10 changes: 5 additions & 5 deletions torch/optim/adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ class Adagrad(Optimizer):
def __init__(
self,
params: ParamsT,
lr=1e-2,
lr_decay=0,
weight_decay=0,
initial_accumulator_value=0,
eps=1e-10,
lr: float = 1e-2,
lr_decay: float = 0,
weight_decay: float = 0,
initial_accumulator_value: float = 0,
eps: float = 1e-10,
foreach: Optional[bool] = None,
*,
maximize: bool = False,
Expand Down
49 changes: 27 additions & 22 deletions torch/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_stack_if_compiling,
_use_grad_for_differentiable,
_view_as_real,
DeviceDict,
Optimizer,
ParamsT,
)
Expand Down Expand Up @@ -203,12 +204,12 @@ def step(self, closure=None):
loss = closure()

for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps = []
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
max_exp_avg_sqs: List[Tensor] = []
state_steps: List[Tensor] = []
beta1, beta2 = group["betas"]

has_complex = self._init_group(
Expand Down Expand Up @@ -506,7 +507,7 @@ def _multi_tensor_adam(
)

if maximize:
device_grads = torch._foreach_neg(device_grads)
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]

# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
Expand All @@ -524,7 +525,7 @@ def _multi_tensor_adam(
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add(
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)

Expand All @@ -539,6 +540,9 @@ def _multi_tensor_adam(
# Delete the local intermediate since it won't be used anymore to save on peak memory
del device_grads

bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
if capturable:
bias_correction1 = torch._foreach_pow(beta1, device_state_steps)
bias_correction2 = torch._foreach_pow(beta2, device_state_steps)
Expand Down Expand Up @@ -585,7 +589,7 @@ def _multi_tensor_adam(

step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])

bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]
bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2] # type: ignore[arg-type]

if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
Expand All @@ -599,7 +603,7 @@ def _multi_tensor_adam(
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
torch._foreach_add_(exp_avg_sq_sqrt, eps)
torch._foreach_addcdiv_(
device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size
device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size # type: ignore[arg-type]
)


Expand Down Expand Up @@ -629,17 +633,18 @@ def _fused_adam(
if differentiable:
raise RuntimeError("Adam with fused=True does not support differentiable=True")

grad_scale_dict = (
{grad_scale.device: grad_scale} if grad_scale is not None else None
grad_scale_dict: DeviceDict = (
{grad_scale.device: grad_scale} if grad_scale is not None else {}
)
found_inf_dict: DeviceDict = (
{found_inf.device: found_inf} if found_inf is not None else {}
)
found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None

# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
# treating it as a scalar.
lr_dict = (
lr_dict: Optional[DeviceDict] = (
{lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
)

grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]
)
Expand All @@ -656,15 +661,15 @@ def _fused_adam(
) in grouped_tensors.items():
device_grad_scale, device_found_inf = None, None
if grad_scale is not None:
if device not in grad_scale_dict:
grad_scale_dict[device] = grad_scale.to(device, non_blocking=True)
device_grad_scale = grad_scale_dict[device]
device_grad_scale = grad_scale_dict.setdefault(
device, grad_scale.to(device, non_blocking=True)
)
if found_inf is not None:
if found_inf not in found_inf_dict:
found_inf_dict[device] = found_inf.to(device, non_blocking=True)
device_found_inf = found_inf_dict[device]
device_found_inf = found_inf_dict.setdefault(
device, found_inf.to(device, non_blocking=True)
)
if lr_dict is not None and device not in lr_dict:
lr_dict[device] = lr.to(device=device, non_blocking=True)
lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr]
lr = lr_dict[device]
torch._foreach_add_(device_state_steps, 1)
torch._fused_adam_(
Expand Down
22 changes: 0 additions & 22 deletions torch/optim/adam.pyi

This file was deleted.

32 changes: 17 additions & 15 deletions torch/optim/adamax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand All @@ -16,6 +16,7 @@
_use_grad_for_differentiable,
_view_as_real,
Optimizer,
ParamsT,
)

__all__ = ["Adamax", "adamax"]
Expand All @@ -24,11 +25,11 @@
class Adamax(Optimizer):
def __init__(
self,
params,
lr=2e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
params: ParamsT,
lr: float = 2e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
foreach: Optional[bool] = None,
*,
maximize: bool = False,
Expand Down Expand Up @@ -128,11 +129,11 @@ def step(self, closure=None):
loss = closure()

for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_infs = []
state_steps = []
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_infs: List[Tensor] = []
state_steps: List[Tensor] = []

beta1, beta2 = group["betas"]
eps = group["eps"]
Expand Down Expand Up @@ -298,11 +299,11 @@ def _multi_tensor_adamax(
exp_infs: List[Tensor],
state_steps: List[Tensor],
*,
eps: float,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool,
differentiable: bool,
capturable: bool,
Expand Down Expand Up @@ -340,7 +341,7 @@ def _multi_tensor_adamax(
)

if maximize:
grouped_grads = torch._foreach_neg(grouped_grads)
grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]

# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
Expand All @@ -358,7 +359,7 @@ def _multi_tensor_adamax(
# Re-use the intermediate memory (grouped_grads) already allocated for maximize
torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
else:
grouped_grads = torch._foreach_add(
grouped_grads = torch._foreach_add( # type: ignore[assignment]
grouped_grads, grouped_params, alpha=weight_decay
)

Expand All @@ -371,13 +372,14 @@ def _multi_tensor_adamax(
# in this case, we need to introduce a copy of the grads
# since one has not been introduced previously
if not maximize and weight_decay == 0:
grouped_grads = torch._foreach_abs(grouped_grads)
grouped_grads = torch._foreach_abs(grouped_grads) # type: ignore[assignment]
else:
torch._foreach_abs_(grouped_grads)

torch._foreach_add_(grouped_grads, eps)
torch._foreach_maximum_(grouped_exp_infs, grouped_grads)

bias_corrections: Union[Tuple[Tensor, ...], List[Tensor]]
if capturable:
bias_corrections = torch._foreach_pow(beta1, grouped_state_steps)
# foreach_sub doesn't allow a scalar as the first arg
Expand Down
13 changes: 0 additions & 13 deletions torch/optim/adamax.pyi

This file was deleted.

0 comments on commit 1a28f73

Please sign in to comment.