Skip to content

Conversation

@jjsjann123
Copy link
Collaborator

Added reference implementation for the two ops;
Added opinfo tests for aten clamp_min/clamp_max;
Added opinfo reference test.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 17, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit a741bcd (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

torch_opinfo_name="clamp_min",
supports_nvfuser=False,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same expectedFailure is under _refs.minimum & _refs.maximum, which doesn't have a comment explaining why.

Since our ref uses those as well, should be the same root cause.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK but what is the root cause? (Although this is probably moot because of the proposed switch to where above)

) -> TensorLikeType:
self, min = _maybe_broadcast(self, min)

return maximum(self, min)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to preserve gradient behavior (clamp doesn't spread gradients when boundary and input are the same) it's better to use where

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! So we should also change clamp to where as well.

clamp doesn't spread gradients when boundary and input are the same

I think clamp does propagate gradient when input/bound equals. But the grad behavior is indeed different from minimum/maximum

In [31]: x = torch.ones(2, 2).requires_grad_()

In [32]: torch.clamp_max(x, torch.tensor(1.0)).sum().backward()

In [33]: x.grad
Out[33]:
tensor([[1., 1.],
        [1., 1.]])

In [34]: x = torch.ones(2, 2).requires_grad_()

In [35]: torch.minimum(x, torch.tensor(1.0)).sum().backward()

In [36]: x.grad
Out[36]:
tensor([[0.5000, 0.5000],
        [0.5000, 0.5000]])

yield SampleInput(a, args=(b, c))


def _clamp_min_numpy(a, min=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice references

skips=(
DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'),
)),
BinaryUfuncInfo('clamp_min',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great OpInfo additions

type_promoting_args=("self", "min"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def clamp_min(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these use elementwise binary helper?

def _make_elementwise_binary_reference(

I think it would take care of out and type promotion and _maybe_broadcast?

@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 21, 2022
@ngimel
Copy link
Collaborator

ngimel commented Jun 23, 2022

Yeah when there are 2 tensors, you can't have a single where and still reproduce correct nan propagation behavior, it requires a couple where's or maybe where and isnan

@jjsjann123
Copy link
Collaborator Author

Yeah when there are 2 tensors, you can't have a single where and still reproduce correct nan propagation behavior, it requires a couple where's or maybe where and isnan

Recording some offline discussion for my own sake.

  1. we want correct nan propagation. I need to fix that.
  2. we want extremal propagation properly tested, which is not in current opinfo tests I put there. (otherwise it should fail). double check that.
  3. correct grad backward behavior is just nice-to-have. don't worry too much about that. (since short term backward is traced at torch level, so we'll have clamp_xxx_backward). Having said that, we also shouldn't fallback to maximum/minimum implementation.

@jjsjann123
Copy link
Collaborator Author

Yeah when there are 2 tensors, you can't have a single where and still reproduce correct nan propagation behavior, it requires a couple where's or maybe where and isnan

I was dumb and forgot that nan propagation on forward has already been fixed few days ago. Did a quick refactor to remove redundant where/isnan. verified the test in opinfo for nan propagation. Put a note there on possible breaking gradient behavior.

This PR should be good to go for now. @ngimel @mruberry

if min is not None:
return maximum(a, min)
a_isnan = isnan(a)
condition = bitwise_or(ge(a, min), a_isnan)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Highlight this section for nan propagation. Tagging @ngimel

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah looks correct

supports_nvfuser=False,
skips=(
# test error disabled since rhs non-tensor python scalar is supported
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_errors'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Convert these skips to xfails so when the issue is fixed we know to enable the test

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool -- but swap the skips for xfails

DecorateInfo(unittest.expectedFailure,
'TestBinaryUfuncs',
'test_type_promotion',
device_type='cuda'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this just xfail on the complex dtypes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, no... that test isn't instantiated for multiple dtypes, I think. My mistake.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. There's a single test there test_type_promotion_clamp_max_cuda (__main__.TestBinaryUfuncsCUDA) ... expected failure.

I checked a few other places where similar expectedFailure is placed and I think we are good this time 🤞

rhs_make_tensor_kwargs=dict(exclude_zero=False),
skips=(
# clamp_max supports two tensor input with bool, but not a bool scalar
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some sample inputs failing for dtype shouldn't result in test failure, so what's going on here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was complaining about missing torch.bool and torch.float16 in dtypes.

======================================================================
FAIL: test_dtypes_clamp_max_cpu (__main__.TestCommonCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/opt/pytorch/pytorch/torch/testing/_internal/common_device_type.py", line 377, in instantiated_test
    result = test(self, **param_kwargs)
  File "/opt/pytorch/pytorch/torch/testing/_internal/common_device_type.py", line 786, in test_wrapper
    return test(*args, **kwargs)
  File "/opt/pytorch/pytorch/torch/testing/_internal/common_device_type.py", line 821, in dep_fn
    return fn(slf, *args, **kwargs)
  File "/opt/pytorch/pytorch/torch/testing/_internal/common_device_type.py", line 979, in only_fn
    return fn(self, *args, **kwargs)
  File "test_ops.py", line 314, in test_dtypes
    self.fail(msg)
AssertionError: The supported dtypes for clamp_max on device type cpu are incorrect!
The following dtypes worked in forward but are not listed by the OpInfo: {torch.bool, torch.float16}.
The following dtypes worked in backward but are not listed by the OpInfo: {torch.float16}.

The comment here # clamp_min supports two tensor input with bool, but not a bool scalar was referring to the failure on a different test when I add torch.bool in the supported dtype. (I think I also mistakenly set rhs_python_scalar=True then).

@jjsjann123
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

@jjsjann123 your PR has been successfully merged.

@github-actions
Copy link
Contributor

Hey @jjsjann123.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Jun 30, 2022
Summary:
Added reference implementation for the two ops;
Added opinfo tests for aten clamp_min/clamp_max;
Added opinfo reference test.

Pull Request resolved: #79821
Approved by: https://github.com/mruberry

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/c28315eab851b9d126457738d73deae0cccfc2bc

Reviewed By: b0noI

Differential Revision: D37523050

fbshipit-source-id: e0d72fadf88700b97a577d580ecd3cfb1034101c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants