-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[primtorch] add reference for clamp_min/clamp_max #79821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5eedd15
98a0fb7
fde6ac0
cce1c2c
60a739a
1ae170b
37aadfd
02036c3
00de23b
a3aa2ee
cbf5773
6a3c01b
d83f86d
5c30ed7
09bd9fb
2e385fb
9ceec45
df52486
4f4dd5b
58ccd36
e37c83f
b308cba
b2590c3
307f897
ba4b1e8
d3891c4
1f0dc2e
3a2f350
1f7b279
0c46f96
7465190
dc53335
e577fd1
bb80a32
7153bde
105d41e
1a25952
1476493
3d34adc
4bbf33c
affe449
a1272bc
1192aab
02f0128
a741bcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -1183,15 +1183,49 @@ def clamp( | |||
| ) -> TensorLikeType: | ||||
| a, min, max = _maybe_broadcast(a, min, max) | ||||
|
|
||||
| if min is not None and max is not None: | ||||
| return minimum(maximum(a, min), max) | ||||
| # NOTE: grad behavior with implementation `where` is not consistent on `nan` | ||||
| if min is None and max is None: | ||||
| msg = "clamp called but both min and max are none!" | ||||
| raise ValueError(msg) | ||||
| if min is not None: | ||||
| return maximum(a, min) | ||||
| a_isnan = isnan(a) | ||||
| condition = bitwise_or(ge(a, min), a_isnan) | ||||
| # we should also propagate `nan` coming from boundaries. However, that's | ||||
| # not necessary since `ge` would already `False` when either operands has | ||||
| # a `nan`. So this line below is redundant | ||||
| # `condition = bitwise_and(condition, bitwise_not(isnan(min)))` | ||||
| a = prims.where(condition, a, min) | ||||
| if max is not None: | ||||
| return minimum(a, max) | ||||
| a_isnan = isnan(a) | ||||
| # same as above, no need to adjust `nan` from `max` | ||||
| condition = bitwise_or(le(a, max), a_isnan) | ||||
| a = prims.where(condition, a, max) | ||||
|
|
||||
| return a | ||||
|
|
||||
|
|
||||
| @out_wrapper | ||||
| @elementwise_type_promotion_wrapper( | ||||
| type_promoting_args=("self", "min"), | ||||
| type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, | ||||
| ) | ||||
| def clamp_min( | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should these use elementwise binary helper? pytorch/torch/_refs/__init__.py Line 684 in 399b3dc
I think it would take care of |
||||
| self: TensorLikeType, | ||||
| min: Optional[TensorOrNumberLikeType] = None, | ||||
| ) -> TensorLikeType: | ||||
| return clamp(self, min=min) | ||||
|
|
||||
|
|
||||
| msg = "clamp called but both min and max are none!" | ||||
| raise ValueError(msg) | ||||
| @out_wrapper | ||||
| @elementwise_type_promotion_wrapper( | ||||
| type_promoting_args=("self", "max"), | ||||
| type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, | ||||
| ) | ||||
| def clamp_max( | ||||
| self: TensorLikeType, | ||||
| max: Optional[TensorOrNumberLikeType] = None, | ||||
| ) -> TensorLikeType: | ||||
| return clamp(self, max=max) | ||||
|
|
||||
|
|
||||
| # | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7506,6 +7506,14 @@ def reference_inputs_elementwise_ternary(op, device, dtype, requires_grad, *, sa | |
| yield SampleInput(a, args=(b, c)) | ||
|
|
||
|
|
||
| def _clamp_min_numpy(a, min=None): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice references |
||
| return np.maximum(a, min) | ||
|
|
||
|
|
||
| def _clamp_max_numpy(a, max=None): | ||
| return np.minimum(a, max) | ||
|
|
||
|
|
||
| def _clamp_numpy(a, min=None, max=None): | ||
| if min is None: | ||
| return np.minimum(a, max) | ||
|
|
@@ -10619,6 +10627,42 @@ def error_inputs_mean(op_info, device, **kwargs): | |
| 'test_reference_numerics_extremal_values', | ||
| dtypes=(torch.complex64, torch.complex128)), | ||
| )), | ||
| BinaryUfuncInfo('clamp_max', | ||
| ref=_clamp_max_numpy, | ||
| dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), | ||
| supports_forward_ad=True, | ||
| supports_rhs_python_scalar=False, | ||
| supports_fwgrad_bwgrad=True, | ||
| rhs_make_tensor_kwargs=dict(exclude_zero=False), | ||
| skips=( | ||
| # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' | ||
| DecorateInfo(unittest.expectedFailure, | ||
| 'TestBinaryUfuncs', | ||
| 'test_type_promotion', | ||
| device_type='cuda'), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this just xfail on the complex dtypes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, no... that test isn't instantiated for multiple dtypes, I think. My mistake. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No worries. There's a single test there I checked a few other places where similar expectedFailure is placed and I think we are good this time 🤞 |
||
| # dispatch to lazy test failed | ||
| DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), | ||
| # test error disabled since rhs non-tensor python scalar is supported | ||
| DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'), | ||
| )), | ||
| BinaryUfuncInfo('clamp_min', | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great OpInfo additions |
||
| ref=_clamp_min_numpy, | ||
| dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), | ||
| supports_forward_ad=True, | ||
| supports_rhs_python_scalar=False, | ||
| supports_fwgrad_bwgrad=True, | ||
| rhs_make_tensor_kwargs=dict(exclude_zero=False), | ||
| skips=( | ||
| # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' | ||
| DecorateInfo(unittest.expectedFailure, | ||
| 'TestBinaryUfuncs', | ||
| 'test_type_promotion', | ||
| device_type='cuda'), | ||
| # dispatch to lazy test failed | ||
| DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), | ||
| # test error disabled since rhs non-tensor python scalar is supported | ||
| DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'), | ||
| )), | ||
| BinaryUfuncInfo('mul', | ||
| aliases=('multiply',), | ||
| dtypes=all_types_and_complex_and(torch.chalf, torch.float16, torch.bfloat16, torch.bool), | ||
|
|
@@ -20660,6 +20704,24 @@ def __init__( | |
| # | ||
| # Elementwise Ternary Reference OpInfos | ||
| # | ||
| ElementwiseBinaryPythonRefInfo( | ||
| "_refs.clamp_min", | ||
| torch_opinfo_name="clamp_min", | ||
| supports_nvfuser=False, | ||
| skips=( | ||
| # test error disabled since rhs non-tensor python scalar is supported | ||
| DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same expectedFailure is under Since our ref uses those as well, should be the same root cause. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK but what is the root cause? (Although this is probably moot because of the proposed switch to |
||
| ), | ||
| ), | ||
| ElementwiseBinaryPythonRefInfo( | ||
| "_refs.clamp_max", | ||
| torch_opinfo_name="clamp_max", | ||
| supports_nvfuser=False, | ||
| skips=( | ||
| # test error disabled since rhs non-tensor python scalar is supported | ||
| DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), | ||
| ), | ||
| ), | ||
| PythonRefInfo( | ||
| "_refs.clamp", | ||
| torch_opinfo_name="clamp", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Highlight this section for
nanpropagation. Tagging @ngimelThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah looks correct