Skip to content

Commit

Permalink
[numpy] torch.cos, torch.tan : promote integer inputs to float (#…
Browse files Browse the repository at this point in the history
…46706)

Summary:
References #42515

cc: mruberry

Pull Request resolved: #46706

Reviewed By: izdeby

Differential Revision: D24537262

Pulled By: mruberry

fbshipit-source-id: e57377a625814a3f34a765ce6bfd63a33c02a5d9
  • Loading branch information
kshitij12345 authored and facebook-github-bot committed Oct 29, 2020
1 parent 42a5114 commit 5c8aad1
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 10 deletions.
8 changes: 4 additions & 4 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -339,8 +339,8 @@ Tensor& sin_out(Tensor& result, const Tensor& self) { return unary_op_impl_float
Tensor sin(const Tensor& self) { return unary_op_impl_float(self, sin_stub); }
Tensor& sin_(Tensor& self) { return unary_op_impl_(self, at::sin_out); }

Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, cos_stub); }
Tensor cos(const Tensor& self) { return unary_op_impl(self, at::cos_out); }
Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, cos_stub); }
Tensor cos(const Tensor& self) { return unary_op_impl_float(self, cos_stub); }
Tensor& cos_(Tensor& self) { return unary_op_impl_(self, at::cos_out); }

Tensor& sinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sinh_stub); }
Expand Down Expand Up @@ -452,8 +452,8 @@ Tensor& tanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(
Tensor tanh(const Tensor& self) { return unary_op_impl(self, at::tanh_out); }
Tensor& tanh_(Tensor& self) { return unary_op_impl_(self, at::tanh_out); }

Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, tan_stub); }
Tensor tan(const Tensor& self) { return unary_op_impl(self, at::tan_out); }
Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, tan_stub); }
Tensor tan(const Tensor& self) { return unary_op_impl_float(self, tan_stub); }
Tensor& tan_(Tensor& self) { return unary_op_impl_(self, at::tan_out); }

Tensor& trunc_out(Tensor& result, const Tensor& self) {
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/UnaryGeometricKernels.cu
Expand Up @@ -43,7 +43,7 @@ void sin_kernel_cuda(TensorIterator& iter) {
}

void cos_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "cos_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "cos_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::cos(a);
});
Expand Down Expand Up @@ -99,7 +99,7 @@ void atanh_kernel_cuda(TensorIterator& iter) {
}

void tan_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "tan_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "tan_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::tan(a);
});
Expand Down
4 changes: 2 additions & 2 deletions test/test_type_promotion.py
Expand Up @@ -919,8 +919,8 @@ def test_unary_op_out_casting(self, device, dtypes):
t = torch.tensor((1), dtype=dtypes[0], device=device)
out = torch.empty(0, dtype=dtypes[1], device=device)

ops = (torch.neg, torch.floor, torch.ceil, torch.cos, torch.erf)
float_only_ops = {torch.floor, torch.ceil, torch.cos, torch.erf}
ops = (torch.neg, torch.floor, torch.ceil, torch.erf)
float_only_ops = {torch.floor, torch.ceil, torch.erf}
real_only_ops = {torch.floor, torch.ceil, torch.erf}
for op in ops:
if dtypes[0] is not dtypes[1]:
Expand Down
2 changes: 1 addition & 1 deletion test/test_unary_ufuncs.py
Expand Up @@ -288,7 +288,7 @@ def test_reference_numerics(self, device, dtype, op):
# NOTE: For these dtypes, PyTorch computes in the default scalar type (float)
# while NumPy computes in float16
self.assertEqualHelper(actual, expected, msg, dtype=dtype,
exact_dtype=exact_dtype, rtol=1e-4, atol=1e-3)
exact_dtype=exact_dtype, rtol=1e-3, atol=1e-2)
continue

self.assertEqualHelper(actual, expected, msg, dtype=dtype, exact_dtype=exact_dtype)
Expand Down
9 changes: 8 additions & 1 deletion torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -274,8 +274,11 @@ def sample_inputs(self, device, dtype, requires_grad=False):
test_inplace_grad=False),
UnaryUfuncInfo('cos',
ref=np.cos,
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
handles_large_floats=False,
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
skips=(
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
Expand Down Expand Up @@ -380,6 +383,10 @@ def sample_inputs(self, device, dtype, requires_grad=False):
)),
UnaryUfuncInfo('tan',
ref=np.tan,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half),
promotes_integers_to_float=True,
skips=(
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]),
Expand Down

0 comments on commit 5c8aad1

Please sign in to comment.