Skip to content

Commit

Permalink
Revert "Include support for the scatter gather cuda kernels to allow …
Browse files Browse the repository at this point in the history
…for comp… (#124809)"

This reverts commit e09f98c.

Reverted #124809 on behalf of https://github.com/clee2000 due to windows build failure is real, https://github.com/pytorch/pytorch/actions/runs/8910674030/job/24470387612#step:11:11236 is the correct failure line, ignore the statement saying build passed, batch is errorcodes arent propagating again ([comment](#124809 (comment)))
  • Loading branch information
pytorchmergebot committed May 1, 2024
1 parent e16f1ee commit 4d41015
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 102 deletions.
88 changes: 1 addition & 87 deletions aten/src/ATen/cuda/Atomic.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,6 @@ struct AtomicFPOp<at::Half> {
}
};

template <>
struct AtomicFPOp<c10::complex<float>> {
template <typename func_t>
inline __device__ c10::complex<float> operator() (c10::complex<float> *address, c10::complex<float> val, const func_t& func) {
unsigned long long int* addr_as_ull = (unsigned long long int*)address;
unsigned long long int old = *addr_as_ull;
unsigned long long int assumed, new_val;

c10::complex<float> csum;
do {
assumed = old;
csum = func(csum, val);
new_val = *reinterpret_cast<unsigned long long*>(&csum);
old = atomicCAS(addr_as_ull, assumed, new_val);
} while (assumed != old);

return *reinterpret_cast<c10::complex<float>*>(&addr_as_ull);
}
};

template <>
struct AtomicFPOp<at::BFloat16> {
template <typename func_t>
Expand Down Expand Up @@ -368,14 +348,6 @@ GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)

inline __device__ c10::complex<float> gpuAtomicMul(c10::complex<float> *address, c10::complex<float> val){
return AtomicFPOp<c10::complex<float>>()(address, val,
[](c10::complex<float> bsum, c10::complex<float> val) {
bsum*=(val);
return bsum;
});
}

inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) {
return AtomicFPOp<at::Half>()(address, val,
[](at::Half bsum, at::Half val) {
Expand All @@ -397,7 +369,7 @@ inline __device__ double gpuAtomicMul(double * address, double val) {
});
}

// Don't use a templated function for this since the addition function defaults to the CUDA built-in.
// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
inline __device__ float gpuAtomicMul (float * address, float val) {
unsigned int* address_as_ull = (unsigned int*)address;
unsigned int old = *address_as_ull;
Expand Down Expand Up @@ -430,29 +402,6 @@ __host__ __device__ T safe_max(T a, T b) {
return max;
}

__inline__ __device__ c10::complex<float> complex_max(c10::complex<float> a, c10::complex<float> b) {
if(at::_isnan(b)) {
return b;
} else {
// Compute the magnitude of the complex numbers and compare each to see which one is greater.
float a_magnitude = __fsqrt_rn(
(
__fmul_rn(a.real(), a.real()) +
__fmul_rn(a.imag(),a.imag())
)
);
float b_magnitude = __fsqrt_rn(
(
__fmul_rn(b.real(), b.real()) +
__fmul_rn(b.imag(),b.imag())
)
);
return std::max<float>(a_magnitude, b_magnitude);
}

}


ATOMIC_INTEGER_IMPL(Max)
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
Expand All @@ -467,13 +416,6 @@ inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
});
}

inline __device__ c10::complex<float> gpuAtomicMax(c10::complex<float> * address, c10::complex<float> val) {
return AtomicFPOp<c10::complex<float>>()(address, val,
[](c10::complex<float> bsum, c10::complex<float> val) {
return complex_max(bsum, val);
});
}

inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) {
return AtomicFPOp<at::BFloat16>()(address, val,
[](at::BFloat16 bsum, at::BFloat16 val) {
Expand Down Expand Up @@ -520,27 +462,6 @@ __host__ __device__ T safe_min(T a, T b) {
return min;
}

__inline__ __device__ c10::complex<float> complex_min(c10::complex<float> a, c10::complex<float> b) {
if(at::_isnan(b)) {
return b;
} else {
// Compute the magnitude of the complex numbers and compare each to see which one is smaller.
float a_magnitude = __fsqrt_rn(
(
__fmul_rn(a.real(), a.real()) +
__fmul_rn(a.imag(),a.imag())
)
);
float b_magnitude = __fsqrt_rn(
(
__fmul_rn(b.real(), b.real()) +
__fmul_rn(b.imag(),b.imag())
)
);
return std::min<float>(a_magnitude, b_magnitude);
}
}

ATOMIC_INTEGER_IMPL(Min)
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
Expand All @@ -555,13 +476,6 @@ inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
});
}

inline __device__ c10::complex<float> gpuAtomicMin(c10::complex<float> * address, c10::complex<float> val) {
return AtomicFPOp<c10::complex<float>>()(address, val,
[](c10::complex<float> bsum, c10::complex<float> val) {
return complex_min(bsum, val);
});
}

inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) {
return AtomicFPOp<at::BFloat16>()(address, val,
[](at::BFloat16 bsum, at::BFloat16 val) {
Expand Down
10 changes: 6 additions & 4 deletions aten/src/ATen/native/cuda/ScatterGatherKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/MemoryOverlap.h>

#include <ATen/native/ScatterGatherChecks.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/TensorIterator.h>
Expand Down Expand Up @@ -200,6 +201,7 @@ struct cuda_scatter_gather_base_kernel {
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;


AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(),
Expand Down Expand Up @@ -257,6 +259,7 @@ struct cuda_scatter_gather_base_kernel {
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;


AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(),
Expand Down Expand Up @@ -315,9 +318,9 @@ struct cuda_scatter_gather_base_kernel {
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;

AT_DISPATCH_ALL_TYPES_AND3(

AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
at::ScalarType::ComplexFloat,
iter.dtype(),
"cuda_scatter_gather_base_kernel_func", [&] {
using dtype = typename std::conditional<cast_to_opaque,
Expand Down Expand Up @@ -447,9 +450,8 @@ struct cuda_scatter_fill_base_kernel {
auto index_size = ensure_nonempty_size(self, dim);
auto index_stride = ensure_nonempty_stride(self, dim);

AT_DISPATCH_ALL_TYPES_AND3(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
at::ScalarType::ComplexFloat,
iter.dtype(),
"cuda_scatter_fill_base_kernel_reduce_multiply", [&] {
using dtype = typename std::conditional<cast_to_opaque,
Expand Down
12 changes: 4 additions & 8 deletions test/test_scatter_gather_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,17 +221,15 @@ def test_scatter_reduce_sum(self, device, dtype):
include_self=include_self)

@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
include_complex=False, include_bool=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
def test_scatter_reduce_prod(self, device, dtype):
for include_self in (True, False):
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
is_scalar=False, reduction='prod', unique_indices=False,
include_self=include_self)

@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
include_complex=False, include_bool=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
def test_scatter_reduce_mean(self, device, dtype):
for include_self in (True, False):
for deterministic in [False, True]:
Expand All @@ -241,8 +239,7 @@ def test_scatter_reduce_mean(self, device, dtype):
include_self=include_self)

@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
include_complex=False, include_bool=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
def test_scatter_reduce_amax(self, device, dtype):
for include_self in (True, False):
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
Expand All @@ -261,8 +258,7 @@ def test_scatter_reduce_amax(self, device, dtype):


@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
include_complex=False, include_bool=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
def test_scatter_reduce_amin(self, device, dtype):
for include_self in (True, False):
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
Expand Down
6 changes: 3 additions & 3 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
_create_scaling_case, _create_scaling_models_optimizers)
from torch.testing._internal.common_mkldnn import bf32_on_and_off
from torch.testing._internal.common_dtype import (
floating_types_and, get_all_math_dtypes, all_types_and_complex_and, all_types_and, floating_types,
floating_and_complex_types, integral_types_and,
floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types,
all_types_and, floating_types, floating_and_complex_types, integral_types_and,
get_all_qint_dtypes,
)
from torch.testing._internal.two_tensor import TwoTensor
Expand Down Expand Up @@ -3837,7 +3837,7 @@ def test_scatter_reduce_non_unique_index(self, device, dtype):
self.assertEqual(input, result, msg=f"result: {result} input: {input} method: {str(operation)}")

@onlyCUDA
@dtypes(torch.cdouble)
@dtypes(*complex_types())
def test_scatter_reduce_multiply_unsupported_dtypes(self, device, dtype):
height = 2
width = 2
Expand Down

0 comments on commit 4d41015

Please sign in to comment.