Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
5eedd15
clamp_min/max support
jjsjann123 Jun 17, 2022
98a0fb7
clamp_min/max support
jjsjann123 Jun 17, 2022
fde6ac0
clamp_min/max support
jjsjann123 Jun 17, 2022
cce1c2c
patching tests
jjsjann123 Jun 17, 2022
60a739a
patching tests
jjsjann123 Jun 17, 2022
1ae170b
patching tests
jjsjann123 Jun 17, 2022
37aadfd
patching tests
jjsjann123 Jun 17, 2022
02036c3
patching tests
jjsjann123 Jun 17, 2022
00de23b
patching tests
jjsjann123 Jun 17, 2022
a3aa2ee
patching tests
jjsjann123 Jun 17, 2022
cbf5773
patching tests
jjsjann123 Jun 17, 2022
6a3c01b
patching tests
jjsjann123 Jun 17, 2022
d83f86d
patching tests
jjsjann123 Jun 17, 2022
5c30ed7
patching tests
jjsjann123 Jun 17, 2022
09bd9fb
patching tests
jjsjann123 Jun 17, 2022
2e385fb
Merge remote-tracking branch 'jiej/master' into HEAD
jjsjann123 Jun 22, 2022
9ceec45
updating from using minimum/maximum to where for gradient behavior
jjsjann123 Jun 22, 2022
df52486
updating clamp implementation
jjsjann123 Jun 22, 2022
4f4dd5b
fixing clamp
jjsjann123 Jun 22, 2022
58ccd36
fixing nan propagation for where
jjsjann123 Jun 22, 2022
e37c83f
fixing nan propagation for boundaries
jjsjann123 Jun 22, 2022
b308cba
fixing not
jjsjann123 Jun 22, 2022
b2590c3
not -> bitwise_not
jjsjann123 Jun 22, 2022
307f897
not caching isnan
jjsjann123 Jun 22, 2022
ba4b1e8
new line
jjsjann123 Jun 22, 2022
d3891c4
removing commented code
jjsjann123 Jun 22, 2022
1f0dc2e
try to put back test_dtypes?
jjsjann123 Jun 22, 2022
3a2f350
removing commented code
jjsjann123 Jun 22, 2022
1f7b279
Revert "removing commented code"
jjsjann123 Jun 22, 2022
0c46f96
fixing syntax
jjsjann123 Jun 22, 2022
7465190
clamp_min/max is not autodiffable
jjsjann123 Jun 22, 2022
dc53335
clamp_min/max doesn't accept scalar tensor
jjsjann123 Jun 23, 2022
e577fd1
clamp_min/max doesn't accept scalar tensor
jjsjann123 Jun 23, 2022
bb80a32
clamp_min/max disables type promotion test
jjsjann123 Jun 23, 2022
7153bde
clamp_min/max dispatch to lazy failed
jjsjann123 Jun 23, 2022
105d41e
clamp_min/max disable error tests
jjsjann123 Jun 23, 2022
1a25952
Merge commit '3afc802c5a5111' into HEAD
jjsjann123 Jun 24, 2022
1476493
code cleaning/refactoring
jjsjann123 Jun 24, 2022
3d34adc
errr, typo
jjsjann123 Jun 24, 2022
4bbf33c
lintrunner
jjsjann123 Jun 24, 2022
affe449
Merge commit '4331bc436ea' into HEAD
jjsjann123 Jun 28, 2022
a1272bc
skip -> expected failure
jjsjann123 Jun 28, 2022
1192aab
updating expected failure
jjsjann123 Jun 28, 2022
02f0128
updating supported dtypes
jjsjann123 Jun 28, 2022
a741bcd
typo
jjsjann123 Jun 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,15 +1183,49 @@ def clamp(
) -> TensorLikeType:
a, min, max = _maybe_broadcast(a, min, max)

if min is not None and max is not None:
return minimum(maximum(a, min), max)
# NOTE: grad behavior with implementation `where` is not consistent on `nan`
if min is None and max is None:
msg = "clamp called but both min and max are none!"
raise ValueError(msg)
if min is not None:
return maximum(a, min)
a_isnan = isnan(a)
condition = bitwise_or(ge(a, min), a_isnan)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Highlight this section for nan propagation. Tagging @ngimel

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah looks correct

# we should also propagate `nan` coming from boundaries. However, that's
# not necessary since `ge` would already `False` when either operands has
# a `nan`. So this line below is redundant
# `condition = bitwise_and(condition, bitwise_not(isnan(min)))`
a = prims.where(condition, a, min)
if max is not None:
return minimum(a, max)
a_isnan = isnan(a)
# same as above, no need to adjust `nan` from `max`
condition = bitwise_or(le(a, max), a_isnan)
a = prims.where(condition, a, max)

return a


@out_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("self", "min"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def clamp_min(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these use elementwise binary helper?

def _make_elementwise_binary_reference(

I think it would take care of out and type promotion and _maybe_broadcast?

self: TensorLikeType,
min: Optional[TensorOrNumberLikeType] = None,
) -> TensorLikeType:
return clamp(self, min=min)


msg = "clamp called but both min and max are none!"
raise ValueError(msg)
@out_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("self", "max"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def clamp_max(
self: TensorLikeType,
max: Optional[TensorOrNumberLikeType] = None,
) -> TensorLikeType:
return clamp(self, max=max)


#
Expand Down
62 changes: 62 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7506,6 +7506,14 @@ def reference_inputs_elementwise_ternary(op, device, dtype, requires_grad, *, sa
yield SampleInput(a, args=(b, c))


def _clamp_min_numpy(a, min=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice references

return np.maximum(a, min)


def _clamp_max_numpy(a, max=None):
return np.minimum(a, max)


def _clamp_numpy(a, min=None, max=None):
if min is None:
return np.minimum(a, max)
Expand Down Expand Up @@ -10619,6 +10627,42 @@ def error_inputs_mean(op_info, device, **kwargs):
'test_reference_numerics_extremal_values',
dtypes=(torch.complex64, torch.complex128)),
)),
BinaryUfuncInfo('clamp_max',
ref=_clamp_max_numpy,
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_rhs_python_scalar=False,
supports_fwgrad_bwgrad=True,
rhs_make_tensor_kwargs=dict(exclude_zero=False),
skips=(
# RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat'
DecorateInfo(unittest.expectedFailure,
'TestBinaryUfuncs',
'test_type_promotion',
device_type='cuda'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this just xfail on the complex dtypes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, no... that test isn't instantiated for multiple dtypes, I think. My mistake.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. There's a single test there test_type_promotion_clamp_max_cuda (__main__.TestBinaryUfuncsCUDA) ... expected failure.

I checked a few other places where similar expectedFailure is placed and I think we are good this time 🤞

# dispatch to lazy test failed
DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
# test error disabled since rhs non-tensor python scalar is supported
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'),
)),
BinaryUfuncInfo('clamp_min',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great OpInfo additions

ref=_clamp_min_numpy,
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_rhs_python_scalar=False,
supports_fwgrad_bwgrad=True,
rhs_make_tensor_kwargs=dict(exclude_zero=False),
skips=(
# RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat'
DecorateInfo(unittest.expectedFailure,
'TestBinaryUfuncs',
'test_type_promotion',
device_type='cuda'),
# dispatch to lazy test failed
DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
# test error disabled since rhs non-tensor python scalar is supported
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'),
)),
BinaryUfuncInfo('mul',
aliases=('multiply',),
dtypes=all_types_and_complex_and(torch.chalf, torch.float16, torch.bfloat16, torch.bool),
Expand Down Expand Up @@ -20660,6 +20704,24 @@ def __init__(
#
# Elementwise Ternary Reference OpInfos
#
ElementwiseBinaryPythonRefInfo(
"_refs.clamp_min",
torch_opinfo_name="clamp_min",
supports_nvfuser=False,
skips=(
# test error disabled since rhs non-tensor python scalar is supported
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same expectedFailure is under _refs.minimum & _refs.maximum, which doesn't have a comment explaining why.

Since our ref uses those as well, should be the same root cause.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK but what is the root cause? (Although this is probably moot because of the proposed switch to where above)

),
),
ElementwiseBinaryPythonRefInfo(
"_refs.clamp_max",
torch_opinfo_name="clamp_max",
supports_nvfuser=False,
skips=(
# test error disabled since rhs non-tensor python scalar is supported
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
),
),
PythonRefInfo(
"_refs.clamp",
torch_opinfo_name="clamp",
Expand Down