Skip to content

Commit

Permalink
Add complex support for torch.{cosh, sinh, tanh}
Browse files Browse the repository at this point in the history
ghstack-source-id: fe59a5244e702853f1e372c935a0b1bd9486bd30
Pull Request resolved: #50387
  • Loading branch information
anjali411 committed Jan 15, 2021
1 parent f10e7aa commit d6673fe
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 20 deletions.
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Expand Up @@ -336,23 +336,23 @@ static void cosh_kernel(TensorIterator& iter) {
}

static void acosh_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "acosh_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "acosh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::acosh(a); });
});
}

static void asinh_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "asinh_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "asinh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::asinh(a); });
});
}

static void atanh_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "atanh_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "atanh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::atanh(a); });
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cuda/UnaryGeometricKernels.cu
Expand Up @@ -75,23 +75,23 @@ void tanh_kernel_cuda(TensorIterator& iter) {
}

void acosh_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "acosh_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "acosh_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::acosh(a);
});
});
}

void asinh_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "asinh_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "asinh_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::asinh(a);
});
});
}

void atanh_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "atanh_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "atanh_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::atanh(a);
});
Expand Down
3 changes: 0 additions & 3 deletions test/test_torch.py
Expand Up @@ -6904,9 +6904,6 @@ def inner(self, device, dtype):
torch.testing.get_all_fp_dtypes() + _complex_types, [torch.bfloat16]),
('asin', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('atan', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('acosh', '', lambda t, d: _small_3d(t, d) + 1, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()),
('asinh', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()),
('atanh', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()),
('erf', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
('erfc', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('rad2deg', '', _small_3d, lambda t, d: [], 1e-1, 1e-0, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Expand Up @@ -80,7 +80,7 @@
'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger',
'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'atanh', 'take', 'fill_',
'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv',
'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'linalg_qr', '_svd_helper', '_fft_c2c', '_fft_r2c',
'linalg_solve', 'sqrt', 'stack', 'gather', 'index_select', 'index_add_', 'linalg_inv',
Expand Down
37 changes: 27 additions & 10 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -724,16 +724,21 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad):
UnaryUfuncInfo('acosh',
ref=np.arccosh,
domain=(1, float('inf')),
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
test_inplace_grad=False,
skips=(
# RuntimeError: "rsqrt_cuda" not implemented for 'BFloat16'
SkipInfo('TestCommon', 'test_variant_consistency_jit',
device_type='cuda', dtypes=[torch.bfloat16]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
active_if=IS_WINDOWS),
)),
OpInfo('addmm',
dtypes=floating_types(),
Expand Down Expand Up @@ -777,16 +782,21 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad):
# NOTE: derivative for inplace asinh is not implemented
UnaryUfuncInfo('asinh',
ref=np.arcsinh,
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
test_inplace_grad=False,
skips=(
# RuntimeError: "rsqrt_cuda" not implemented for 'BFloat16'
SkipInfo('TestCommon', 'test_variant_consistency_jit',
device_type='cuda', dtypes=[torch.bfloat16]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
active_if=IS_WINDOWS),
)),
UnaryUfuncInfo('atan',
ref=np.arctan,
Expand All @@ -807,12 +817,19 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad):
UnaryUfuncInfo('atanh',
ref=np.arctanh,
domain=(-1, 1),
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
test_inplace_grad=False),
test_inplace_grad=False,
skips=(
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
active_if=IS_WINDOWS),
)),
OpInfo('broadcast_to',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_tensor_out=False,
Expand Down

0 comments on commit d6673fe

Please sign in to comment.