New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[numpy] torch.sqrt
: promote integer inputs to float
#47293
Changes from all commits
ca27f31
6880799
a587945
331f68b
44c4359
60d35a1
dfa204b
44acce0
91b9a3a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -436,7 +436,37 @@ def sample_inputs(self, device, dtype, requires_grad=False): | |
ref=np.nan_to_num, | ||
dtypes=all_types_and(torch.half, torch.bool), | ||
dtypesIfCPU=None, | ||
dtypesIfCUDA=None) | ||
dtypesIfCUDA=None), | ||
UnaryUfuncInfo('sqrt', | ||
ref=np.sqrt, | ||
domain=(0, float('inf')), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Metanote: it's a little unfortunate that we don't have dtype-specific domains so we can't test taking the sqrt of negative complex values easily. |
||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), | ||
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), | ||
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), | ||
decorators=(precisionOverride({torch.bfloat16: 7e-2}),), | ||
skips=( | ||
# Reference: https://github.com/pytorch/pytorch/issues/47358 | ||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', | ||
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], | ||
active_if=IS_MACOS), | ||
# Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436 | ||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', | ||
dtypes=[torch.bfloat16]), | ||
# RuntimeError: sqrt does not support automatic differentiation for outputs with complex dtype. | ||
SkipInfo('TestGradients', 'test_fn_grad', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Metanote: if this happens again we should add a property for whether a function supports complex autograd or not, and, if complex autograd isn't supported, skip the complex autograd tests. |
||
dtypes=[torch.cdouble]), | ||
SkipInfo('TestGradients', 'test_fn_gradgrad', | ||
dtypes=[torch.cdouble]), | ||
SkipInfo('TestGradients', 'test_method_grad', | ||
dtypes=[torch.cdouble]), | ||
SkipInfo('TestGradients', 'test_method_gradgrad', | ||
dtypes=[torch.cdouble]), | ||
SkipInfo('TestGradients', 'test_inplace_grad', | ||
dtypes=[torch.cdouble]), | ||
SkipInfo('TestGradients', 'test_inplace_gradgrad', | ||
dtypes=[torch.cdouble]),), | ||
promotes_integers_to_float=True, | ||
handles_complex_extremals=False), | ||
] | ||
|
||
# Common operator groupings | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's great to see another OpInfo! With this ported this PR can also remove the TorchMathTestMeta for sqrt in test_torch.py (see torch_op_tests, I can't link the line because the file is too big to render on Github).