Skip to content

Commit

Permalink
Fix deprecated warnings for nan_to_num (#46309)
Browse files Browse the repository at this point in the history
Summary:
Related to #44592

This PR is to fix the deprecated warnings for the nan_to_num function.

Below is the warning message when building the latest code.
```
../aten/src/ATen/native/UnaryOps.cpp: In function ‘at::Tensor& at::native::nan_to_num_out(at::Tensor&,
const at::Tensor&, c10::optional<double>, c10::optional<double>, c10::optional<double>)’:
../aten/src/ATen/native/UnaryOps.cpp:397:45: warning: ‘bool c10::isIntegralType(c10::ScalarType)’
is deprecated: isIntegralType is deprecated.
Please use the overload with 'includeBool' parameter instead. [-Wdeprecated-declarations]
   if (c10::isIntegralType(self.scalar_type())) {
```

The deprecated warning is defined in `ScalarType.h`.
https://github.com/pytorch/pytorch/blob/d790ec6de01a61fe81733c41a64b6092bacfb7bd/c10/core/ScalarType.h#L255-L260

Pull Request resolved: #46309

Reviewed By: mrshenli

Differential Revision: D24310248

Pulled By: heitorschueroff

fbshipit-source-id: 0f9f2ad304eb5a2da9d2b415343f2fc9029037af
  • Loading branch information
RockingJavaBean authored and facebook-github-bot committed Oct 16, 2020
1 parent ecf6335 commit 7b788d1
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ Tensor& nan_to_num_out(
c10::optional<double> pos_inf,
c10::optional<double> neg_inf) {

if (c10::isIntegralType(self.scalar_type())) {
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
result.resize_as_(self);
result.copy_(self);
return result;
Expand Down
3 changes: 2 additions & 1 deletion test/test_unary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,8 @@ def test_batch_vs_slicing(self, device, dtype, op):

self.assertEqual(actual, expected)

@dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False)))
@dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool] +
torch.testing.get_all_fp_dtypes(include_bfloat16=False)))
def test_nan_to_num(self, device, dtype):
for contiguous in [False, True]:
x = make_tensor((64, 64), low=0., high=100., dtype=dtype, device=device)
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def sample_inputs(self, device, dtype, requires_grad=False):
dtypesIfCUDA=None),
UnaryUfuncInfo('nan_to_num',
ref=np.nan_to_num,
dtypes=all_types_and(torch.half),
dtypes=all_types_and(torch.half, torch.bool),
dtypesIfCPU=None,
dtypesIfCUDA=None)
]
Expand Down

0 comments on commit 7b788d1

Please sign in to comment.