Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Meta register all foreach ops #112281

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 182 additions & 148 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from enum import Enum
from functools import partial
from typing import List, Optional, Sequence, Tuple, Union

import torch
Expand Down Expand Up @@ -2938,37 +2939,166 @@ def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
return self.new_empty(self.size())


@register_meta(
def register_meta_foreach(ops):
def wrapper(fn):
def register(op):
op_name = str(op).split(".")[1]
scalar_op = getattr(aten, op_name.replace("_foreach_", ""))

_add_op_to_registry(
meta_table,
op,
partial(
fn,
_scalar_op=scalar_op,
),
)

pytree.tree_map_(register, ops)
return fn

return wrapper


@register_meta_foreach(
[
aten._foreach_abs_.default,
aten._foreach_neg_.default,
aten._foreach_reciprocal_.default,
aten._foreach_sqrt_.default,
aten._foreach_sign_.default,
]
aten._foreach_abs,
aten._foreach_acos,
aten._foreach_asin,
aten._foreach_atan,
aten._foreach_ceil,
aten._foreach_cos,
aten._foreach_cosh,
aten._foreach_erf,
aten._foreach_erfc,
aten._foreach_exp,
aten._foreach_expm1,
aten._foreach_frac,
aten._foreach_floor,
aten._foreach_lgamma,
aten._foreach_log,
aten._foreach_log10,
aten._foreach_log1p,
aten._foreach_log2,
aten._foreach_neg,
aten._foreach_reciprocal,
aten._foreach_round,
aten._foreach_sigmoid,
aten._foreach_sign,
aten._foreach_sin,
aten._foreach_sinh,
aten._foreach_sqrt,
aten._foreach_tan,
aten._foreach_tanh,
aten._foreach_trunc,
aten._foreach_zero,
aten._foreach_add,
aten._foreach_sub,
aten._foreach_mul,
aten._foreach_div,
aten._foreach_clamp_min,
aten._foreach_clamp_max,
aten._foreach_lerp,
],
)
def meta__foreach_unaop_(self):
def _meta_foreach_out_of_place(*args, _scalar_op=None, **kwargs):
torch._check(
isinstance(self, List),
lambda: f"Expect List[Tensor] but got {type(self)}",
isinstance(args[0], list),
lambda: (f"The first argument must be List[Tensor], but got {type(args[0])}."),
)

nelem = len(args[0])
torch._check(
nelem > 0,
lambda: ("Tensor list must have at least one tensor."),
)

@register_meta(
nlists = 1
for iarg, arg in enumerate(args[1:]):
if isinstance(arg, list):
nlists += 1
torch._check(
len(arg) == nelem,
lambda: (
f"self and argument-{iarg+2} must match in length, "
f"but got {nelem} and {len(arg)}."
),
)
elif isinstance(arg, Tensor):
torch._check(
arg.dim() == 0 and arg.numel() == 1,
lambda: (
"scalar tensor expected to be 0 dim but it has "
f"{arg.dim()} dimensions and {arg.numel()} elements."
),
)
else:
break

result = []
for elem in range(nelem):
each_args = [args[i][elem] for i in range(nlists)]
result.append(_scalar_op(*each_args, *args[nlists:], **kwargs))

return result


@register_meta_foreach(
[
aten._foreach_abs.default,
aten._foreach_neg.default,
aten._foreach_reciprocal.default,
aten._foreach_sqrt.default,
aten._foreach_sign.default,
aten._foreach_abs_,
aten._foreach_acos_,
aten._foreach_asin_,
aten._foreach_atan_,
aten._foreach_ceil_,
aten._foreach_cos_,
aten._foreach_cosh_,
aten._foreach_erf_,
aten._foreach_erfc_,
aten._foreach_exp_,
aten._foreach_expm1_,
aten._foreach_frac_,
aten._foreach_floor_,
aten._foreach_lgamma_,
aten._foreach_log_,
aten._foreach_log10_,
aten._foreach_log1p_,
aten._foreach_log2_,
aten._foreach_neg_,
aten._foreach_reciprocal_,
aten._foreach_round_,
aten._foreach_sigmoid_,
aten._foreach_sign_,
aten._foreach_sin_,
aten._foreach_sinh_,
aten._foreach_sqrt_,
aten._foreach_tan_,
aten._foreach_tanh_,
aten._foreach_trunc_,
aten._foreach_zero_,
aten._foreach_add_,
aten._foreach_sub_,
aten._foreach_mul_,
aten._foreach_div_,
aten._foreach_clamp_min_,
aten._foreach_clamp_max_,
aten._foreach_lerp_,
aten._foreach_copy_,
]
)
def meta__foreach_unaop(self):
def _meta_foreach_inplace(*args, _scalar_op=None, **kwargs):
_meta_foreach_out_of_place(*args, _scalar_op=_scalar_op, **kwargs)
return


@register_meta([aten._foreach_pow.ScalarAndTensor])
def meta__foreach_pow_scalar_and_tensor(self, exponent):
# Only foreach_pow has a ScalarAndTensor method and needs special
# handling because it does not work with _meta_foreach_out_of_place.
torch._check(
isinstance(self, List),
lambda: f"Expect List[Tensor] but got {type(self)}",
isinstance(exponent, List),
lambda: f"exponent must be a tensor list but got {type(exponent)}",
)
return [torch.empty_like(s) for s in self]
return [torch.empty_like(e) for e in exponent]


def _check_foreach_binop_tensor_lists(self, other):
Expand All @@ -2990,130 +3120,25 @@ def _check_foreach_binop_tensor_lists(self, other):

@register_meta(
[
aten._foreach_add.List,
aten._foreach_sub.List,
aten._foreach_mul.List,
aten._foreach_div.List,
aten._foreach_maximum.List,
aten._foreach_minimum.List,
aten._foreach_clamp_min.List,
aten._foreach_clamp_max.List,
]
)
def meta__foreach_binop_list(self, other, alpha=1):
_check_foreach_binop_tensor_lists(self, other)
return [torch.empty_like(s) for s in self]


@register_meta(
[
aten._foreach_add_.List,
aten._foreach_sub_.List,
aten._foreach_mul_.List,
aten._foreach_div_.List,
aten._foreach_maximum_.List,
aten._foreach_minimum_.List,
aten._foreach_clamp_min_.List,
aten._foreach_clamp_max_.List,
]
)
def meta__foreach_binop__list(self, other, alpha=1):
_check_foreach_binop_tensor_lists(self, other)


@register_meta(
[
aten._foreach_add.Tensor,
]
)
def meta__foreach_binop_tensor(self, other, alpha=1):
torch._check(
isinstance(self, List),
lambda: f"The first argument must be List[Tensor], but got {type(self)}.",
)
torch._check(
isinstance(other, torch.Tensor),
lambda: f"The second argument must be Tensor, but got {type(other)}.",
)
return [torch.empty_like(s) for s in self]


@register_meta(
[
aten._foreach_add_.Tensor,
]
)
def meta__foreach_binop__tensor(self, other, alpha=1):
torch._check(
isinstance(self, List),
lambda: f"The first argument must be List[Tensor], but got {type(self)}.",
)
torch._check(
isinstance(other, torch.Tensor),
lambda: f"The second argument must be Tensor, but got {type(other)}.",
)


@register_meta(
[
aten._foreach_add_.Scalar,
aten._foreach_mul_.Scalar,
aten._foreach_sub_.Scalar,
aten._foreach_div_.Scalar,
aten._foreach_maximum_.Scalar,
aten._foreach_maximum,
aten._foreach_minimum,
]
)
def meta__foreach_binop__scalar(self, scalar=1):
torch._check(
isinstance(self, List),
lambda: f"The first argument of must be List[Tensor], but got {type(self)}.",
)
def meta__foreach_binop_scalar(*args):
# aten.maximum(Tensor, Scalar) does not exist.
return _meta_foreach_out_of_place(*args, _scalar_op=aten.clamp_min)


@register_meta(
[
aten._foreach_add.Scalar,
aten._foreach_div.Scalar,
aten._foreach_mul.Scalar,
aten._foreach_sub.Scalar,
]
)
def meta__foreach_binop_scalar(self, scalar=1):
torch._check(
isinstance(self, List),
lambda: f"The first argument of must be List[Tensor], but got {type(self)}.",
)
return [torch.empty_like(s) for s in self]


@register_meta(
[
aten._foreach_addcdiv_.Scalar,
aten._foreach_addcmul_.Scalar,
aten._foreach_maximum_,
aten._foreach_minimum_,
]
)
def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1):
torch._check(
all(isinstance(l, List) for l in [self, tensor1, tensor2]),
lambda: (
"All arguments of _foreach_addc*_ must be List[Tensor], "
f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
),
)
torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
torch._check(
len(self) == len(tensor1) and len(self) == len(tensor2),
lambda: "All input tensor lists must have the same length",
)


@register_meta(
[
aten._foreach_lerp_.Scalar,
]
)
def meta__foreach_lerp__scalar(self, other, scalar=1):
_check_foreach_binop_tensor_lists(self, other)
def meta__foreach_binop__scalar(*args):
# aten.maximum(Tensor, Scalar) does not exist
_meta_foreach_inplace(*args, _scalar_op=aten.clamp_min_)
return


@register_meta(
Expand All @@ -3123,6 +3148,8 @@ def meta__foreach_lerp__scalar(self, other, scalar=1):
]
)
def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1):
# forach_addcdiv and addcdiv have different signatures and
# cannot use _meta_foreach_out_of_place.
torch._check(
all(isinstance(l, List) for l in [self, tensor1, tensor2]),
lambda: (
Expand All @@ -3139,15 +3166,6 @@ def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1):
return [torch.empty_like(s) for s in self]


@register_meta([aten._foreach_pow.ScalarAndTensor])
def meta__foreach_pow_scalar_and_tensor(self, exponent):
torch._check(
isinstance(exponent, List),
lambda: f"exponent must be a tensor list but got {type(exponent)}",
)
return [torch.empty_like(e) for e in exponent]


@register_meta([aten._foreach_addcdiv_.Tensor, aten._foreach_addcmul_.Tensor])
def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars):
torch._check(
Expand All @@ -3165,9 +3183,25 @@ def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars):
)


@register_meta([aten._foreach_copy_])
def meta__foreach_copy_inplace(self, src, non_blocking=False):
_check_foreach_binop_tensor_lists(self, src)
@register_meta(
[
aten._foreach_addcdiv_.Scalar,
aten._foreach_addcmul_.Scalar,
]
)
def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1):
torch._check(
all(isinstance(l, List) for l in [self, tensor1, tensor2]),
lambda: (
"All arguments of _foreach_addc*_ must be List[Tensor], "
f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
),
)
torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
torch._check(
len(self) == len(tensor1) and len(self) == len(tensor2),
lambda: "All input tensor lists must have the same length",
)


@register_meta([aten._fused_adam_.default])
Expand Down
Loading
Loading