Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions test/test_unary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
suppress_warnings, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy,
gradcheck, IS_WINDOWS)
from torch.testing._internal.common_methods_invocations import (
unary_ufuncs)
unary_ufuncs, _NOTHING)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, ops, dtypes, onlyCPU, onlyOnCPUAndCUDA,
onlyCUDA, dtypesIfCUDA, precisionOverride, skipCUDAIfRocm, dtypesIfCPU,
Expand All @@ -25,6 +25,11 @@
if TEST_SCIPY:
import scipy

# Refer [scipy reference filter]
# Filter operators for which the reference function
# is available in the current environment (for reference_numerics tests).
reference_filtered_ops = list(filter(lambda op: op.ref is not _NOTHING, unary_ufuncs))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment explaining this filter


# Tests for unary "universal functions (ufuncs)" that accept a single
# tensor and have common properties like:
# - they are elementwise functions
Expand Down Expand Up @@ -311,14 +316,14 @@ def _helper_reference_numerics(expected, actual, msg, exact_dtype, equal_nan=Tru
# 1D tensors and a large 2D tensor with interesting and extremal values
# and discontiguities.
@suppress_warnings
@ops(unary_ufuncs)
@ops(reference_filtered_ops)
def test_reference_numerics_normal(self, device, dtype, op):
tensors = generate_numeric_tensors(device, dtype,
domain=op.domain)
self._test_reference_numerics(dtype, op, tensors)

@suppress_warnings
@ops(unary_ufuncs, allowed_dtypes=floating_and_complex_types_and(
@ops(reference_filtered_ops, allowed_dtypes=floating_and_complex_types_and(
torch.bfloat16, torch.half, torch.int8, torch.int16, torch.int32, torch.int64
))
def test_reference_numerics_hard(self, device, dtype, op):
Expand All @@ -330,7 +335,8 @@ def test_reference_numerics_hard(self, device, dtype, op):
self._test_reference_numerics(dtype, op, tensors)

@suppress_warnings
@ops(unary_ufuncs, allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half))
@ops(reference_filtered_ops,
allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half))
def test_reference_numerics_extremal(self, device, dtype, op):
handles_extremals = (op.handles_complex_extremals if
dtype in (torch.cfloat, torch.cdouble) else op.handles_extremals)
Expand Down
Loading