Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[numpy] torch.cos, torch.tan : promote integer inputs to float #46706

Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -335,8 +335,8 @@ Tensor& sin_out(Tensor& result, const Tensor& self) { return unary_op_impl_float
Tensor sin(const Tensor& self) { return unary_op_impl_float(self, sin_stub); }
Tensor& sin_(Tensor& self) { return unary_op_impl_(self, at::sin_out); }

Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, cos_stub); }
Tensor cos(const Tensor& self) { return unary_op_impl(self, at::cos_out); }
Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, cos_stub); }
Tensor cos(const Tensor& self) { return unary_op_impl_float(self, cos_stub); }
Tensor& cos_(Tensor& self) { return unary_op_impl_(self, at::cos_out); }

Tensor& sinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sinh_stub); }
Expand Down Expand Up @@ -448,8 +448,8 @@ Tensor& tanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(
Tensor tanh(const Tensor& self) { return unary_op_impl(self, at::tanh_out); }
Tensor& tanh_(Tensor& self) { return unary_op_impl_(self, at::tanh_out); }

Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, tan_stub); }
Tensor tan(const Tensor& self) { return unary_op_impl(self, at::tan_out); }
Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, tan_stub); }
Tensor tan(const Tensor& self) { return unary_op_impl_float(self, tan_stub); }
Tensor& tan_(Tensor& self) { return unary_op_impl_(self, at::tan_out); }

Tensor& trunc_out(Tensor& result, const Tensor& self) {
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/UnaryGeometricKernels.cu
Expand Up @@ -43,7 +43,7 @@ void sin_kernel_cuda(TensorIterator& iter) {
}

void cos_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "cos_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "cos_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::cos(a);
});
Expand Down Expand Up @@ -99,7 +99,7 @@ void atanh_kernel_cuda(TensorIterator& iter) {
}

void tan_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "tan_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "tan_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::tan(a);
});
Expand Down
4 changes: 2 additions & 2 deletions test/test_type_promotion.py
Expand Up @@ -919,8 +919,8 @@ def test_unary_op_out_casting(self, device, dtypes):
t = torch.tensor((1), dtype=dtypes[0], device=device)
out = torch.empty(0, dtype=dtypes[1], device=device)

ops = (torch.neg, torch.floor, torch.ceil, torch.cos, torch.erf, torch.log)
float_only_ops = {torch.floor, torch.ceil, torch.cos, torch.erf, torch.log}
ops = (torch.neg, torch.floor, torch.ceil, torch.erf, torch.log)
float_only_ops = {torch.floor, torch.ceil, torch.erf, torch.log}
real_only_ops = {torch.floor, torch.ceil, torch.erf}
for op in ops:
if dtypes[0] is not dtypes[1]:
Expand Down
2 changes: 1 addition & 1 deletion test/test_unary_ufuncs.py
Expand Up @@ -288,7 +288,7 @@ def test_reference_numerics(self, device, dtype, op):
# NOTE: For these dtypes, PyTorch computes in the default scalar type (float)
# while NumPy computes in float16
self.assertEqualHelper(actual, expected, msg, dtype=dtype,
exact_dtype=exact_dtype, rtol=1e-4, atol=1e-3)
exact_dtype=exact_dtype, rtol=1e-3, atol=1e-2)
continue

self.assertEqualHelper(actual, expected, msg, dtype=dtype, exact_dtype=exact_dtype)
Expand Down
9 changes: 8 additions & 1 deletion torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -274,8 +274,11 @@ def sample_inputs(self, device, dtype, requires_grad=False):
test_inplace_grad=False),
UnaryUfuncInfo('cos',
ref=np.cos,
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
handles_large_floats=False,
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
skips=(
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
Expand Down Expand Up @@ -371,6 +374,10 @@ def sample_inputs(self, device, dtype, requires_grad=False):
)),
UnaryUfuncInfo('tan',
ref=np.tan,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half),
promotes_integers_to_float=True,
skips=(
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]),
Expand Down