Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable backward on _foreach_zero_ #101149

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 31 additions & 20 deletions test/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,36 +400,40 @@ def _inplace_unary_test(self, inplace, inplace_ref, inputs, is_fastpath, **kwarg
@ops(foreach_unary_op_db)
@parametrize("is_fastpath", (True, False))
def test_unary_op(self, device, dtype, op, is_fastpath):
out_place_defined = op.name != "_foreach_zero"
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
samples = op.sample_inputs(device, dtype, noncontiguous=not is_fastpath)
disable_fastpath = op.name == "_foreach_abs" and dtype in complex_types()
for sample in samples:
zero_size = sample.kwargs.pop('zero_size')
inputs = [sample.input]
if zero_size:
wrapped_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size)
if out_place_defined:
wrapped_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size)
inplace_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size)
continue
inputs = [sample.input]
disable_fastpath = (op.name == "_foreach_abs" and dtype in complex_types()) or sample.kwargs.pop(
"disable_fastpath"
)
self.assertEqual(
ref(inputs),
wrapped_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size),
)
if out_place_defined:
self.assertEqual(
ref(inputs),
wrapped_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size),
)
self._inplace_unary_test(
inplace_op, inplace_ref, [sample.input], is_fastpath and not disable_fastpath, zero_size=zero_size
)
if op.supports_autograd and dtype in floating_types() and not zero_size:
tensors = [t.clone().detach().requires_grad_() for t in sample.input]
ref_tensors = [t.clone().detach().requires_grad_() for t in tensors]
out = wrapped_op.func(tensors)
# tensors have different shapes
torch.cat([t.view(-1) for t in out]).mean().backward()
torch.cat([ref.func(t).view(-1) for t in ref_tensors]).mean().backward()
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
self.assertEqual(len({t.grad_fn for t in out}), 1)
if out_place_defined:
out = wrapped_op.func(tensors)
# tensors have different shapes
torch.cat([t.view(-1) for t in out]).mean().backward()
torch.cat([ref.func(t).view(-1) for t in ref_tensors]).mean().backward()
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
self.assertEqual(len({t.grad_fn for t in out}), 1)

inplace_input_tensors = [t.clone().detach().requires_grad_() for t in tensors]
inplace_inputs = [t.clone() for t in inplace_input_tensors]
Expand Down Expand Up @@ -687,26 +691,31 @@ def test_binary_op_float_inf_nan(self, device, dtype, op):
@onlyCUDA
@ops(foreach_unary_op_db)
def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
out_place_defined = op.name != "_foreach_zero"
method, ref, inplace_method, ref_inplace = self._get_funcs(op)
# tensors: ['cuda', 'cpu]
tensors = list(op.sample_inputs(device, dtype, num_input_tensors=[2]))[0].input
tensors[1] = tensors[1].to("cpu")
try:
actual = method((tensors,), False, False, zero_size=False)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), str(e)):
ref((tensors,))
else:
expected = ref((tensors,))
self.assertEqual(expected, actual)
if out_place_defined:
try:
actual = method((tensors,), False, False, zero_size=False)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), str(e)):
ref((tensors,))
else:
expected = ref((tensors,))
self.assertEqual(expected, actual)

try:
inplace_method((tensors,), False, False, zero_size=False)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), str(e)):
ref_inplace((tensors,))
else:
self.assertEqual(expected, tensors)
if out_place_defined:
self.assertEqual(expected, tensors)
else:
self.assertEqual([torch.zeros_like(t) for t in tensors], tensors)

@onlyCUDA
@ops(foreach_binary_op_db)
Expand Down Expand Up @@ -888,6 +897,8 @@ def test_inplace_foreach_leaf_check_and_grad_fn(self, device, dtype, op):
dtypes=(torch.float,),
)
def test_outplace_with_invalid_grads(self, device, dtype, op):
if op.name in {"_foreach_zero"}:
self.skipTest(f"{op.name} does not have out-place implementation")
func, *_ = self._get_funcs(op)
sample = list(op.sample_inputs(dtype=dtype, device=device, requires_grad=True, num_input_tensors=[2], same_size=True))[0]
self.assertTrue(all(t.requires_grad for t in sample.input))
Expand Down
2 changes: 0 additions & 2 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,8 +977,6 @@ def gen_variable_type_func(
# No reference backward available as addcdiv/addcmul don't support Tensor as scaling factor.
("_foreach_addcdiv", "Tensor"),
("_foreach_addcmul", "Tensor"),
# FIXME(crcrpar): Let `_foreach_zero_` have backward.
("_foreach_zero", ""),
}

_foreach_ops_with_different_arity = {
Expand Down
7 changes: 7 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8474,6 +8474,13 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),

ForeachFuncInfo(
'zero',
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
]

foreach_binary_op_db: List[OpInfo] = [
Expand Down
28 changes: 18 additions & 10 deletions torch/testing/_internal/opinfo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2659,8 +2659,25 @@ def __init__(
supports_scalar_self_arg=False,
**kwargs,
):
(
foreach_method,
foreach_method_inplace,
torch_ref_method,
torch_ref_inplace,
) = get_foreach_method_names(name)
if name == "zero":
# note(crcrpar): `foreach_method` for `"zero"` is `None` but `None` would call
# `_getattr_qual` in `OpInfo.__post_init__` which should fail since `_foreach_zero`
# is not defined at the moment. Thus to skip the qualification, set a similar torch
# function.
assert foreach_method is None
foreach_method = torch.zero_
super().__init__(
"_foreach_" + name,
name="_foreach_" + name,
op=foreach_method,
ref=torch_ref_method,
method_variant=foreach_method,
inplace_variant=foreach_method_inplace,
dtypes=dtypes,
dtypesIfCUDA=dtypesIfCUDA,
dtypesIfROCM=dtypesIfROCM,
Expand All @@ -2670,15 +2687,6 @@ def __init__(
)
self.supports_scalar_self_arg = supports_scalar_self_arg

(
foreach_method,
foreach_method_inplace,
torch_ref_method,
torch_ref_inplace,
) = get_foreach_method_names(name)
self.method_variant = foreach_method
self.inplace_variant = foreach_method_inplace
self.ref = torch_ref_method
self.ref_inplace = torch_ref_inplace
self.supports_alpha_param = supports_alpha_param

Expand Down
51 changes: 32 additions & 19 deletions torchgen/api/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,12 @@
return f.func.name.name.base.startswith("_foreach_")


# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind
# is functional for their backward derivatives (and forward derivatives in the future), i.e.,
# they would find such one in `functional_info_by_signature`. There however are some exceptions:
_foreach_with_inplace_ref = {"_foreach_zero_"}


# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
# reference to generate derivatives.
def is_reference_for_foreach(
Expand All @@ -318,7 +324,10 @@
) -> bool:
return (
f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base
and not function_schema.name.name.inplace
and (
not function_schema.name.name.inplace
or str(f.func.name) in _foreach_with_inplace_ref
)
and all(
ref_arg.type in (arg.type, getattr(arg.type, "elem", None))
for arg, ref_arg in zip(
Expand All @@ -330,39 +339,41 @@


# TODO(crcrpar): Avoid hard coding "Default" ideally.
def gen_foreach_derivativeinfo(

Check notice on line 342 in torchgen/api/autograd.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function gen_foreach_derivativeinfo: differentiability_infos was removed

Check notice on line 342 in torchgen/api/autograd.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function gen_foreach_derivativeinfo: differentiability_infos was removed
foreach_function: NativeFunction,
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
functional_info_by_signature: Dict[
FunctionSchema, Dict[str, DifferentiabilityInfo]
],
non_functional_info_by_signature: Dict[

Check notice on line 347 in torchgen/api/autograd.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function gen_foreach_derivativeinfo: non_functional_info_by_signature was added and is now required

Check notice on line 347 in torchgen/api/autograd.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function gen_foreach_derivativeinfo: non_functional_info_by_signature was added and is now required
crcrpar marked this conversation as resolved.
Show resolved Hide resolved
FunctionSchema, Dict[str, DifferentiabilityInfo]
],
dispatch_key: str = "Default",
) -> Tuple[Optional[DifferentiabilityInfo], bool]:
"""Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.

The second return value indicates whether the info is generated in this function.
"""
ref_diff_info: Optional[DifferentiabilityInfo] = None
for function_schema in functional_info_by_signature:

for function_schema, diff_info in functional_info_by_signature.items():
if not is_reference_for_foreach(foreach_function, function_schema):
continue
if function_schema in differentiability_infos:
ref_diff_info = differentiability_infos[function_schema][dispatch_key]
elif (
function_schema.signature(strip_default=True)
in functional_info_by_signature
):
ref_diff_info = functional_info_by_signature[
function_schema.signature(strip_default=True)
][dispatch_key]
else:
raise RuntimeError(
"Reference `DifferentiabilityInfo` for {} not found".format(
foreach_function.func
)
)
ref_diff_info = diff_info[dispatch_key]
if ref_diff_info is not None:
break
# note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature
# while the info of `zero_` is in non_functional_info_by_signature
if (
ref_diff_info is None
and foreach_function.func.kind() == SchemaKind.inplace
and str(foreach_function.func.name) in _foreach_with_inplace_ref
):
for function_schema, diff_info in non_functional_info_by_signature.items():
if not is_reference_for_foreach(foreach_function, function_schema):
continue
ref_diff_info = diff_info[dispatch_key]
if ref_diff_info is not None:
break
if ref_diff_info is None:
return None, False

Expand Down Expand Up @@ -534,7 +545,9 @@
if is_foreach_func(f):
assert f.func not in differentiability_infos
diff_info, is_generated = gen_foreach_derivativeinfo(
f, differentiability_infos, functional_info_by_signature
f,
functional_info_by_signature,
non_functional_info_by_signature,
)
if diff_info is None:
return None, False
Expand Down