Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

igamma and igammac: port to structured #57626

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 4 additions & 34 deletions aten/src/ATen/native/BinaryOps.cpp
Expand Up @@ -75,6 +75,8 @@ TORCH_META_FUNC(atan2) (const Tensor& self, const Tensor& other) {

CREATE_BINARY_META_FUNC(gcd);
CREATE_BINARY_META_FUNC(nextafter);
CREATE_BINARY_META_FUNC(igamma);
CREATE_BINARY_META_FUNC(igammac);

} // namespace meta

Expand Down Expand Up @@ -214,6 +216,8 @@ TORCH_IMPL_FUNC(func##_out) (const Tensor& self, const Tensor& other, const Tens

CREATE_BINARY_TORCH_IMPL_FUNC(gcd);
CREATE_BINARY_TORCH_IMPL_FUNC(nextafter);
CREATE_BINARY_TORCH_IMPL_FUNC(igamma);
CREATE_BINARY_TORCH_IMPL_FUNC(igammac);

Tensor special_xlog1py(const Scalar& x, const Tensor& y) {
return at::special_xlog1py(wrapped_scalar_tensor(x), y);
Expand Down Expand Up @@ -1111,40 +1115,6 @@ Tensor& hypot_(Tensor& self, const Tensor& other) {
return at::hypot_out(self, self, other);
}

Tensor& igamma_out(const Tensor& self, const Tensor& other, Tensor& result) {
auto iter = TensorIterator::binary_op(result, self, other);
igamma_stub(iter.device_type(), iter);
return result;
}

Tensor igamma(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
igamma_stub(iter.device_type(), iter);
return iter.output();
}

Tensor& igamma_(Tensor& self, const Tensor& other) {
return at::igamma_out(self, self, other);
}

Tensor& igammac_out(const Tensor& self, const Tensor& other, Tensor& result) {
auto iter = TensorIterator::binary_op(result, self, other);
igammac_stub(iter.device_type(), iter);
return result;
}

Tensor igammac(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
igammac_stub(iter.device_type(), iter);
return iter.output();
}

Tensor& igammac_(Tensor& self, const Tensor& other) {
return at::igammac_out(self, self, other);
}

// Note: this function is only for testing.
// It is undocumented and should not be used outside of tests.
Tensor _test_serialization_subcmul(const Tensor& self, const Tensor& other, const Scalar& alpha) {
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/BinaryOps.h
Expand Up @@ -86,8 +86,8 @@ DECLARE_DISPATCH(binary_fn, logaddexp2_stub);
DECLARE_DISPATCH(structured_binary_fn, gcd_stub);
DECLARE_DISPATCH(binary_fn, lcm_stub);
DECLARE_DISPATCH(binary_fn, hypot_stub);
DECLARE_DISPATCH(binary_fn, igamma_stub);
DECLARE_DISPATCH(binary_fn, igammac_stub);
DECLARE_DISPATCH(structured_binary_fn, igamma_stub);
DECLARE_DISPATCH(structured_binary_fn, igammac_stub);
DECLARE_DISPATCH(structured_binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);
DECLARE_DISPATCH(structured_binary_fn, copysign_stub);
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Expand Up @@ -892,7 +892,7 @@ void hypot_kernel(TensorIterator& iter) {
});
}

void igamma_kernel(TensorIterator& iter) {
void igamma_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "igamma_cpu", [&]() {
cpu_kernel_vec(
iter,
Expand All @@ -905,7 +905,7 @@ void igamma_kernel(TensorIterator& iter) {
});
}

void igammac_kernel(TensorIterator& iter) {
void igammac_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "igammac_cpu", [&]() {
cpu_kernel_vec(
iter,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/IGammaKernel.cu
Expand Up @@ -515,7 +515,7 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) {

namespace at { namespace native {

void igamma_kernel_cuda(TensorIterator& iter) {
void igamma_kernel_cuda(TensorIteratorBase& iter) {

AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "igamma_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
Expand All @@ -524,7 +524,7 @@ void igamma_kernel_cuda(TensorIterator& iter) {
});
}

void igammac_kernel_cuda(TensorIterator& iter) {
void igammac_kernel_cuda(TensorIteratorBase& iter) {

AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "igammac_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
Expand Down
16 changes: 8 additions & 8 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -6556,32 +6556,32 @@
CompositeExplicitAutograd: hypot_

- func: igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: igamma_out

- func: igamma(Tensor self, Tensor other) -> Tensor
structured_delegate: igamma.out
variants: method, function
dispatch:
CPU, CUDA: igamma

- func: igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!)
structured_delegate: igamma.out
variants: method
dispatch:
CPU, CUDA: igamma_

- func: igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: igammac_out

- func: igammac(Tensor self, Tensor other) -> Tensor
structured_delegate: igammac.out
variants: method, function
dispatch:
CPU, CUDA: igammac

- func: igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!)
structured_delegate: igammac.out
variants: method
dispatch:
CPU, CUDA: igammac_

- func: nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
Expand Down