Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add complex support for torch.{acosh, asinh, atanh} #50387

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Expand Up @@ -336,23 +336,23 @@ static void cosh_kernel(TensorIterator& iter) {
}

static void acosh_kernel(TensorIterator& iter) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA dispatches also need an update.

AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "acosh_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "acosh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::acosh(a); });
});
}

static void asinh_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "asinh_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "asinh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::asinh(a); });
});
}

static void atanh_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "atanh_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "atanh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::atanh(a); });
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cuda/UnaryGeometricKernels.cu
Expand Up @@ -75,23 +75,23 @@ void tanh_kernel_cuda(TensorIterator& iter) {
}

void acosh_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "acosh_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "acosh_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::acosh(a);
});
});
}

void asinh_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "asinh_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "asinh_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::asinh(a);
});
});
}

void atanh_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "atanh_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "atanh_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::atanh(a);
});
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Expand Up @@ -80,7 +80,7 @@
'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger',
'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'atanh', 'take', 'fill_',
'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv',
'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'linalg_qr', '_svd_helper', '_fft_c2c', '_fft_r2c',
'linalg_solve', 'sqrt', 'stack', 'gather', 'index_select', 'index_add_', 'linalg_inv',
Expand Down
37 changes: 27 additions & 10 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -773,16 +773,21 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad):
UnaryUfuncInfo('acosh',
ref=np.arccosh,
domain=(1, float('inf')),
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
test_inplace_grad=False,
skips=(
# RuntimeError: "rsqrt_cuda" not implemented for 'BFloat16'
SkipInfo('TestCommon', 'test_variant_consistency_jit',
device_type='cuda', dtypes=[torch.bfloat16]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
active_if=IS_WINDOWS),
)),
OpInfo('addmm',
dtypes=floating_types(),
Expand Down Expand Up @@ -827,16 +832,21 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad):
# NOTE: derivative for inplace asinh is not implemented
UnaryUfuncInfo('asinh',
ref=np.arcsinh,
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
test_inplace_grad=False,
skips=(
# RuntimeError: "rsqrt_cuda" not implemented for 'BFloat16'
SkipInfo('TestCommon', 'test_variant_consistency_jit',
device_type='cuda', dtypes=[torch.bfloat16]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
active_if=IS_WINDOWS),
)),
UnaryUfuncInfo('atan',
ref=np.arctan,
Expand All @@ -857,12 +867,19 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad):
UnaryUfuncInfo('atanh',
ref=np.arctanh,
domain=(-1, 1),
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool),
anjali411 marked this conversation as resolved.
Show resolved Hide resolved
dtypesIfCPU=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
test_inplace_grad=False),
test_inplace_grad=False,
skips=(
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
active_if=IS_WINDOWS),
)),
OpInfo('broadcast_to',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_tensor_out=False,
Expand Down