Skip to content

Commit

Permalink
OpInfo: Remove promotes_integers_to_float and infer it instead (#50279)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #50279

This allows different sample inputs to have different behavior for the same
operator. For example, `div(..., rounding_mode='true')` will promote but other
rounding modes don't. The current boolean flag is too restrictive to allow this.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D25950011

Pulled By: mruberry

fbshipit-source-id: 7e82b82bedc626b2b6970d92d5b25676183ec384
  • Loading branch information
peterbell10 authored and facebook-github-bot committed Jan 22, 2021
1 parent 4bbff92 commit 0436ea1
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 75 deletions.
3 changes: 1 addition & 2 deletions test/test_ops.py
Expand Up @@ -221,8 +221,7 @@ def test_variant_consistency_eager(self, device, dtype, op):
for variant in variants:
# Verifies that inplace operations that promote int->float fail
# on tensors with integer dtypes.
if (variant is inplace and op.promotes_integers_to_float and
dtype in (torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)):
if (variant is inplace and not torch.can_cast(expected_forward.dtype, dtype)):
try:
variant_forward = variant(*(clone_input_helper(input) for input in sample.input),
*sample.args,
Expand Down
56 changes: 16 additions & 40 deletions test/test_unary_ufuncs.py
Expand Up @@ -10,16 +10,16 @@

from torch._six import inf, nan
from torch.testing._internal.common_utils import (
TestCase, run_tests, torch_to_numpy_dtype_dict, suppress_warnings,
IS_MACOS, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy)
TestCase, run_tests, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict,
suppress_warnings, IS_MACOS, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy)
from torch.testing._internal.common_methods_invocations import (
unary_ufuncs)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, ops, dtypes, onlyCPU, onlyOnCPUAndCUDA,
onlyCUDA, dtypesIfCUDA, precisionOverride, skipCUDAIfRocm, dtypesIfCPU,
OpDTypes)
from torch.testing import (
floating_types_and, integral_types, all_types_and_complex_and, floating_types)
floating_types_and, all_types_and_complex_and, floating_types)

if TEST_SCIPY:
import scipy
Expand Down Expand Up @@ -214,7 +214,7 @@ def _fn(t):
if alt is None:
continue

if inplace and op.promotes_integers_to_float and dtype in integral_types() + (torch.bool,):
if inplace and not torch.can_cast(expected.dtype, dtype):
# Assert that RuntimeError is raised
# for inplace variant of Operators that
# promote integer input to floating dtype.
Expand Down Expand Up @@ -285,7 +285,7 @@ def test_reference_numerics(self, device, dtype, op):
msg = None

exact_dtype = True
if op.promotes_integers_to_float and dtype in integral_types() + (torch.bool,):
if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype.type], dtype):
exact_dtype = False

if dtype in [torch.uint8, torch.int8, torch.bool]:
Expand Down Expand Up @@ -402,53 +402,29 @@ def test_batch_vs_slicing(self, device, dtype, op):

self.assertEqual(actual, expected)

def _test_out_arg(self, op, input, output):
dtype = input.dtype
out_dtype = output.dtype
if dtype is out_dtype:
expected = op(input)
op(input, out=output)
self.assertEqual(output, expected)
def _test_out_arg(self, op, input, output, expected):
if op.safe_casts_outputs:
expect_fail = not torch.can_cast(expected.dtype, output.dtype)
else:
with self.assertRaises(RuntimeError):
op(input, out=output)
expect_fail = output.dtype != expected.dtype

def _test_out_promote_int_to_float_op(self, op, input, output):
def compare_out(op, input, out):
out_dtype = out.dtype
expected = op(input)
op(input, out=out)
self.assertEqual(out, expected.to(out_dtype))

dtype = input.dtype
out_dtype = output.dtype
if out_dtype.is_floating_point and not dtype.is_complex:
compare_out(op, input, output)
elif out_dtype.is_floating_point and dtype.is_complex:
if op.supports_complex_to_float:
compare_out(op, input, output)
else:
# Can't cast complex to float
with self.assertRaises(RuntimeError):
op(input, out=output)
elif out_dtype.is_complex:
compare_out(op, input, output)
else:
# Can't cast to Integral types
if expect_fail:
with self.assertRaises(RuntimeError):
op(input, out=output)
else:
res = op(input, out=output)
self.assertTrue(res is output)
self.assertEqual(output, expected.to(output.dtype))

@ops(unary_ufuncs, dtypes=OpDTypes.supported)
def test_out_arg_all_dtypes(self, device, dtype, op):
input = make_tensor((64, 64), dtype=dtype, device=device,
low=op.domain[0], high=op.domain[1])
expected = op(input)

for out_dtype in all_types_and_complex_and(torch.bool, torch.half):
out = torch.empty_like(input, dtype=out_dtype)
if op.promotes_integers_to_float:
self._test_out_promote_int_to_float_op(op, input, out)
else:
self._test_out_arg(op, input, out)
self._test_out_arg(op, input, out, expected)

@dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool] +
torch.testing.get_all_fp_dtypes(include_bfloat16=False)))
Expand Down

0 comments on commit 0436ea1

Please sign in to comment.