diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 518032c81b04..1aebfda85da0 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -339,8 +339,8 @@ Tensor& sin_out(Tensor& result, const Tensor& self) { return unary_op_impl_float Tensor sin(const Tensor& self) { return unary_op_impl_float(self, sin_stub); } Tensor& sin_(Tensor& self) { return unary_op_impl_(self, at::sin_out); } -Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, cos_stub); } -Tensor cos(const Tensor& self) { return unary_op_impl(self, at::cos_out); } +Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, cos_stub); } +Tensor cos(const Tensor& self) { return unary_op_impl_float(self, cos_stub); } Tensor& cos_(Tensor& self) { return unary_op_impl_(self, at::cos_out); } Tensor& sinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sinh_stub); } @@ -452,8 +452,8 @@ Tensor& tanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out( Tensor tanh(const Tensor& self) { return unary_op_impl(self, at::tanh_out); } Tensor& tanh_(Tensor& self) { return unary_op_impl_(self, at::tanh_out); } -Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, tan_stub); } -Tensor tan(const Tensor& self) { return unary_op_impl(self, at::tan_out); } +Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, tan_stub); } +Tensor tan(const Tensor& self) { return unary_op_impl_float(self, tan_stub); } Tensor& tan_(Tensor& self) { return unary_op_impl_(self, at::tan_out); } Tensor& trunc_out(Tensor& result, const Tensor& self) { diff --git a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu index 465e54db51d6..2f7e92f3fc2e 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu @@ -43,7 +43,7 @@ void sin_kernel_cuda(TensorIterator& iter) { } void cos_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "cos_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "cos_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::cos(a); }); @@ -99,7 +99,7 @@ void atanh_kernel_cuda(TensorIterator& iter) { } void tan_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "tan_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "tan_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::tan(a); }); diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index c6efa8f0d90d..c023553de402 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -919,8 +919,8 @@ def test_unary_op_out_casting(self, device, dtypes): t = torch.tensor((1), dtype=dtypes[0], device=device) out = torch.empty(0, dtype=dtypes[1], device=device) - ops = (torch.neg, torch.floor, torch.ceil, torch.cos, torch.erf) - float_only_ops = {torch.floor, torch.ceil, torch.cos, torch.erf} + ops = (torch.neg, torch.floor, torch.ceil, torch.erf) + float_only_ops = {torch.floor, torch.ceil, torch.erf} real_only_ops = {torch.floor, torch.ceil, torch.erf} for op in ops: if dtypes[0] is not dtypes[1]: diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 6d4dc91ff5bd..9f3353376913 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -288,7 +288,7 @@ def test_reference_numerics(self, device, dtype, op): # NOTE: For these dtypes, PyTorch computes in the default scalar type (float) # while NumPy computes in float16 self.assertEqualHelper(actual, expected, msg, dtype=dtype, - exact_dtype=exact_dtype, rtol=1e-4, atol=1e-3) + exact_dtype=exact_dtype, rtol=1e-3, atol=1e-2) continue self.assertEqualHelper(actual, expected, msg, dtype=dtype, exact_dtype=exact_dtype) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 48b24f1ae499..9b67d31e7d16 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -274,8 +274,11 @@ def sample_inputs(self, device, dtype, requires_grad=False): test_inplace_grad=False), UnaryUfuncInfo('cos', ref=np.cos, - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), handles_large_floats=False, + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -380,6 +383,10 @@ def sample_inputs(self, device, dtype, requires_grad=False): )), UnaryUfuncInfo('tan', ref=np.tan, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + promotes_integers_to_float=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]),