Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 89 additions & 88 deletions test/test_type_promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,91 +748,92 @@ def test_abs_angle_complex_to_float(self, device, dtype):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
@float_double_default_dtype
@onlyCPU
def test_numpy_array_binary_ufunc_promotion(self, device):
@dtypes(*list(itertools.product(torch_to_numpy_dtype_dict.keys(),
torch_to_numpy_dtype_dict.keys())))
def test_numpy_array_binary_ufunc_promotion(self, device, dtypes):
import operator
np_types = numpy_to_torch_dtype_dict.keys()
torch_types = numpy_to_torch_dtype_dict.values()

for np_type, torch_type in itertools.product(np_types, torch_types):
t = torch.tensor((1,), device=device, dtype=torch_type)
a = np.array((1,), dtype=np_type)
a_as_t = torch.from_numpy(a).to(device=device)

for np_first in (True, False):
for op in (operator.add, torch.add):

# Acquires results of binary ufunc type promotion.
try:
actual = op(a, t) if np_first else op(t, a)
except Exception as e:
actual = e

try:
expected = op(a_as_t, t) if np_first else op(t, a_as_t)
except Exception as e:
expected = e

same_result = (type(expected) == type(actual)) and expected == actual

# Note: An "undesired failure," as opposed to an "expected failure"
# is both expected (we know the test will fail) and
# undesirable (if PyTorch was working properly the test would
# not fail). This test is affected by three issues (see below)
# that will cause undesired failures. It detects when these
# issues will occur and updates this bool accordingly.
undesired_failure = False

# A NumPy array as the first argument to the plus operator
# or as any argument to torch.add is not working as
# intended.
# See https://github.com/pytorch/pytorch/issues/36363.
if np_first and op is operator.add:
undesired_failure = True
if op is torch.add:
undesired_failure = True

# float16 x bool, uint, int, and float16 interactions are not
# working as intended.
# See https://github.com/pytorch/pytorch/issues/36058.
float16_failures = (torch.bool, torch.uint8,
torch.int8, torch.int16, torch.int32, torch.int64,
torch.float16)
if torch_type is torch.float16 and \
numpy_to_torch_dtype_dict[np_type] in float16_failures:
undesired_failure = True

if torch_type in float16_failures and np_type is np.float16:
undesired_failure = True

# bool x complex interactions are not working as intended.
# See https://github.com/pytorch/pytorch/issues/36057.
if torch_type in (torch.complex64, torch.complex128) and np_type is np.bool:
undesired_failure = True

if torch_type is torch.bool and np_type in (np.complex64, np.complex128):
undesired_failure = True

# Expects the same result if undesired_failure is false
# and a different result otherwise.
# Note: These cases prettyprint the failing inputs to make
# debugging test failures easier.
if undesired_failure and same_result:
msg = ("Failure: {0} == {1}. "
"torch type was {2}. NumPy type was {3}. np_first is {4} "
"default type is {5}.").format(actual, expected,
torch_type, np_type,
np_first,
torch.get_default_dtype())
self.fail(msg)

if not undesired_failure and not same_result:
msg = ("Failure: {0} != {1}. "
"torch type was {2}. NumPy type was {3}. np_first is {4} "
"default type is {5}.").format(actual, expected,
torch_type, np_type,
np_first,
torch.get_default_dtype())
self.fail(msg)
np_type = torch_to_numpy_dtype_dict[dtypes[0]]
torch_type = dtypes[1]

t = torch.tensor((1,), device=device, dtype=torch_type)
a = np.array((1,), dtype=np_type)
a_as_t = torch.from_numpy(a).to(device=device)

for np_first in (True, False):
for op in (operator.add, torch.add):

# Acquires results of binary ufunc type promotion.
try:
actual = op(a, t) if np_first else op(t, a)
except Exception as e:
actual = e

try:
expected = op(a_as_t, t) if np_first else op(t, a_as_t)
except Exception as e:
expected = e

same_result = (type(expected) == type(actual)) and expected == actual

# Note: An "undesired failure," as opposed to an "expected failure"
# is both expected (we know the test will fail) and
# undesirable (if PyTorch was working properly the test would
# not fail). This test is affected by three issues (see below)
# that will cause undesired failures. It detects when these
# issues will occur and updates this bool accordingly.
undesired_failure = False

# A NumPy array as the first argument to the plus operator
# or as any argument to torch.add is not working as
# intended.
# See https://github.com/pytorch/pytorch/issues/36363.
if np_first and op is operator.add:
undesired_failure = True
if op is torch.add:
undesired_failure = True

# float16 x bool, uint, int, and float16 interactions are not
# working as intended.
# See https://github.com/pytorch/pytorch/issues/36058.
float16_failures = (torch.bool, torch.uint8,
torch.int8, torch.int16, torch.int32, torch.int64,
torch.float16)
if torch_type is torch.float16 and \
numpy_to_torch_dtype_dict[np_type] in float16_failures:
undesired_failure = True

if torch_type in float16_failures and np_type is np.float16:
undesired_failure = True

# bool x complex interactions are not working as intended.
# See https://github.com/pytorch/pytorch/issues/36057.
if torch_type in (torch.complex64, torch.complex128) and np_type is np.bool:
undesired_failure = True

if torch_type is torch.bool and np_type in (np.complex64, np.complex128):
undesired_failure = True

# Expects the same result if undesired_failure is false
# and a different result otherwise.
# Note: These cases prettyprint the failing inputs to make
# debugging test failures easier.
if undesired_failure and same_result:
msg = ("Failure: {0} == {1}. "
"torch type was {2}. NumPy type was {3}. np_first is {4} "
"default type is {5}.").format(actual, expected,
torch_type, np_type,
np_first,
torch.get_default_dtype())
self.fail(msg)

if not undesired_failure and not same_result:
msg = ("Failure: {0} != {1}. "
"torch type was {2}. NumPy type was {3}. np_first is {4} "
"default type is {5}.").format(actual, expected,
torch_type, np_type,
np_first,
torch.get_default_dtype())
self.fail(msg)


@onlyOnCPUAndCUDA
Expand All @@ -841,12 +842,12 @@ def test_cat_different_dtypes(self, device):
y = torch.tensor([4, 5, 6], device=device, dtype=torch.int32)
expected_out = torch.tensor([1, 2, 3, 4, 5, 6], device=device, dtype=torch.int32)
out = torch.cat([x, y])
self.assertEqual(out, expected_out, exact_dtype=True)
self.assertEqual(out, expected_out, exact_dtype=True)
z = torch.tensor([7, 8, 9], device=device, dtype=torch.int16)
expected_out = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9],
device=device, dtype=torch.int32)
out = torch.cat([x, y, z])
self.assertEqual(out, expected_out, exact_dtype=True)
self.assertEqual(out, expected_out, exact_dtype=True)

@onlyOnCPUAndCUDA
def test_cat_out_different_dtypes(self, device):
Expand All @@ -855,13 +856,13 @@ def test_cat_out_different_dtypes(self, device):
y = torch.tensor([4, 5, 6], device=device, dtype=torch.int32)
expected_out = torch.tensor([1, 2, 3, 4, 5, 6], device=device, dtype=torch.int16)
torch.cat([x, y], out=out)
self.assertEqual(out, expected_out, exact_dtype=True)
self.assertEqual(out, expected_out, exact_dtype=True)
z = torch.tensor([7, 8, 9], device=device, dtype=torch.int16)
out = torch.zeros(9, device=device, dtype=torch.int64)
expected_out = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9],
device=device, dtype=torch.int64)
torch.cat([x, y, z], out=out)
self.assertEqual(out, expected_out, exact_dtype=True)
self.assertEqual(out, expected_out, exact_dtype=True)

@onlyOnCPUAndCUDA
def test_cat_invalid_dtype_promotion(self, device):
Expand Down
37 changes: 29 additions & 8 deletions torch/testing/_internal/common_device_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
# (1b) @deviceCountAtLeast(<minimum number of devices to run test with>)
# testX(self, devices)
#
# (1c) @dtypes(<list of dtypes>)
# (1c) @dtypes(<list of dtypes> or <list of tuples of dtypes>)
# testX(self, device, dtype)
#
# (1d) @deviceCountAtLeast(<minimum number of devices to run test with>)
# @dtypes(<list of dtypes>)
# @dtypes(<list of dtypes> or <list of tuples of dtypes>)
# testX(self, devices, dtype)
#
#
Expand All @@ -40,8 +40,9 @@
# primary device. These tests will be skipped if the device type
# has fewer available devices than the argument to @deviceCountAtLeast.
#
# Tests like (1c) are called with a device string and a torch.dtype from
# the list of dtypes specified in the @dtypes decorator. Device-specific
# Tests like (1c) are called with a device string and a torch.dtype (or
# a tuple of torch.dtypes) from the list of dtypes (or list of tuples
# of torch.dtypes) specified in the @dtypes decorator. Device-specific
# dtype overrides can be specified using @dtypesIfCPU and @dtypesIfCUDA.
#
# Tests like (1d) take a devices argument like (1b) and a dtype
Expand Down Expand Up @@ -211,8 +212,15 @@ def instantiated_test(self, test=test):
setattr(cls, test_name, instantiated_test)
else: # Test has dtype variants
for dtype in dtypes:
dtype_str = str(dtype).split('.')[1]
dtype_test_name = test_name + "_" + dtype_str
# Constructs dtype suffix
if isinstance(dtype, (list, tuple)):
dtype_str = ""
for d in dtype:
dtype_str += "_" + str(d).split('.')[1]
else:
dtype_str = "_" + str(dtype).split('.')[1]

dtype_test_name = test_name + dtype_str
assert not hasattr(cls, dtype_test_name), "Redefinition of test {0}".format(dtype_test_name)

@wraps(test)
Expand Down Expand Up @@ -530,13 +538,26 @@ def __call__(self, fn):
# (1) Tests that accept the dtype argument MUST use this decorator.
# (2) Can be overridden for the CPU or CUDA, respectively, using dtypesIfCPU
# or dtypesIfCUDA.
# (3) Prefer the existing decorators to defining the 'device_type' kwarg.
# (3) Can accept an iterable of dtypes or an iterable of tuples
# of dtypes.
# Examples:
# @dtypes(torch.float32, torch.float64)
# @dtypes((torch.long, torch.float32), (torch.int, torch.float64))
class dtypes(object):

# Note: *args, **kwargs for Python2 compat.
# Python 3 allows (self, *args, device_type='all').
def __init__(self, *args, **kwargs):
assert all(isinstance(arg, torch.dtype) for arg in args), "Unknown dtype in {0}".format(str(args))
if len(args) > 0 and isinstance(args[0], (list, tuple)):
for arg in args:
assert isinstance(arg, (list, tuple)), \
"When one dtype variant is a tuple or list, " \
"all dtype variants must be. " \
"Received non-list non-tuple dtype {0}".format(str(arg))
assert all(isinstance(dtype, torch.dtype) for dtype in arg), "Unknown dtype in {0}".format(str(arg))
else:
assert all(isinstance(arg, torch.dtype) for arg in args), "Unknown dtype in {0}".format(str(args))

self.args = args
self.device_type = kwargs.get('device_type', 'all')

Expand Down