Skip to content
Closed
4 changes: 2 additions & 2 deletions test/test_tensor_creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
onlyCUDA, skipCPUIf, dtypesIfCUDA, skipMeta)
from torch.testing._internal.common_dtype import (
all_types_and_complex_and, all_types_and, floating_and_complex_types,
floating_types, floating_and_complex_types_and, integral_types, integral_types_and, get_all_dtypes
floating_types, floating_and_complex_types_and, integral_types, integral_types_and, get_all_dtypes,
float_to_corresponding_complex_type_map
)
from torch.testing._creation import float_to_corresponding_complex_type_map

from torch.utils.dlpack import to_dlpack

Expand Down
8 changes: 5 additions & 3 deletions test/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,9 @@ def test_requires_grad(self, dtype, device, requires_grad):
t = make_tensor()
self.assertEqual(t.requires_grad, requires_grad)
else:
with self.assertRaisesRegex(ValueError, "requires_grad must be False for integral dtype"):
with self.assertRaisesRegex(
ValueError, "`requires_grad=True` is not supported for boolean and integral dtypes"
):
make_tensor()

@supported_dtypes
Expand Down Expand Up @@ -1431,7 +1433,7 @@ def test_memory_format(self, dtype, device, memory_format_and_shape):

@supported_dtypes
def test_noncontiguous_memory_format(self, dtype, device):
with self.assertRaises(AssertionError):
with self.assertRaisesRegex(ValueError, "`noncontiguous` and `memory_format` are mutually exclusive"):
torch.testing.make_tensor(
(2, 3, 4, 5),
dtype=dtype,
Expand Down Expand Up @@ -1505,7 +1507,7 @@ def test_low_high_nan(self, dtype, device, low_high):
# FIXME: bool needs to fail as well
return

with self.assertRaisesRegex(ValueError, "one of low or high was NaN"):
with self.assertRaisesRegex(ValueError, "`low` and `high` cannot be NaN"):
torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high)

# FIXME this fails for all dtypes
Expand Down
6 changes: 2 additions & 4 deletions test/test_type_promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes,
dtypes, onlyCPU, expectedFailureMeta, skipMeta)
from torch.testing._internal.common_dtype import (
all_types_and_complex_and, get_all_math_dtypes, floating_types, get_all_dtypes
)
from torch.testing._creation import (
float_to_corresponding_complex_type_map
all_types_and_complex_and, get_all_math_dtypes, floating_types, get_all_dtypes,
float_to_corresponding_complex_type_map,
)


Expand Down
140 changes: 72 additions & 68 deletions torch/testing/_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,14 @@

import torch

# Used by make_tensor for generating complex tensor.
complex_to_corresponding_float_type_map = {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need that since torch.finfo works for complex types as well.

torch.complex32: torch.float16,
torch.complex64: torch.float32,
torch.complex128: torch.float64,
}
float_to_corresponding_complex_type_map = {
v: k for k, v in complex_to_corresponding_float_type_map.items()
}


def _uniform_random(t: torch.Tensor, low: float, high: float):
_INTEGRAL_TYPES = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
_COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128]
_BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES]
_FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES]
Comment on lines +11 to +15
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining them here do avoid doing it over and over again inside make_tensor.



def _uniform_random_(t: torch.Tensor, low: float, high: float) -> torch.Tensor:
# uniform_ requires to-from <= std::numeric_limits<scalar_t>::max()
# Work around this by scaling the range before and after the PRNG
if high - low >= torch.finfo(t.dtype).max:
Expand Down Expand Up @@ -73,20 +69,21 @@ def make_tensor(
is determined based on the :attr:`dtype` (see the table above). Default: ``None``.
requires_grad (Optional[bool]): If autograd should record operations on the returned tensor. Default: ``False``.
noncontiguous (Optional[bool]): If `True`, the returned tensor will be noncontiguous. This argument is
ignored if the constructed tensor has fewer than two elements.
ignored if the constructed tensor has fewer than two elements. Mutually exclusive with ``memory_format``.
exclude_zero (Optional[bool]): If ``True`` then zeros are replaced with the dtype's small positive value
depending on the :attr:`dtype`. For bool and integer types zero is replaced with one. For floating
point types it is replaced with the dtype's smallest positive normal number (the "tiny" value of the
:attr:`dtype`'s :func:`~torch.finfo` object), and for complex types it is replaced with a complex number
whose real and imaginary parts are both the smallest positive normal number representable by the complex
type. Default ``False``.
memory_format (Optional[torch.memory_format]): The memory format of the returned tensor. Incompatible
with :attr:`noncontiguous`.
memory_format (Optional[torch.memory_format]): The memory format of the returned tensor. Mutually exclusive
with ``noncontiguous``.

Raises:
ValueError: if ``requires_grad=True`` is passed for integral `dtype`
ValueError: If ``requires_grad=True`` is passed for integral `dtype`
ValueError: If ``low > high``.
ValueError: If either :attr:`low` or :attr:`high` is ``nan``.
ValueError: If both :attr:`noncontiguous` and :attr:`memory_format` are passed.
TypeError: If :attr:`dtype` isn't supported by this function.

Examples:
Expand All @@ -103,25 +100,37 @@ def make_tensor(
[False, True]], device='cuda:0')
"""

def _modify_low_high(low, high, lowest, highest, default_low, default_high, dtype):
def modify_low_high(
low: Optional[float],
high: Optional[float],
*,
lowest_inclusive: float,
highest_exclusive: float,
default_low: float,
default_high: float,
) -> Tuple[float, float]:
"""
Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high) if required.
Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high)
if required.
"""

def clamp(a, l, h):
def clamp(a: float, l: float, h: float) -> float:
return min(max(a, l), h)

low = low if low is not None else default_low
high = high if high is not None else default_high

# Checks for error cases
if low != low or high != high:
raise ValueError("make_tensor: one of low or high was NaN!")
if low > high:
raise ValueError("make_tensor: low must be weakly less than high!")
if math.isnan(low) or math.isnan(high):
raise ValueError(
f"`low` and `high` cannot be NaN, but got {low=} and {high=}"
)
elif low > high:
raise ValueError(
f"`low` must be weakly less than `high`, but got {low} >= {high}"
)

low = clamp(low, lowest, highest)
high = clamp(high, lowest, highest)
low = clamp(low, lowest_inclusive, highest_exclusive)
high = clamp(high, lowest_inclusive, highest_exclusive)

if dtype in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
return math.floor(low), math.ceil(high)
Expand All @@ -132,70 +141,65 @@ def clamp(a, l, h):
shape = shape[0] # type: ignore[assignment]
shape = cast(Tuple[int, ...], tuple(shape))

_integral_types = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
_floating_types = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
_complex_types = [torch.complex32, torch.complex64, torch.complex128]
if requires_grad and dtype not in _floating_types and dtype not in _complex_types:
raise ValueError("make_tensor: requires_grad must be False for integral dtype")
if noncontiguous and memory_format is not None:
raise ValueError(
f"The parameters `noncontiguous` and `memory_format` are mutually exclusive, "
f"but got {noncontiguous=} and {memory_format=}"
)

if requires_grad and dtype in _BOOLEAN_OR_INTEGRAL_TYPES:
raise ValueError(
f"`requires_grad=True` is not supported for boolean and integral dtypes, but got {dtype=}"
)

if dtype is torch.bool:
result = torch.randint(0, 2, shape, device=device, dtype=dtype) # type: ignore[call-overload]
elif dtype is torch.uint8:
ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max)
result = torch.randint(0, 2, shape, device=device, dtype=dtype)
elif dtype in _INTEGRAL_TYPES:
low, high = cast(
Tuple[int, int],
_modify_low_high(low, high, ranges[0], ranges[1], 0, 10, dtype),
modify_low_high(
low,
high,
lowest_inclusive=torch.iinfo(dtype).min,
highest_exclusive=torch.iinfo(dtype).max,
# This is incorrect for `torch.uint8`, but since we clamp to `lowest`, i.e. 0 for `torch.uint8`,
# _after_ we use the default value, we don't need to special case it here
default_low=-9,
default_high=10,
),
)
result = torch.randint(low, high, shape, device=device, dtype=dtype) # type: ignore[call-overload]
elif dtype in _integral_types:
ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max)
low, high = _modify_low_high(low, high, ranges[0], ranges[1], -9, 10, dtype)
result = torch.randint(low, high, shape, device=device, dtype=dtype) # type: ignore[call-overload]
elif dtype in _floating_types:
ranges_floats = (torch.finfo(dtype).min, torch.finfo(dtype).max)
m_low, m_high = _modify_low_high(
low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype
result = torch.randint(low, high, shape, device=device, dtype=dtype)
elif dtype in _FLOATING_OR_COMPLEX_TYPES:
low, high = modify_low_high(
low,
high,
lowest_inclusive=torch.finfo(dtype).min,
highest_exclusive=torch.finfo(dtype).max,
default_low=-9,
default_high=9,
)
result = torch.empty(shape, device=device, dtype=dtype)
_uniform_random(result, m_low, m_high)
elif dtype in _complex_types:
float_dtype = complex_to_corresponding_float_type_map[dtype]
ranges_floats = (torch.finfo(float_dtype).min, torch.finfo(float_dtype).max)
m_low, m_high = _modify_low_high(
low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype
_uniform_random_(
torch.view_as_real(result) if dtype in _COMPLEX_TYPES else result, low, high
)
result = torch.empty(shape, device=device, dtype=dtype)
result_real = torch.view_as_real(result)
_uniform_random(result_real, m_low, m_high)
else:
raise TypeError(
f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()."
" To request support, file an issue at: https://github.com/pytorch/pytorch/issues"
)

assert not (noncontiguous and memory_format is not None)
if noncontiguous and result.numel() > 1:
result = torch.repeat_interleave(result, 2, dim=-1)
result = result[..., ::2]
elif memory_format is not None:
result = result.clone(memory_format=memory_format)

if exclude_zero:
if dtype in _integral_types or dtype is torch.bool:
replace_with = torch.tensor(1, device=device, dtype=dtype)
elif dtype in _floating_types:
replace_with = torch.tensor(
torch.finfo(dtype).tiny, device=device, dtype=dtype
)
else: # dtype in _complex_types:
float_dtype = complex_to_corresponding_float_type_map[dtype]
float_eps = torch.tensor(
torch.finfo(float_dtype).tiny, device=device, dtype=float_dtype
)
replace_with = torch.complex(float_eps, float_eps)
result[result == 0] = replace_with
result[result == 0] = (
1 if dtype in _BOOLEAN_OR_INTEGRAL_TYPES else torch.finfo(dtype).tiny
)

if dtype in _floating_types + _complex_types:
if dtype in _FLOATING_OR_COMPLEX_TYPES:
result.requires_grad = requires_grad

return result
7 changes: 7 additions & 0 deletions torch/testing/_internal/common_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,10 @@ def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dt

def get_all_qint_dtypes() -> List[torch.dtype]:
return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]


float_to_corresponding_complex_type_map = {
torch.float16: torch.complex32,
torch.float32: torch.complex64,
torch.float64: torch.complex128,
}