Skip to content

Commit

Permalink
Add complex support for torch.{cosh, sinh, tanh}
Browse files Browse the repository at this point in the history
ghstack-source-id: dac14477cb366ff0ea3a7953ec3a1a1879e902c6
Pull Request resolved: #50387
  • Loading branch information
anjali411 committed Jan 17, 2021
1 parent 7e05d07 commit fdc5e87
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 17 deletions.
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Expand Up @@ -336,23 +336,23 @@ static void cosh_kernel(TensorIterator& iter) {
}

static void acosh_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "acosh_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "acosh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::acosh(a); });
});
}

static void asinh_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "asinh_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "asinh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::asinh(a); });
});
}

static void atanh_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "atanh_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "atanh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::atanh(a); });
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cuda/UnaryGeometricKernels.cu
Expand Up @@ -75,23 +75,23 @@ void tanh_kernel_cuda(TensorIterator& iter) {
}

void acosh_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "acosh_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "acosh_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::acosh(a);
});
});
}

void asinh_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "asinh_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "asinh_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::asinh(a);
});
});
}

void atanh_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "atanh_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "atanh_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::atanh(a);
});
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Expand Up @@ -80,7 +80,7 @@
'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger',
'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'atanh', 'take', 'fill_',
'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv',
'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'linalg_qr', '_svd_helper', '_fft_c2c', '_fft_r2c',
'linalg_solve', 'sqrt', 'stack', 'gather', 'index_select', 'index_add_', 'linalg_inv',
Expand Down
37 changes: 27 additions & 10 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -773,16 +773,21 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad):
UnaryUfuncInfo('acosh',
ref=np.arccosh,
domain=(1, float('inf')),
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
test_inplace_grad=False,
skips=(
# RuntimeError: "rsqrt_cuda" not implemented for 'BFloat16'
SkipInfo('TestCommon', 'test_variant_consistency_jit',
device_type='cuda', dtypes=[torch.bfloat16]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
active_if=IS_WINDOWS),
)),
OpInfo('addmm',
dtypes=floating_types(),
Expand Down Expand Up @@ -827,16 +832,21 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad):
# NOTE: derivative for inplace asinh is not implemented
UnaryUfuncInfo('asinh',
ref=np.arcsinh,
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
test_inplace_grad=False,
skips=(
# RuntimeError: "rsqrt_cuda" not implemented for 'BFloat16'
SkipInfo('TestCommon', 'test_variant_consistency_jit',
device_type='cuda', dtypes=[torch.bfloat16]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
active_if=IS_WINDOWS),
)),
UnaryUfuncInfo('atan',
ref=np.arctan,
Expand All @@ -857,12 +867,19 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad):
UnaryUfuncInfo('atanh',
ref=np.arctanh,
domain=(-1, 1),
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
test_inplace_grad=False),
test_inplace_grad=False,
skips=(
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
active_if=IS_WINDOWS),
)),
OpInfo('broadcast_to',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_tensor_out=False,
Expand Down

0 comments on commit fdc5e87

Please sign in to comment.