Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add amax/amin #43092

Closed
wants to merge 29 commits into from
Closed

Add amax/amin #43092

wants to merge 29 commits into from

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Aug 15, 2020

Add a max/min operator that only return values.

Some important decision to discuss

Question Current State
Expose torch.max_values to python? No
Remove max_values and only keep amax? Yes
Should amax support named tensors? Not in this PR

Numpy compatibility

Reference: https://numpy.org/doc/stable/reference/generated/numpy.amax.html

Parameter PyTorch Behavior
axis: None or int or tuple of ints, optional. Axis or axes along which to operate. By default, flattened input is used. If this is a tuple of ints, the maximum is selected over multiple axes, instead of a single axis or all the axes as before. Named dim, behavior same as torch.sum (#29137)
out: ndarray, optional. Alternative output array in which to place the result. Must be of the same shape and buffer length as the expected output. Same
keepdims: bool, optional. If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array. implemented as keepdim
initial: scalar, optional. The minimum value of an output element. Must be present to allow computation on empty slice. Not implemented in this PR. Better to implement for all reductions in the future.
where: array_like of bool, optional. Elements to compare for the maximum. Not implemented in this PR. Better to implement for all reductions in the future.

Note from numpy:

NaN values are propagated, that is if at least one item is NaN, the corresponding max value will be NaN as well. To ignore NaN values (MATLAB behavior), please use nanmax.

PyTorch has the same behavior

@dr-ci
Copy link

dr-ci bot commented Aug 15, 2020

💊 CI failures summary and remediations

As of commit 422e870 (more details on the Dr. CI page):



🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test2 (1/2)

Step: "Attaching workspace" (full log | diagnosis details | 🔁 rerun)

Downloading workspace layers
Downloading workspace layers
  workflows/workspaces/68841824-5f4a-4015-b835-20e9672dcb2d/0/0e1964d6-c16b-4de2-bca3-7ee68c84622e/0/105.tar.gz - 8.4 MB
Applying workspace layers
  0e1964d6-c16b-4de2-bca3-7ee68c84622e

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test1 (2/2)

Step: "Attaching workspace" (full log | diagnosis details | 🔁 rerun)

Downloading workspace layers
Downloading workspace layers
  workflows/workspaces/68841824-5f4a-4015-b835-20e9672dcb2d/0/0e1964d6-c16b-4de2-bca3-7ee68c84622e/0/105.tar.gz - 8.4 MB
Applying workspace layers
  0e1964d6-c16b-4de2-bca3-7ee68c84622e

❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun) ❄️

Aug 27 09:30:36 ConnectionResetError: [Errno 104] Connection reset by peer
Aug 27 09:30:36   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 455, in accept 
Aug 27 09:30:36     deliver_challenge(c, self._authkey) 
Aug 27 09:30:36   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 722, in deliver_challenge 
Aug 27 09:30:36     response = connection.recv_bytes(256)        # reject large message 
Aug 27 09:30:36   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes 
Aug 27 09:30:36     buf = self._recv_bytes(maxlength) 
Aug 27 09:30:36   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes 
Aug 27 09:30:36     buf = self._recv(4) 
Aug 27 09:30:36   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 379, in _recv 
Aug 27 09:30:36     chunk = read(handle, remaining) 
Aug 27 09:30:36 ConnectionResetError: [Errno 104] Connection reset by peer 
Aug 27 09:30:36 /opt/conda/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 14 leaked semaphores to clean up at shutdown 
Aug 27 09:30:36   len(cache)) 
Aug 27 09:30:39 Process ErrorTrackingProcess-156: 
Aug 27 09:30:39 Traceback (most recent call last): 
Aug 27 09:30:39   File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap 
Aug 27 09:30:39     self.run() 
Aug 27 09:30:39   File "/var/lib/jenkins/workspace/test/test_dataloader.py", line 361, in run 
Aug 27 09:30:39     super(ErrorTrackingProcess, self).run() 
Aug 27 09:30:39   File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 93, in run 
Aug 27 09:30:39     self._target(*self._args, **self._kwargs) 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 109 times.

@mruberry
Copy link
Collaborator

Thinking seems good, but PRs introducing new functions typically do add the out variants and derivatives, and they should be straightforward to include.

I do have a question, though: What's the difference between min_values (max_values) and amin (amax)? Should we leave min_values and max_values unchanged and deprecate them?

@zasdfgbnm zasdfgbnm changed the title [WIP]Add amax/amin and max_values and min_values [WIP]Add amax/amin Aug 19, 2020
@zasdfgbnm zasdfgbnm changed the title [WIP]Add amax/amin Add amax/amin Aug 21, 2020
@zasdfgbnm zasdfgbnm marked this pull request as ready for review August 21, 2020 00:33
@mruberry mruberry self-requested a review August 25, 2020 23:49
@mruberry
Copy link
Collaborator

Overall looking good but you need to update tensors.rst and torch.rst with the new doc entries.

@zasdfgbnm
Copy link
Collaborator Author

@mruberry Yes, I missed that point. Added now.

@@ -336,6 +336,10 @@ Reduction Ops

argmax
argmin
amax
amin
max
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a better place for max and min but would you please remove them from the "Comparison Ops," too, then? People should be using maximum and minimum, anyway, for comparisons.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last nit about removing the redundant max and min torch.rst entries under Comparison Ops.

Other than that I am satisfied. Let's let @ngimel make her final determination, too.

aten/src/ATen/native/cpu/ReduceOpsKernel.cpp Show resolved Hide resolved
r"""
amax(input, dim, keepdim=False, *, out=None) -> Tensor

Returns the maximum value of each row of the :attr:`input` tensor in the given
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that there are potentially multiple dimensions, it's slice, not row

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

The difference between ``max``/``min`` and ``amax``/``amin`` is:
1. ``amax``/``amin`` supports reducing on multiple dimensions,
2. ``amax``/``amin`` does not return indices,
3. ``amax``/``amin`` produces deterministic (sub)gradients unlike
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max and min will produce deterministic (but different) gradients soon, so it makes sense to say that amax evenly distributes gradient between equal values, while max propagates gradient only to a single index in the source tensor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@zasdfgbnm
Copy link
Collaborator Author

Just fixed the failures in doc build and doc test.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mruberry
Copy link
Collaborator

@zasdfgbnm Would you rebase this?

@zasdfgbnm
Copy link
Collaborator Author

@mruberry rebased

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in bcec8cc.

vkuzo added a commit that referenced this pull request Sep 2, 2020
Summary:

This is to align with the naming in numpy and in
#43092

Test Plan:

```
python test/test_torch.py TestTorchDeviceTypeCPU.test_aminmax_cpu_float32
python test/test_torch.py TestTorchDeviceTypeCUDA.test_aminmax_cuda_float32
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 2, 2020
Summary:

This is to align with the naming in numpy and in
#43092

Test Plan:

```
python test/test_torch.py TestTorchDeviceTypeCPU.test_aminmax_cpu_float32
python test/test_torch.py TestTorchDeviceTypeCUDA.test_aminmax_cuda_float32
```

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23465298](https://our.internmc.facebook.com/intern/diff/D23465298)

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Sep 3, 2020
Summary:
Pull Request resolved: #44001

This is to align with the naming in numpy and in
#43092

Test Plan:
```
python test/test_torch.py TestTorchDeviceTypeCPU.test_aminmax_cpu_float32
python test/test_torch.py TestTorchDeviceTypeCUDA.test_aminmax_cuda_float32
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D23465298

fbshipit-source-id: b599035507156cefa53942db05f93242a21c8d06
@mruberry
Copy link
Collaborator

This was reverted because it triggered a preexisting CUDA bug, relevant log snippet:

test_blas_empty_cuda - TestTorchDeviceTypeCUDA
test_torch.py

Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 815, in wrapper
    method(*args, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 815, in wrapper
    method(*args, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 257, in instantiated_test
    result = test(self, *args)
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 479, in dep_fn
    return fn(slf, device, *args, **kwargs)
  File "test_torch.py", line 13564, in test_blas_empty
    self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6)))
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1124, in assertEqual
    self.assertTrue(result, msg=msg)
AssertionError: False is not true : Tensors failed to compare as equal! With rtol=1.3e-06 and atol=1e-05, found 10 element(s) (out of 30) whose difference(s) exceeded the margin of error (including 10 nan comparisons). The greatest difference was nan (0.0 vs. nan), which occurred at index (1, 0).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants