Skip to content

Commit

Permalink
ForeachFuncInfo("zero")
Browse files Browse the repository at this point in the history
with a fake `op` as arg to OpInfo dunder init

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar authored and pytorchmergebot committed May 17, 2023
1 parent 6ec8fc2 commit c83cb65
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 56 deletions.
78 changes: 32 additions & 46 deletions test/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
(instantiate_device_type_tests, dtypes, onlyCUDA, ops, OpDTypes)
from torch.testing._internal.common_methods_invocations import (
foreach_unary_op_db, foreach_binary_op_db, foreach_pointwise_op_db,
foreach_reduce_op_db, foreach_lerp_op_db, foreach_inputs_sample_func)
foreach_reduce_op_db, foreach_lerp_op_db)
from torch.testing._internal.common_dtype import (
all_types_and_complex_and, integral_types, complex_types,
floating_types_and, floating_types, integral_types_and,
Expand Down Expand Up @@ -400,36 +400,40 @@ def _inplace_unary_test(self, inplace, inplace_ref, inputs, is_fastpath, **kwarg
@ops(foreach_unary_op_db)
@parametrize("is_fastpath", (True, False))
def test_unary_op(self, device, dtype, op, is_fastpath):
out_place_defined = op.name != "_foreach_zero"
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
samples = op.sample_inputs(device, dtype, noncontiguous=not is_fastpath)
disable_fastpath = op.name == "_foreach_abs" and dtype in complex_types()
for sample in samples:
zero_size = sample.kwargs.pop('zero_size')
inputs = [sample.input]
if zero_size:
wrapped_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size)
if out_place_defined:
wrapped_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size)
inplace_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size)
continue
inputs = [sample.input]
disable_fastpath = (op.name == "_foreach_abs" and dtype in complex_types()) or sample.kwargs.pop(
"disable_fastpath"
)
self.assertEqual(
ref(inputs),
wrapped_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size),
)
if out_place_defined:
self.assertEqual(
ref(inputs),
wrapped_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size),
)
self._inplace_unary_test(
inplace_op, inplace_ref, [sample.input], is_fastpath and not disable_fastpath, zero_size=zero_size
)
if op.supports_autograd and dtype in floating_types() and not zero_size:
tensors = [t.clone().detach().requires_grad_() for t in sample.input]
ref_tensors = [t.clone().detach().requires_grad_() for t in tensors]
out = wrapped_op.func(tensors)
# tensors have different shapes
torch.cat([t.view(-1) for t in out]).mean().backward()
torch.cat([ref.func(t).view(-1) for t in ref_tensors]).mean().backward()
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
self.assertEqual(len({t.grad_fn for t in out}), 1)
if out_place_defined:
out = wrapped_op.func(tensors)
# tensors have different shapes
torch.cat([t.view(-1) for t in out]).mean().backward()
torch.cat([ref.func(t).view(-1) for t in ref_tensors]).mean().backward()
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
self.assertEqual(len({t.grad_fn for t in out}), 1)

inplace_input_tensors = [t.clone().detach().requires_grad_() for t in tensors]
inplace_inputs = [t.clone() for t in inplace_input_tensors]
Expand Down Expand Up @@ -687,26 +691,31 @@ def test_binary_op_float_inf_nan(self, device, dtype, op):
@onlyCUDA
@ops(foreach_unary_op_db)
def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
out_place_defined = op.name != "_foreach_zero"
method, ref, inplace_method, ref_inplace = self._get_funcs(op)
# tensors: ['cuda', 'cpu]
tensors = list(op.sample_inputs(device, dtype, num_input_tensors=[2]))[0].input
tensors[1] = tensors[1].to("cpu")
try:
actual = method((tensors,), False, False, zero_size=False)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), str(e)):
ref((tensors,))
else:
expected = ref((tensors,))
self.assertEqual(expected, actual)
if out_place_defined:
try:
actual = method((tensors,), False, False, zero_size=False)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), str(e)):
ref((tensors,))
else:
expected = ref((tensors,))
self.assertEqual(expected, actual)

try:
inplace_method((tensors,), False, False, zero_size=False)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), str(e)):
ref_inplace((tensors,))
else:
self.assertEqual(expected, tensors)
if out_place_defined:
self.assertEqual(expected, tensors)
else:
self.assertEqual([torch.zeros_like(t) for t in tensors], tensors)

@onlyCUDA
@ops(foreach_binary_op_db)
Expand Down Expand Up @@ -888,6 +897,8 @@ def test_inplace_foreach_leaf_check_and_grad_fn(self, device, dtype, op):
dtypes=(torch.float,),
)
def test_outplace_with_invalid_grads(self, device, dtype, op):
if op.name in {"_foreach_zero"}:
self.skipTest(f"{op.name} does not have out-place implementation")
func, *_ = self._get_funcs(op)
sample = list(op.sample_inputs(dtype=dtype, device=device, requires_grad=True, num_input_tensors=[2], same_size=True))[0]
self.assertTrue(all(t.requires_grad for t in sample.input))
Expand All @@ -900,31 +911,6 @@ def test_outplace_with_invalid_grads(self, device, dtype, op):
self.assertIsNotNone(sample.input[0].grad)
self.assertIsNone(sample.input[1].grad)

# note(crcrpar): this is clumsy but we don't have out-place `_foreach_zero` so make it a special case.
@dtypes(*floating_types_and(torch.half, torch.bfloat16,))
def test_foreach_zero(self, device, dtype):

# Needed to let sample inputs func below work.
class FakeOpInfo:
ref = torch.zero_

for sample in foreach_inputs_sample_func(1, False, False)(FakeOpInfo(), device, dtype, requires_grad=True):
original_inputs = sample.input
inputs = [t.clone() for t in original_inputs]
zeros = [torch.zeros_like(t) for t in original_inputs]
torch._foreach_zero_(inputs)
# checking a function call doesn't fail is enough when `zero_size`
if sample.kwargs["zero_size"]:
continue
self.assertEqual(inputs, zeros)
self.assertTrue(all(t.grad_fn.name() == "ZeroBackward0" for t in inputs))
sum_of_cloned_tensors = torch.cat([t.clone().view(-1) for t in inputs]).sum()
grad_output = torch.rand_like(sum_of_cloned_tensors)
self.assertEqual(
torch.autograd.grad(sum_of_cloned_tensors, inputs=original_inputs, grad_outputs=(grad_output,)),
zeros,
)


instantiate_device_type_tests(TestForeach, globals())

Expand Down
7 changes: 7 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8474,6 +8474,13 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),

ForeachFuncInfo(
'zero',
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
]

foreach_binary_op_db: List[OpInfo] = [
Expand Down
28 changes: 18 additions & 10 deletions torch/testing/_internal/opinfo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2659,8 +2659,25 @@ def __init__(
supports_scalar_self_arg=False,
**kwargs,
):
(
foreach_method,
foreach_method_inplace,
torch_ref_method,
torch_ref_inplace,
) = get_foreach_method_names(name)
if name == "zero":
# note(crcrpar): `foreach_method` for `"zero"` is `None` but `None` would call
# `_getattr_qual` in `OpInfo.__post_init__` which should fail since `_foreach_zero`
# is not defined at the moment. Thus to skip the qualification, set a similar torch
# function.
assert foreach_method is None
foreach_method = torch.zero_
super().__init__(
"_foreach_" + name,
name="_foreach_" + name,
op=foreach_method,
ref=torch_ref_method,
method_variant=foreach_method,
inplace_variant=foreach_method_inplace,
dtypes=dtypes,
dtypesIfCUDA=dtypesIfCUDA,
dtypesIfROCM=dtypesIfROCM,
Expand All @@ -2670,15 +2687,6 @@ def __init__(
)
self.supports_scalar_self_arg = supports_scalar_self_arg

(
foreach_method,
foreach_method_inplace,
torch_ref_method,
torch_ref_inplace,
) = get_foreach_method_names(name)
self.method_variant = foreach_method
self.inplace_variant = foreach_method_inplace
self.ref = torch_ref_method
self.ref_inplace = torch_ref_inplace
self.supports_alpha_param = supports_alpha_param

Expand Down

0 comments on commit c83cb65

Please sign in to comment.