diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index daea7e7f68bb..900f5ee72f7a 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -343,8 +343,8 @@ Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_float Tensor cos(const Tensor& self) { return unary_op_impl_float(self, cos_stub); } Tensor& cos_(Tensor& self) { return unary_op_impl_(self, at::cos_out); } -Tensor& sinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sinh_stub); } -Tensor sinh(const Tensor& self) { return unary_op_impl(self, at::sinh_out); } +Tensor& sinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, sinh_stub); } +Tensor sinh(const Tensor& self) { return unary_op_impl_float(self, sinh_stub); } Tensor& sinh_(Tensor& self) { return unary_op_impl_(self, at::sinh_out); } Tensor& cosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, cosh_stub); } diff --git a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu index 46281b6573aa..2488528f5e2c 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu @@ -51,7 +51,7 @@ void cos_kernel_cuda(TensorIterator& iter) { } void sinh_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "sinh_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "sinh_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::sinh(a); }); diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index d36f4f428c53..50f285104d95 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1148,8 +1148,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } break; case aten::sinh: { - return computeOneOperand( - "aten_sinh", v, [](const ExprHandle& a) { return sinh(a); }); + return computeOneOperand("aten_sinh", v, [](const ExprHandle& a) { + return sinh(promoteIntegerToFloat(a)); + }); } break; case aten::atan: { diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 424d6d254470..8506ec37e056 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1,4 +1,4 @@ -from functools import reduce +from functools import reduce, wraps from operator import mul, itemgetter import collections @@ -21,7 +21,8 @@ random_symmetric_matrix, random_symmetric_psd_matrix, random_symmetric_pd_matrix, make_nonzero_det, random_fullrank_matrix_distinct_singular_value, set_rng_seed, - TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor, TEST_SCIPY) + TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor, TEST_SCIPY, + torch_to_numpy_dtype_dict) if TEST_SCIPY: import scipy.special @@ -262,6 +263,31 @@ def sample_inputs_addmm(self, device, dtype, requires_grad): low=None, high=None, requires_grad=False))),) +def np_unary_ufunc_integer_promotion_wrapper(fn): + # Wrapper that passes PyTorch's default scalar + # type as an argument to the wrapped NumPy + # unary ufunc when given an integer input. + # This mimicks PyTorch's integer->floating point + # type promotion. + # + # This is necessary when NumPy promotes + # integer types to double, since PyTorch promotes + # integer types to the default scalar type. + + # Helper to determine if promotion is needed + def is_integral(dtype): + return dtype in [np.bool, np.uint8, np.int8, np.int16, np.int32, np.int64] + + # NOTE: Promotion in PyTorch is from integer types to the default dtype + np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()] + + @wraps(fn) + def wrapped_fn(x): + if is_integral(x.dtype): + return fn(x, dtype=np_dtype) + return fn(x) + + return wrapped_fn # Operator database (sorted alphabetically) op_db: List[Any] = [ @@ -508,8 +534,10 @@ def sample_inputs_addmm(self, device, dtype, requires_grad): dtypes=[torch.float], active_if=TEST_WITH_ROCM), )), UnaryUfuncInfo('sinh', - ref=np.sinh, - dtypesIfCPU=floating_and_complex_types(), + ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh), + dtypesIfCPU=all_types_and_complex_and(torch.bool), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + promotes_integers_to_float=True, assert_autodiffed=True, decorators=(precisionOverride({torch.float16: 1e-2}),), skips=( @@ -519,6 +547,9 @@ def sample_inputs_addmm(self, device, dtype, requires_grad): SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + # Reference: https://github.com/pytorch/pytorch/issues/48641 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.int8]), SkipInfo('TestCommon', 'test_variant_consistency_jit', device_type='cuda', dtypes=[torch.float16]), )),