Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable grouping by dtype and device if compiling #102771

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions torch/optim/adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from .optimizer import (Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach,
_differentiable_doc, _foreach_doc, _maximize_doc)
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
from typing import List, Optional

__all__ = ["Adadelta", "adadelta"]
Expand Down Expand Up @@ -276,7 +275,7 @@ def _multi_tensor_adadelta(
if len(params) == 0:
return

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, square_avgs, acc_deltas])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, square_avgs, acc_deltas])
for device_params, device_grads, device_square_avgs, device_acc_deltas in grouped_tensors.values():
if maximize:
device_grads = torch._foreach_neg(device_grads)
Expand Down
3 changes: 1 addition & 2 deletions torch/optim/adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value,
_default_to_fused_or_foreach, _differentiable_doc, _foreach_doc, _maximize_doc)
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
from typing import List, Optional

__all__ = ["Adagrad", "adagrad"]
Expand Down Expand Up @@ -321,7 +320,7 @@ def _multi_tensor_adagrad(
if len(params) == 0:
return

grouped_tensorlists = _group_tensors_by_device_and_dtype([params, grads, state_sums, state_steps])
grouped_tensorlists = Optimizer._group_tensors_by_device_and_dtype([params, grads, state_sums, state_steps])
for device_params, device_grads, device_state_sums, device_state_steps in grouped_tensorlists.values():

if maximize:
Expand Down
7 changes: 4 additions & 3 deletions torch/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling,
_dispatch_sqrt, _default_to_fused_or_foreach, _capturable_doc,
_differentiable_doc, _foreach_doc, _fused_doc, _maximize_doc)
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ['Adam', 'adam']

Expand Down Expand Up @@ -424,7 +423,8 @@ def _multi_tensor_adam(params: List[Tensor],

assert not differentiable, "_foreach ops don't support autograd"

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
for (device_params, device_grads, device_exp_avgs, device_exp_avg_sqs,
device_max_exp_avg_sqs, device_state_steps) in grouped_tensors.values():

Expand Down Expand Up @@ -532,7 +532,8 @@ def _fused_adam(
) -> None:
grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None
found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
for (device, dtype) in grouped_tensors:
(
device_params,
Expand Down
3 changes: 1 addition & 2 deletions torch/optim/adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling,
_default_to_fused_or_foreach, _differentiable_doc, _maximize_doc, _foreach_doc)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ["Adamax", "adamax"]

Expand Down Expand Up @@ -305,7 +304,7 @@ def _multi_tensor_adamax(
if len(params) == 0:
return

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps])
for grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs, grouped_state_steps in grouped_tensors.values():
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads)
Expand Down
6 changes: 3 additions & 3 deletions torch/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
_stack_if_compiling, _capturable_doc, _differentiable_doc, _foreach_doc,
_fused_doc, _maximize_doc, _default_to_fused_or_foreach)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ["AdamW", "adamw"]

Expand Down Expand Up @@ -476,7 +475,7 @@ def _multi_tensor_adamw(

assert grad_scale is None and found_inf is None

grouped_tensors = _group_tensors_by_device_and_dtype([
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
for (device_params, device_grads, device_exp_avgs, device_exp_avg_sqs,
device_max_exp_avg_sqs, device_state_steps) in grouped_tensors.values():
Expand Down Expand Up @@ -593,7 +592,8 @@ def _fused_adamw(
raise RuntimeError("_fused_adamw is not differentiable")
grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None
found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
for (device, dtype) in grouped_tensors:
(
device_params,
Expand Down
3 changes: 1 addition & 2 deletions torch/optim/asgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _default_to_fused_or_foreach,
_differentiable_doc, _foreach_doc, _maximize_doc)
from torch._utils import is_compiling
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
from typing import List, Optional

__all__ = ["ASGD", "asgd"]
Expand Down Expand Up @@ -294,7 +293,7 @@ def _multi_tensor_asgd(

assert not differentiable, "_foreach ops don't support autograd"

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
for (grouped_params, grouped_grads, grouped_axs, grouped_mus,
grouped_etas, grouped_state_steps) in grouped_tensors.values():
if maximize:
Expand Down
4 changes: 1 addition & 3 deletions torch/optim/nadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling,
_differentiable_doc, _foreach_doc, _default_to_fused_or_foreach)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ['NAdam', 'nadam']

Expand Down Expand Up @@ -291,8 +290,7 @@ def _multi_tensor_nadam(params: List[Tensor],

assert not differentiable, "_foreach ops don't support autograd"

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs,
mu_products, state_steps])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps])
for (grouped_params, grouped_grads, grouped_exp_avgs,
grouped_exp_avg_sqs, grouped_mu_products, grouped_state_steps) in grouped_tensors.values():

Expand Down
13 changes: 13 additions & 0 deletions torch/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.utils.hooks as hooks
from torch.utils.hooks import RemovableHandle
from torch._utils import is_compiling
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ['Optimizer', 'register_optimizer_step_pre_hook', 'register_optimizer_step_post_hook']
_global_optimizer_pre_hooks: Dict[int, Callable] = OrderedDict()
Expand Down Expand Up @@ -288,6 +289,18 @@ def wrapper(*args, **kwargs):

return wrapper

@staticmethod
def _group_tensors_by_device_and_dtype(tensorlistlist, with_indices=False):
mlazos marked this conversation as resolved.
Show resolved Hide resolved
"""Groups a list of lists of tensors by device and dtype.
Skips this step if we are compiling since this will occur during inductor lowering."""
if is_compiling():
if with_indices:
indices = list(range(len(tensorlistlist[0])))
tensorlistlist.append(indices)
return {(None, None): tensorlistlist}
else:
return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices)

def _patch_step_function(self):
self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__)
hooked = getattr(self.__class__.step, "hooked", None)
Expand Down
3 changes: 1 addition & 2 deletions torch/optim/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling,
_default_to_fused_or_foreach, _differentiable_doc, _foreach_doc)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ["RAdam", "radam"]

Expand Down Expand Up @@ -315,7 +314,7 @@ def _multi_tensor_radam(

assert not differentiable, "_foreach ops don't support autograd"

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, state_steps])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, state_steps])
for grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs, grouped_state_steps in grouped_tensors.values():
# Update steps
torch._foreach_add_(grouped_state_steps, 1)
Expand Down
3 changes: 1 addition & 2 deletions torch/optim/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .optimizer import (Optimizer, _default_to_fused_or_foreach, _use_grad_for_differentiable,
_differentiable_doc, _foreach_doc, _maximize_doc)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ["RMSprop", "rmsprop"]

Expand Down Expand Up @@ -326,7 +325,7 @@ def _multi_tensor_rmsprop(

assert not differentiable, "_foreach ops don't support autograd"

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, square_avgs, grad_avgs, momentum_buffer_list])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, square_avgs, grad_avgs, momentum_buffer_list])
for (grouped_params, grouped_grads, grouped_square_avgs, grouped_grad_avgs,
grouped_momentum_buffer_list) in grouped_tensors.values():
if maximize:
Expand Down
3 changes: 1 addition & 2 deletions torch/optim/rprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .optimizer import (Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach,
_differentiable_doc, _foreach_doc, _maximize_doc)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ["Rprop", "rprop"]

Expand Down Expand Up @@ -281,7 +280,7 @@ def _multi_tensor_rprop(

assert not differentiable, "_foreach ops don't support autograd"

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, prevs, step_sizes])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, prevs, step_sizes])
for grouped_params, grouped_grads, grouped_prevs, grouped_step_sizes in grouped_tensors.values():
# Handle complex params
def _view_complex_as_real(tensor_list):
Expand Down
3 changes: 1 addition & 2 deletions torch/optim/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .optimizer import (Optimizer, required, _use_grad_for_differentiable, _default_to_fused_or_foreach,
_differentiable_doc, _foreach_doc, _maximize_doc)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ['SGD', 'sgd']

Expand Down Expand Up @@ -280,7 +279,7 @@ def _multi_tensor_sgd(params: List[Tensor],
if len(params) == 0:
return

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, momentum_buffer_list], with_indices=True)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, momentum_buffer_list], with_indices=True)
for device_params, device_grads, device_momentum_buffer_list, indices in grouped_tensors.values():
device_has_sparse_grad = any(grad.is_sparse for grad in device_grads)

Expand Down