Skip to content

Add error message for complex alpha and non-complex inputs #54964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

zou3519
Copy link
Contributor

@zou3519 zou3519 commented Mar 30, 2021

Stack from ghstack:

Previously, the following would error out with a strange error message:

import torch
x=torch.randn(2)
torch.rsub(x, 1, alpha=2j)

Traceback (most recent call last)
<ipython-input-2-caf2a1c03d0b> in <module>
      1 import torch
      2 x=torch.randn(2)
----> 3 torch.rsub(x, 1, alpha=2j)

RuntimeError: value cannot be converted to type float without overflow: (-0,-2)

The reason why this is happening is because the alpha check doesn't check for if x is not complex and alpha is complex.
The error gets thrown further along in the implementation of torch.sub,
when it coerces alpha to be the same dtype as the input tensor:
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L53

This PR fixes the bad error message by adding a new check to the alpha check.

Test Plan:

  • pytest test/test_binary_ufuncs.py
  • NB: add, sub, and rsub all share the same alpha check. The test only tests it for torch.add, but that should be sufficient.

Differential Revision: D27504017

Previously, the following would error out with a strange error message:
```
import torch
x=torch.randn(2)
torch.rsub(x, 1, alpha=2j)

Traceback (most recent call last)
<ipython-input-2-caf2a1c03d0b> in <module>
      1 import torch
      2 x=torch.randn(2)
----> 3 torch.rsub(x, 1, alpha=2j)

RuntimeError: value cannot be converted to type float without overflow: (-0,-2)
```

The reason why this is happening is because the alpha check doesn't check for if `x` is not complex and `alpha` is complex.
The error gets thrown further along in the implementation of torch.sub,
when it coerces `alpha` to be the same dtype as the input tensor:
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L53

This PR fixes the bad error message by adding a new check to the alpha check.

Test Plan:
- pytest test/test_binary_ufuncs.py
- NB: add, sub, and rsub all share the same alpha check. The test only tests it for torch.add, but that should be sufficient.

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 30, 2021

💊 CI failures summary and remediations

As of commit 1733ae7 (more details on the Dr. CI page):


  • 3/3 failures possibly* introduced in this PR
    • 2/3 non-scanned failure(s)

1 failure not recognized by patterns:

Job Step Action
GitHub Actions test Unknown 🔁 rerun

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

zou3519 added a commit that referenced this pull request Mar 30, 2021
Previously, the following would error out with a strange error message:
```
import torch
x=torch.randn(2)
torch.rsub(x, 1, alpha=2j)

Traceback (most recent call last)
<ipython-input-2-caf2a1c03d0b> in <module>
      1 import torch
      2 x=torch.randn(2)
----> 3 torch.rsub(x, 1, alpha=2j)

RuntimeError: value cannot be converted to type float without overflow: (-0,-2)
```

The reason why this is happening is because the alpha check doesn't check for if `x` is not complex and `alpha` is complex.
The error gets thrown further along in the implementation of torch.sub,
when it coerces `alpha` to be the same dtype as the input tensor:
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L53

This PR fixes the bad error message by adding a new check to the alpha check.

Test Plan:
- pytest test/test_binary_ufuncs.py
- NB: add, sub, and rsub all share the same alpha check. The test only tests it for torch.add, but that should be sufficient.

ghstack-source-id: 61d5b74
Pull Request resolved: #54964
@zou3519 zou3519 requested a review from anjali411 March 30, 2021 16:13
Previously, the following would error out with a strange error message:
```
import torch
x=torch.randn(2)
torch.rsub(x, 1, alpha=2j)

Traceback (most recent call last)
<ipython-input-2-caf2a1c03d0b> in <module>
      1 import torch
      2 x=torch.randn(2)
----> 3 torch.rsub(x, 1, alpha=2j)

RuntimeError: value cannot be converted to type float without overflow: (-0,-2)
```

The reason why this is happening is because the alpha check doesn't check for if `x` is not complex and `alpha` is complex.
The error gets thrown further along in the implementation of torch.sub,
when it coerces `alpha` to be the same dtype as the input tensor:
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L53

This PR fixes the bad error message by adding a new check to the alpha check.

Test Plan:
- pytest test/test_binary_ufuncs.py
- NB: add, sub, and rsub all share the same alpha check. The test only tests it for torch.add, but that should be sufficient.

Differential Revision: [D27504017](https://our.internmc.facebook.com/intern/diff/D27504017)

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Apr 7, 2021
Previously, the following would error out with a strange error message:
```
import torch
x=torch.randn(2)
torch.rsub(x, 1, alpha=2j)

Traceback (most recent call last)
<ipython-input-2-caf2a1c03d0b> in <module>
      1 import torch
      2 x=torch.randn(2)
----> 3 torch.rsub(x, 1, alpha=2j)

RuntimeError: value cannot be converted to type float without overflow: (-0,-2)
```

The reason why this is happening is because the alpha check doesn't check for if `x` is not complex and `alpha` is complex.
The error gets thrown further along in the implementation of torch.sub,
when it coerces `alpha` to be the same dtype as the input tensor:
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L53

This PR fixes the bad error message by adding a new check to the alpha check.

Test Plan:
- pytest test/test_binary_ufuncs.py
- NB: add, sub, and rsub all share the same alpha check. The test only tests it for torch.add, but that should be sufficient.

ghstack-source-id: 943e7ea
Pull Request resolved: #54964
@codecov
Copy link

codecov bot commented Apr 7, 2021

Codecov Report

Merging #54964 (1733ae7) into gh/zou3519/349/base (82006ba) will decrease coverage by 0.01%.
The diff coverage is 100.00%.

@@                   Coverage Diff                   @@
##           gh/zou3519/349/base   #54964      +/-   ##
=======================================================
- Coverage                77.43%   77.42%   -0.02%     
=======================================================
  Files                     1895     1895              
  Lines                   187516   187518       +2     
=======================================================
- Hits                    145196   145178      -18     
- Misses                   42320    42340      +20     

@facebook-github-bot
Copy link
Contributor

@zou3519 merged this pull request in 1e70d21.

@facebook-github-bot facebook-github-bot deleted the gh/zou3519/349/head branch April 11, 2021 14:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants