Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.reciprocal: promote integer inputs to float #49102

Closed
wants to merge 13 commits into from
4 changes: 2 additions & 2 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -308,8 +308,8 @@ Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_o
Tensor digamma(const Tensor& self) { return unary_op_impl(self, digamma_out); }
Tensor& digamma_(Tensor& self) { return unary_op_impl_(self, digamma_out); }

Tensor& reciprocal_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, reciprocal_stub); }
Tensor reciprocal(const Tensor& self) { return unary_op_impl(self, at::reciprocal_out); }
Tensor& reciprocal_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, reciprocal_stub); }
Tensor reciprocal(const Tensor& self) { return unary_op_impl_float(self, reciprocal_stub); }
Tensor& reciprocal_(Tensor& self) { return unary_op_impl_(self, at::reciprocal_out); }

Tensor& rsqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, rsqrt_stub); }
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Expand Up @@ -246,12 +246,12 @@ static void logical_not_kernel(TensorIterator& iter) {
});
}

static void reciprocal_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "reciprocal_cpu", [&]() {
static void reciprocal_kernel(TensorIterator& iter) __ubsan_ignore_float_divide_by_zero__ {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "reciprocal_cpu", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return static_cast<scalar_t>(1.0) / a; },
[=](Vec256<scalar_t> a) { return a.reciprocal(); });
[=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { return static_cast<scalar_t>(1.0) / a; },
[=](Vec256<scalar_t> a) __ubsan_ignore_float_divide_by_zero__ { return a.reciprocal(); });
});
}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/UnaryFractionKernels.cu
Expand Up @@ -88,7 +88,7 @@ __host__ __device__ static inline c10::complex<T> reciprocal_wrapper(c10::comple
}

void reciprocal_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "reciprocal_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "reciprocal_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return reciprocal_wrapper(a);
});
Expand Down
2 changes: 1 addition & 1 deletion c10/util/BFloat16-inl.h
Expand Up @@ -59,7 +59,7 @@ inline C10_HOST_DEVICE BFloat16 operator*(const BFloat16& a, const BFloat16& b)
return static_cast<float>(a) * static_cast<float>(b);
}

inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b) {
inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b) __ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}

Expand Down
2 changes: 1 addition & 1 deletion c10/util/Half-inl.h
Expand Up @@ -66,7 +66,7 @@ inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) {
return static_cast<float>(a) * static_cast<float>(b);
}

inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) {
inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) __ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}

Expand Down
16 changes: 7 additions & 9 deletions test/jit/test_save_load.py
Expand Up @@ -360,16 +360,14 @@ def _helper(m, fn):
else:
fn_result = self._try_fn(fn, a, b)

if not a.is_floating_point():
# NOTE: Torchscript rewrites the module forward into
# torch.reciprocal(a) * b, but torch.reciprocal is
# implemented for integer dtypes.
self.assertTrue(m_result, Exception)
self.assertTrue('"reciprocal_cpu" not implemented for' in str(m_result))
elif isinstance(m_result, Exception):
self.assertTrue(fn_result, Exception)
else:
if isinstance(m_result, Exception):
self.assertTrue(isinstance(fn_result, Exception))
elif fn is torch.div or a.is_floating_point():
self.assertEqual(m_result, fn_result)
else:
# Skip when fn is not torch.div and a is integral because
# historic_div_scalar_int performs floored division
pass

if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float_reciprocal)
Expand Down
12 changes: 2 additions & 10 deletions test/test_binary_ufuncs.py
Expand Up @@ -786,10 +786,7 @@ def _wrapped_floordiv(a, b):
def _wrapped_div_scalar(a):
return a / 5

# NOTE: this will fail when given an integer input, since
# the JIT implements division as
# torch.reciprocal(a) * 5, and reciprocal is only
# implemented for float types.
# NOTE: the JIT implements division as torch.reciprocal(a) * 5
def _wrapped_rdiv_scalar(a):
return 5 / a

Expand Down Expand Up @@ -819,12 +816,7 @@ def _wrapped_rfloordiv_scalar(a):
if a == 0:
continue

if a_t.is_floating_point():
self.assertEqual(5 / a, scripted_rdiv_scalar(a_t))
else:
with self.assertRaises(RuntimeError):
scripted_rdiv_scalar(a_t)

self.assertEqual(5 / a, scripted_rdiv_scalar(a_t))

# Handles Issue 45199 (see comment above)
if a_t.is_floating_point():
Expand Down
46 changes: 0 additions & 46 deletions test/test_unary_ufuncs.py
Expand Up @@ -904,52 +904,6 @@ def test_silu(self, device, dtype):
input_noncontig, inplace=True), expected_output_noncontig,
atol=atol, rtol=rtol)

# Opinfo reciprocal
@onlyCPU
@dtypes(torch.float, torch.double)
def test_reciprocal(self, device, dtype):
a = torch.randn(100, 89, device=device, dtype=dtype)
res_div = 1 / a
res_reciprocal = a.clone()
res_reciprocal.reciprocal_()
self.assertEqual(res_reciprocal, res_div)

@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
def test_reciprocal_complex(self, device, dtype):
t = torch.randn(10, 10, dtype=dtype, device=device)
expected = torch.from_numpy(np.reciprocal(t.cpu().numpy()))
actual = torch.reciprocal(t).cpu()
self.assertEqual(expected, actual)

@onlyCUDA
@dtypes(torch.complex64, torch.complex128)
def test_reciprocal_complex_extremal(self, device, dtype):
vals = (
# Inf and Zeros
complex(float('inf'), float('inf')),
complex(float('inf'), 0.),
complex(0., float('inf')),
complex(0., 0.),

# Nans and Zeros
complex(float('nan'), 0.),
complex(0., float('nan')),
complex(float('nan'), float('nan')),

# Inf and Nans
complex(float('nan'), float('inf')),
complex(float('inf'), float('nan')),

# Extremal and Normal Number
complex(float('nan'), 2.0),
complex(float('inf'), 2.0),
complex(2.0, float('nan')),
complex(2.0, float('inf')),
complex(2.0, 0.0),
complex(0.0, 2.0))

self.compare_with_numpy(torch.reciprocal, np.reciprocal, vals, device, dtype)

# do ops like threshold need a test_unary(_nonufunc) test suite?
@onlyCPU
@dtypes(*torch.testing.get_all_math_dtypes('cpu'))
Expand Down
5 changes: 5 additions & 0 deletions torch/_torch_docs.py
Expand Up @@ -1538,6 +1538,11 @@ def merge_dicts(*dicts):

Returns a new tensor with the reciprocal of the elements of :attr:`input`

.. note::
Unlike NumPy's reciprocal, torch.reciprocal supports integral inputs. Integral
inputs to reciprocal are automatically :ref:`promoted <type-promotion-doc>` to
the default scalar type.

.. math::
\text{out}_{i} = \frac{1}{\text{input}_{i}}
""" + r"""
Expand Down
5 changes: 1 addition & 4 deletions torch/tensor.py
Expand Up @@ -550,10 +550,7 @@ def __rdiv__(self, other):
from torch.overrides import has_torch_function, handle_torch_function
if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args):
return handle_torch_function(Tensor.__rdiv__, relevant_args, self, other)
if self.dtype.is_floating_point or self.dtype.is_complex:
return self.reciprocal() * other
else:
return self.to(torch.get_default_dtype()).reciprocal() * other
return self.reciprocal() * other

__rtruediv__ = __rdiv__
__itruediv__ = _C._TensorBase.__idiv__
Expand Down
20 changes: 16 additions & 4 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -784,6 +784,22 @@ def sample_inputs(self, device, dtype, requires_grad=False):
dtypes=all_types_and(torch.half, torch.bool),
dtypesIfCPU=None,
dtypesIfCUDA=None),
UnaryUfuncInfo('reciprocal',
ref=np_unary_ufunc_integer_promotion_wrapper(np.reciprocal),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCPU=None,
dtypesIfCUDA=None,
assert_autodiffed=True,
skip_bfloat16_grad=True,
promotes_integers_to_float=True,
skips=(
# Reference: https://github.com/pytorch/pytorch/issues/45690
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
soulitzer marked this conversation as resolved.
Show resolved Hide resolved
dtypes=[torch.cfloat, torch.cdouble]),
# Reference: https://github.com/pytorch/pytorch/pull/49102#issuecomment-744604601
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
dtypes=[torch.bfloat16]),
)),
UnaryUfuncInfo('sqrt',
ref=np.sqrt,
domain=(0, float('inf')),
Expand Down Expand Up @@ -1165,10 +1181,6 @@ def method_tests():
('atan2', (S, S, S), ((S,),), 'broadcast_rhs'),
('atan2', (S,), ((S, S, S),), 'broadcast_lhs'),
('atan2', (S, 1, S), ((S, S),), 'broadcast_all'),
('reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS, '', (True,)),
soulitzer marked this conversation as resolved.
Show resolved Hide resolved
('reciprocal', uniform_scalar(0.1, requires_grad=True), NO_ARGS, 'scalar', (True,)),
('reciprocal', torch.randn(S, S, S, dtype=torch.cdouble) + 0.1, NO_ARGS, 'complex', (True,)),
('reciprocal', uniform_scalar(0.1j), NO_ARGS, 'complex_scalar', (True,)),
('round', (S, S, S), NO_ARGS, '', (True,)),
('round', (), NO_ARGS, 'scalar', (True,)),
('sign', (S, S, S), NO_ARGS),
Expand Down