-
Notifications
You must be signed in to change notification settings - Fork 25.6k
QoL improvements for torch.testing.make_tensor #96125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
a17f8c3
QoL improvements for torch.testing.make_tensor
pmeier 5714c22
Update on "QoL improvements for torch.testing.make_tensor"
pmeier 273e2e8
Update on "QoL improvements for torch.testing.make_tensor"
pmeier 5b86220
Update on "QoL improvements for torch.testing.make_tensor"
pmeier cb36212
Update on "QoL improvements for torch.testing.make_tensor"
pmeier 7c7f31a
Update on "QoL improvements for torch.testing.make_tensor"
pmeier e706e47
Update on "QoL improvements for torch.testing.make_tensor"
pmeier b27ba07
Update on "QoL improvements for torch.testing.make_tensor"
pmeier 0621e4b
Update on "QoL improvements for torch.testing.make_tensor"
pmeier eaa1141
Update on "QoL improvements for torch.testing.make_tensor"
pmeier File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,18 +8,14 @@ | |
|
||
import torch | ||
|
||
# Used by make_tensor for generating complex tensor. | ||
complex_to_corresponding_float_type_map = { | ||
torch.complex32: torch.float16, | ||
torch.complex64: torch.float32, | ||
torch.complex128: torch.float64, | ||
} | ||
float_to_corresponding_complex_type_map = { | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
v: k for k, v in complex_to_corresponding_float_type_map.items() | ||
} | ||
|
||
|
||
def _uniform_random(t: torch.Tensor, low: float, high: float): | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_INTEGRAL_TYPES = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] | ||
_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64] | ||
_COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128] | ||
_BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES] | ||
_FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES] | ||
Comment on lines
+11
to
+15
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Defining them here do avoid doing it over and over again inside |
||
|
||
|
||
def _uniform_random_(t: torch.Tensor, low: float, high: float) -> torch.Tensor: | ||
# uniform_ requires to-from <= std::numeric_limits<scalar_t>::max() | ||
# Work around this by scaling the range before and after the PRNG | ||
if high - low >= torch.finfo(t.dtype).max: | ||
|
@@ -73,20 +69,21 @@ def make_tensor( | |
is determined based on the :attr:`dtype` (see the table above). Default: ``None``. | ||
requires_grad (Optional[bool]): If autograd should record operations on the returned tensor. Default: ``False``. | ||
noncontiguous (Optional[bool]): If `True`, the returned tensor will be noncontiguous. This argument is | ||
ignored if the constructed tensor has fewer than two elements. | ||
ignored if the constructed tensor has fewer than two elements. Mutually exclusive with ``memory_format``. | ||
exclude_zero (Optional[bool]): If ``True`` then zeros are replaced with the dtype's small positive value | ||
depending on the :attr:`dtype`. For bool and integer types zero is replaced with one. For floating | ||
point types it is replaced with the dtype's smallest positive normal number (the "tiny" value of the | ||
:attr:`dtype`'s :func:`~torch.finfo` object), and for complex types it is replaced with a complex number | ||
whose real and imaginary parts are both the smallest positive normal number representable by the complex | ||
type. Default ``False``. | ||
memory_format (Optional[torch.memory_format]): The memory format of the returned tensor. Incompatible | ||
with :attr:`noncontiguous`. | ||
memory_format (Optional[torch.memory_format]): The memory format of the returned tensor. Mutually exclusive | ||
with ``noncontiguous``. | ||
|
||
Raises: | ||
ValueError: if ``requires_grad=True`` is passed for integral `dtype` | ||
ValueError: If ``requires_grad=True`` is passed for integral `dtype` | ||
ValueError: If ``low > high``. | ||
ValueError: If either :attr:`low` or :attr:`high` is ``nan``. | ||
ValueError: If both :attr:`noncontiguous` and :attr:`memory_format` are passed. | ||
TypeError: If :attr:`dtype` isn't supported by this function. | ||
|
||
Examples: | ||
|
@@ -103,25 +100,37 @@ def make_tensor( | |
[False, True]], device='cuda:0') | ||
""" | ||
|
||
def _modify_low_high(low, high, lowest, highest, default_low, default_high, dtype): | ||
def modify_low_high( | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
low: Optional[float], | ||
high: Optional[float], | ||
*, | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
lowest_inclusive: float, | ||
highest_exclusive: float, | ||
default_low: float, | ||
default_high: float, | ||
) -> Tuple[float, float]: | ||
""" | ||
Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high) if required. | ||
Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high) | ||
if required. | ||
""" | ||
|
||
def clamp(a, l, h): | ||
def clamp(a: float, l: float, h: float) -> float: | ||
return min(max(a, l), h) | ||
|
||
low = low if low is not None else default_low | ||
high = high if high is not None else default_high | ||
|
||
# Checks for error cases | ||
if low != low or high != high: | ||
raise ValueError("make_tensor: one of low or high was NaN!") | ||
if low > high: | ||
raise ValueError("make_tensor: low must be weakly less than high!") | ||
if math.isnan(low) or math.isnan(high): | ||
raise ValueError( | ||
f"`low` and `high` cannot be NaN, but got {low=} and {high=}" | ||
) | ||
elif low > high: | ||
raise ValueError( | ||
f"`low` must be weakly less than `high`, but got {low} >= {high}" | ||
) | ||
|
||
low = clamp(low, lowest, highest) | ||
high = clamp(high, lowest, highest) | ||
low = clamp(low, lowest_inclusive, highest_exclusive) | ||
high = clamp(high, lowest_inclusive, highest_exclusive) | ||
|
||
if dtype in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]: | ||
return math.floor(low), math.ceil(high) | ||
|
@@ -132,70 +141,65 @@ def clamp(a, l, h): | |
shape = shape[0] # type: ignore[assignment] | ||
shape = cast(Tuple[int, ...], tuple(shape)) | ||
|
||
_integral_types = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] | ||
_floating_types = [torch.float16, torch.bfloat16, torch.float32, torch.float64] | ||
_complex_types = [torch.complex32, torch.complex64, torch.complex128] | ||
if requires_grad and dtype not in _floating_types and dtype not in _complex_types: | ||
raise ValueError("make_tensor: requires_grad must be False for integral dtype") | ||
if noncontiguous and memory_format is not None: | ||
raise ValueError( | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"The parameters `noncontiguous` and `memory_format` are mutually exclusive, " | ||
f"but got {noncontiguous=} and {memory_format=}" | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
if requires_grad and dtype in _BOOLEAN_OR_INTEGRAL_TYPES: | ||
raise ValueError( | ||
f"`requires_grad=True` is not supported for boolean and integral dtypes, but got {dtype=}" | ||
) | ||
|
||
if dtype is torch.bool: | ||
result = torch.randint(0, 2, shape, device=device, dtype=dtype) # type: ignore[call-overload] | ||
elif dtype is torch.uint8: | ||
ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max) | ||
result = torch.randint(0, 2, shape, device=device, dtype=dtype) | ||
elif dtype in _INTEGRAL_TYPES: | ||
low, high = cast( | ||
Tuple[int, int], | ||
_modify_low_high(low, high, ranges[0], ranges[1], 0, 10, dtype), | ||
modify_low_high( | ||
low, | ||
high, | ||
lowest_inclusive=torch.iinfo(dtype).min, | ||
highest_exclusive=torch.iinfo(dtype).max, | ||
# This is incorrect for `torch.uint8`, but since we clamp to `lowest`, i.e. 0 for `torch.uint8`, | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# _after_ we use the default value, we don't need to special case it here | ||
default_low=-9, | ||
default_high=10, | ||
), | ||
) | ||
result = torch.randint(low, high, shape, device=device, dtype=dtype) # type: ignore[call-overload] | ||
elif dtype in _integral_types: | ||
ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max) | ||
low, high = _modify_low_high(low, high, ranges[0], ranges[1], -9, 10, dtype) | ||
result = torch.randint(low, high, shape, device=device, dtype=dtype) # type: ignore[call-overload] | ||
elif dtype in _floating_types: | ||
ranges_floats = (torch.finfo(dtype).min, torch.finfo(dtype).max) | ||
m_low, m_high = _modify_low_high( | ||
low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype | ||
result = torch.randint(low, high, shape, device=device, dtype=dtype) | ||
elif dtype in _FLOATING_OR_COMPLEX_TYPES: | ||
low, high = modify_low_high( | ||
low, | ||
high, | ||
lowest_inclusive=torch.finfo(dtype).min, | ||
highest_exclusive=torch.finfo(dtype).max, | ||
default_low=-9, | ||
default_high=9, | ||
) | ||
result = torch.empty(shape, device=device, dtype=dtype) | ||
_uniform_random(result, m_low, m_high) | ||
elif dtype in _complex_types: | ||
float_dtype = complex_to_corresponding_float_type_map[dtype] | ||
ranges_floats = (torch.finfo(float_dtype).min, torch.finfo(float_dtype).max) | ||
m_low, m_high = _modify_low_high( | ||
low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype | ||
_uniform_random_( | ||
torch.view_as_real(result) if dtype in _COMPLEX_TYPES else result, low, high | ||
) | ||
result = torch.empty(shape, device=device, dtype=dtype) | ||
result_real = torch.view_as_real(result) | ||
_uniform_random(result_real, m_low, m_high) | ||
else: | ||
raise TypeError( | ||
f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()." | ||
" To request support, file an issue at: https://github.com/pytorch/pytorch/issues" | ||
) | ||
|
||
assert not (noncontiguous and memory_format is not None) | ||
if noncontiguous and result.numel() > 1: | ||
result = torch.repeat_interleave(result, 2, dim=-1) | ||
result = result[..., ::2] | ||
elif memory_format is not None: | ||
result = result.clone(memory_format=memory_format) | ||
|
||
if exclude_zero: | ||
if dtype in _integral_types or dtype is torch.bool: | ||
replace_with = torch.tensor(1, device=device, dtype=dtype) | ||
elif dtype in _floating_types: | ||
replace_with = torch.tensor( | ||
torch.finfo(dtype).tiny, device=device, dtype=dtype | ||
) | ||
else: # dtype in _complex_types: | ||
float_dtype = complex_to_corresponding_float_type_map[dtype] | ||
float_eps = torch.tensor( | ||
torch.finfo(float_dtype).tiny, device=device, dtype=float_dtype | ||
) | ||
replace_with = torch.complex(float_eps, float_eps) | ||
result[result == 0] = replace_with | ||
result[result == 0] = ( | ||
1 if dtype in _BOOLEAN_OR_INTEGRAL_TYPES else torch.finfo(dtype).tiny | ||
) | ||
|
||
if dtype in _floating_types + _complex_types: | ||
if dtype in _FLOATING_OR_COMPLEX_TYPES: | ||
result.requires_grad = requires_grad | ||
|
||
return result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need that since
torch.finfo
works for complex types as well.