Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.reciprocal: promote integer inputs to float #49102

Closed
wants to merge 13 commits into from
4 changes: 2 additions & 2 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -308,8 +308,8 @@ Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_o
Tensor digamma(const Tensor& self) { return unary_op_impl(self, digamma_out); }
Tensor& digamma_(Tensor& self) { return unary_op_impl_(self, digamma_out); }

Tensor& reciprocal_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, reciprocal_stub); }
Tensor reciprocal(const Tensor& self) { return unary_op_impl(self, at::reciprocal_out); }
Tensor& reciprocal_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, reciprocal_stub); }
Tensor reciprocal(const Tensor& self) { return unary_op_impl_float(self, reciprocal_stub); }
Tensor& reciprocal_(Tensor& self) { return unary_op_impl_(self, at::reciprocal_out); }

Tensor& rsqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, rsqrt_stub); }
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Expand Up @@ -247,10 +247,10 @@ static void logical_not_kernel(TensorIterator& iter) {
}

static void reciprocal_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "reciprocal_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "reciprocal_cpu", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return static_cast<scalar_t>(1.0) / a; },
[=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { return static_cast<scalar_t>(1.0) / a; },
[=](Vec256<scalar_t> a) { return a.reciprocal(); });
});
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/UnaryFractionKernels.cu
Expand Up @@ -88,7 +88,7 @@ __host__ __device__ static inline c10::complex<T> reciprocal_wrapper(c10::comple
}

void reciprocal_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "reciprocal_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "reciprocal_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return reciprocal_wrapper(a);
});
Expand Down
13 changes: 4 additions & 9 deletions test/jit/test_save_load.py
Expand Up @@ -360,16 +360,11 @@ def _helper(m, fn):
else:
fn_result = self._try_fn(fn, a, b)

if not a.is_floating_point():
# NOTE: Torchscript rewrites the module forward into
# torch.reciprocal(a) * b, but torch.reciprocal is
# implemented for integer dtypes.
self.assertTrue(m_result, Exception)
self.assertTrue('"reciprocal_cpu" not implemented for' in str(m_result))
elif isinstance(m_result, Exception):
self.assertTrue(fn_result, Exception)
if isinstance(m_result, Exception):
self.assertTrue(isinstance(fn_result, Exception))
else:
self.assertEqual(m_result, fn_result)
if fn is torch.div or a.is_floating_point():
soulitzer marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(m_result, fn_result)

if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float_reciprocal)
Expand Down
7 changes: 1 addition & 6 deletions test/test_binary_ufuncs.py
Expand Up @@ -819,12 +819,7 @@ def _wrapped_rfloordiv_scalar(a):
if a == 0:
continue

if a_t.is_floating_point():
self.assertEqual(5 / a, scripted_rdiv_scalar(a_t))
else:
with self.assertRaises(RuntimeError):
scripted_rdiv_scalar(a_t)

self.assertEqual(5 / a, scripted_rdiv_scalar(a_t))

# Handles Issue 45199 (see comment above)
if a_t.is_floating_point():
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/tensorexpr/kernel.cpp
Expand Up @@ -942,7 +942,7 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {

case aten::reciprocal: {
return computeOneOperand("aten_reciprocal", v, [](const ExprHandle& a) {
return ExprHandle(1.0f) / a;
return ExprHandle(1.0f) / promoteIntegerToFloat(a);
});
} break;

Expand Down
18 changes: 14 additions & 4 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -784,6 +784,20 @@ def sample_inputs(self, device, dtype, requires_grad=False):
dtypes=all_types_and(torch.half, torch.bool),
dtypesIfCPU=None,
dtypesIfCUDA=None),
UnaryUfuncInfo('reciprocal',
ref=np_unary_ufunc_integer_promotion_wrapper(np.reciprocal),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCPU=None,
dtypesIfCUDA=None,
assert_autodiffed=True,
skip_bfloat16_grad=True,
skips=(
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
soulitzer marked this conversation as resolved.
Show resolved Hide resolved
dtypes=[torch.cfloat, torch.cdouble]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
dtypes=[torch.bfloat16]),
),
promotes_integers_to_float=True),
UnaryUfuncInfo('sqrt',
ref=np.sqrt,
domain=(0, float('inf')),
Expand Down Expand Up @@ -1165,10 +1179,6 @@ def method_tests():
('atan2', (S, S, S), ((S,),), 'broadcast_rhs'),
('atan2', (S,), ((S, S, S),), 'broadcast_lhs'),
('atan2', (S, 1, S), ((S, S),), 'broadcast_all'),
('reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS, '', (True,)),
soulitzer marked this conversation as resolved.
Show resolved Hide resolved
('reciprocal', uniform_scalar(0.1, requires_grad=True), NO_ARGS, 'scalar', (True,)),
('reciprocal', torch.randn(S, S, S, dtype=torch.cdouble) + 0.1, NO_ARGS, 'complex', (True,)),
('reciprocal', uniform_scalar(0.1j), NO_ARGS, 'complex_scalar', (True,)),
('round', (S, S, S), NO_ARGS, '', (True,)),
('round', (), NO_ARGS, 'scalar', (True,)),
('sign', (S, S, S), NO_ARGS),
Expand Down