Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[numpy] torch.{a}tanh : promote integer inputs to float #47064

Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,8 @@ Tensor& arcsinh_out(Tensor& result, const Tensor& self) { return at::asinh_out(r
Tensor arcsinh(const Tensor& self) { return self.asinh(); }
Tensor& arcsinh_(Tensor& self) { return self.asinh_(); }

Tensor& atanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, atanh_stub); }
Tensor atanh(const Tensor& self) { return unary_op_impl(self, at::atanh_out); }
Tensor& atanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, atanh_stub); }
Tensor atanh(const Tensor& self) { return unary_op_impl_float(self, atanh_stub); }
Tensor& atanh_(Tensor& self) { return unary_op_impl_(self, at::atanh_out); }

// arctanh, alias for atanh
Expand Down Expand Up @@ -448,8 +448,8 @@ Tensor& nan_to_num_(
return at::nan_to_num_out(self, self, nan, pos_inf, neg_inf);
}

Tensor& tanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, tanh_stub); }
Tensor tanh(const Tensor& self) { return unary_op_impl(self, at::tanh_out); }
Tensor& tanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, tanh_stub); }
Tensor tanh(const Tensor& self) { return unary_op_impl_float(self, tanh_stub); }
Tensor& tanh_(Tensor& self) { return unary_op_impl_(self, at::tanh_out); }

Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, tan_stub); }
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/UnaryGeometricKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void cosh_kernel_cuda(TensorIterator& iter) {
}

void tanh_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "tanh_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "tanh_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::tanh(a);
});
Expand All @@ -91,7 +91,7 @@ void asinh_kernel_cuda(TensorIterator& iter) {
}

void atanh_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "atanh_cuda", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "atanh_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::atanh(a);
});
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1123,8 +1123,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
} break;

case aten::tanh: {
return computeOneOperand(
"aten_tanh", v, [](const ExprHandle& a) { return tanh(a); });
return computeOneOperand("aten_tanh", v, [](const ExprHandle& a) {
return tanh(promoteIntegerToFloat(a));
});
} break;

case aten::sqrt: {
Expand Down
11 changes: 8 additions & 3 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,10 @@ def sample_inputs(self, device, dtype, requires_grad=False):
UnaryUfuncInfo('atanh',
ref=np.arctanh,
domain=(-1, 1),
dtypesIfCPU=floating_types(),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
test_inplace_grad=False),
UnaryUfuncInfo('cos',
Expand Down Expand Up @@ -395,7 +397,10 @@ def sample_inputs(self, device, dtype, requires_grad=False):
UnaryUfuncInfo('tanh',
ref=np.tanh,
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
promotes_integers_to_float=True,
skips=(
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]),
Expand Down