From 65dbb6cb0e3bd0ef398bed00b5a707a69ae23197 Mon Sep 17 00:00:00 2001 From: bhack Date: Thu, 9 May 2024 16:21:04 +0000 Subject: [PATCH 01/11] extend nonzero to int64 --- aten/src/ATen/native/cuda/Nonzero.cu | 177 +++++++++++++++------------ 1 file changed, 102 insertions(+), 75 deletions(-) diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index 5d62f7711d104..af0770ea7be98 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -30,7 +30,7 @@ struct NonZeroOp //TODO: actually support int64_t index_t template struct TensorDims { - index_t sizes[MAX_DIMS]; + index_t sizes[MAX_DIMS]; }; template @@ -39,92 +39,119 @@ __global__ void write_indices( TensorDims dims, int ndim, index_t n) { - auto index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < n) { - index_t div = 1; - int64_t idx_flat = inp[index]; + auto index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < n) { + index_t div = 1; + int64_t idx_flat = inp[index]; #pragma unroll - for (int dim = MAX_DIMS; dim >= 0; dim--) { - if (dim > ndim - 1) - continue; - auto dim_size = dims.sizes[dim]; - inp[index + dim * n] = (idx_flat / div) % dim_size; - div *= dim_size; + for (int dim = MAX_DIMS; dim >= 0; dim--) { + if (dim > ndim - 1) + continue; + auto dim_size = dims.sizes[dim]; + inp[index + dim * n] = (idx_flat / div) % dim_size; + div *= dim_size; + } } - } } -} //anonymous namespace +// Temporary wrapper for DeviceSelect::If from https://github.com/NVIDIA/cccl/pull/1379 +// until CCCL https://github.com/NVIDIA/cccl/issues/1422 is resolved +template +CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t dispatch_select_if_wrapper( + void* d_temp_storage, + std::size_t& temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + OffsetT num_items, + SelectOp select_op, + cudaStream_t stream = 0) +{ + using flag_iterator_t = cub::NullType*; + using equality_op_t = cub::NullType; + + return cub::DispatchSelectIf< + InputIteratorT, + flag_iterator_t, + OutputIteratorT, + NumSelectedIteratorT, + SelectOp, + equality_op_t, + OffsetT, + false>::Dispatch(d_temp_storage, + temp_storage_bytes, + d_in, + nullptr, + d_out, + d_num_selected_out, + select_op, + equality_op_t{}, + num_items, + stream); +} + +DECLARE_LAUNCH_WRAPPER(cub::DeviceSelect::If, select_if); +DECLARE_LAUNCH_WRAPPER(dispatch_select_if_wrapper, dispatch_select_if); + +} // anonymous namespace template -void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){ - Tensor self_ = self.contiguous(); - int N = self_.numel(); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { + Tensor self_ = self.contiguous(); + int64_t N = self_.numel(); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // compute number of nonzero elements size_t temp_storage_bytes=0; - auto& allocator = *c10::cuda::CUDACachingAllocator::get(); - auto num_nonzeros = allocator.allocate(sizeof(int)); - cub::TransformInputIterator, const scalar_t*> itr(self_.const_data_ptr(), NonZeroOp()); - cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream); - auto temp_storage = allocator.allocate(temp_storage_bytes); - cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream); - int num_nonzeros_h; - at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream); - //expected output size is num_nonzeros x ndim - //we are producing output with size {num_nonzeros, ndim} and strides {1, num_nonzeros} (that is, transposed ndim x num_nonzeros output) - //we are able to directly use passed output with this size and strides, and we can also (per contract) - //resize passed output with incorrect sizes anyway we want. - //However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced. - bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous(); - at::Tensor out_temp = need_to_copy ? - Tensor(at::detail::empty_cuda({self.dim(), num_nonzeros_h}, out.options())) : - out.resize_({self.dim(), num_nonzeros_h}); - //Scalars are expected to produce output of size (1,0), so we can't write to it - if (self.dim() > 0) { - cub::CountingInputIterator counting_itr(0); - temp_storage_bytes = 0; - cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr, - out_temp.mutable_data_ptr(), (int*)num_nonzeros.get(), N, stream); - temp_storage = allocator.allocate(temp_storage_bytes); - cub::DeviceSelect::Flagged(temp_storage.get(), temp_storage_bytes, counting_itr, itr, - out_temp.mutable_data_ptr(), (int*)num_nonzeros.get(), N, stream); - if (num_nonzeros_h > 0 && self.dim() > 1){ - TensorDims dims; - for (int i=0; i, const scalar_t*> itr(self_.const_data_ptr(), NonZeroOp()); + cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, static_cast(num_nonzeros.get()), N, stream); + auto temp_storage = allocator.allocate(temp_storage_bytes); + cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, static_cast(num_nonzeros.get()), N, stream); + int64_t num_nonzeros_h; + at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int64_t), cudaMemcpyDeviceToHost, stream); + bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous(); + at::Tensor out_temp = need_to_copy ? Tensor(at::detail::empty_cuda({self.dim(), num_nonzeros_h}, out.options())) : out.resize_({self.dim(), num_nonzeros_h}); + if (self.dim() > 0) { + cub::CountingInputIterator counting_itr(0); + temp_storage_bytes = 0; + dispatch_select_if_wrapper(nullptr, temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr(), static_cast(num_nonzeros.get()), N, NonZeroOp(), stream); + temp_storage = allocator.allocate(temp_storage_bytes); + dispatch_select_if_wrapper(temp_storage.get(), temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr(), static_cast(num_nonzeros.get()), N, NonZeroOp(), stream); + if (num_nonzeros_h > 0 && self.dim() > 1){ + TensorDims dims; + for (int i=0; i < self.dim(); i++){ + dims.sizes[i] = self.sizes()[i]; + } + const int nthreads = 256; + const int nblocks = (num_nonzeros_h + nthreads - 1)/nthreads; + write_indices<<>>(out_temp.mutable_data_ptr(), dims, self.dim(), num_nonzeros_h); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - const int nthreads = 256; - const int nblocks = (num_nonzeros_h + nthreads -1)/nthreads; - write_indices<<>>(out_temp.mutable_data_ptr(), - dims, self.dim(), num_nonzeros_h); - C10_CUDA_KERNEL_LAUNCH_CHECK(); } - } - if (need_to_copy) { - out.copy_(out_temp.t()); - } else { - //transpose out so it is correct size - Tensor out_ = out_temp.t(); - out.set_(out_); - } + if (need_to_copy) { + out.copy_(out_temp.t()); + } else { + Tensor out_ = out_temp.t(); + out.set_(out_); + } } -Tensor& nonzero_out_cuda(const Tensor& self, Tensor& out){ - TORCH_CHECK(self.numel() < std::numeric_limits::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \ - file a support request"); - TORCH_CHECK(out.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out.dtype()); - TORCH_CHECK(self.device() == out.device(), "expected self and out to be on the same device, but got out on ", - out.device(), " and self on ", self.device()); - TORCH_CHECK(self.dim() <= MAX_DIMS, "nonzero is not supported for tensor with more than ", MAX_DIMS, " dimensions"); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, - self.scalar_type(), "nonzero_cuda", - [&] {nonzero_cuda_out_impl(self, out);}); - return out; +Tensor& nonzero_out_cuda(const Tensor& self, Tensor& out) { + TORCH_CHECK(self.numel() < std::numeric_limits::max(), "nonzero is not supported for tensors with more than INT64_MAX elements."); + TORCH_CHECK(out.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out.dtype()); + TORCH_CHECK(self.device() == out.device(), "Expected self and out to be on the same device, but got out on ", out.device(), " and self on ", self.device()); + TORCH_CHECK(self.dim() <= MAX_DIMS, "nonzero is not supported for tensor with more than ", MAX_DIMS, " dimensions"); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, self.scalar_type(), "nonzero_cuda", [&] {nonzero_cuda_out_impl(self, out);}); + return out; } -Tensor nonzero_cuda(const Tensor& self){ - Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong)); - return at::native::nonzero_out_cuda(self, out); +Tensor nonzero_cuda(const Tensor& self) { + Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong)); + return at::native::nonzero_out_cuda(self, out); } -} //namespace at::native +} // namespace at::native From 8cce2c78a71112c8601d57b76b195104835b5e93 Mon Sep 17 00:00:00 2001 From: bhack Date: Thu, 9 May 2024 16:38:11 +0000 Subject: [PATCH 02/11] Reintroduce inline comments --- aten/src/ATen/native/cuda/Nonzero.cu | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index af0770ea7be98..2b8bad0f18db1 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -114,8 +114,14 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, static_cast(num_nonzeros.get()), N, stream); int64_t num_nonzeros_h; at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int64_t), cudaMemcpyDeviceToHost, stream); + //expected output size is num_nonzeros x ndim + //we are producing output with size {num_nonzeros, ndim} and strides {1, num_nonzeros} (that is, transposed ndim x num_nonzeros output) + //we are able to directly use passed output with this size and strides, and we can also (per contract) + //resize passed output with incorrect sizes anyway we want. + //However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced. bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous(); at::Tensor out_temp = need_to_copy ? Tensor(at::detail::empty_cuda({self.dim(), num_nonzeros_h}, out.options())) : out.resize_({self.dim(), num_nonzeros_h}); + //Scalars are expected to produce output of size (1,0), so we can't write to it if (self.dim() > 0) { cub::CountingInputIterator counting_itr(0); temp_storage_bytes = 0; From a2554de9e89601c7f1f8be9565e24609742ecb1e Mon Sep 17 00:00:00 2001 From: bhack Date: Thu, 9 May 2024 21:53:34 +0200 Subject: [PATCH 03/11] Remove macro --- aten/src/ATen/native/cuda/Nonzero.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index 2b8bad0f18db1..96a81ed078698 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -94,8 +94,6 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t dispatch_select_if_wra stream); } -DECLARE_LAUNCH_WRAPPER(cub::DeviceSelect::If, select_if); -DECLARE_LAUNCH_WRAPPER(dispatch_select_if_wrapper, dispatch_select_if); } // anonymous namespace From 2362079b3bf152ccdc2eb0bc48fcc387b5aae1ec Mon Sep 17 00:00:00 2001 From: bhack Date: Fri, 10 May 2024 00:06:13 +0200 Subject: [PATCH 04/11] Remove typo --- aten/src/ATen/native/cuda/Nonzero.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index 96a81ed078698..bf9e8779ff69b 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -61,7 +61,7 @@ template -CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t dispatch_select_if_wrapper( +static cudaError_t dispatch_select_if_wrapper( void* d_temp_storage, std::size_t& temp_storage_bytes, InputIteratorT d_in, From 0841181c973172fafd15eb13b14f5d20de4a1bfd Mon Sep 17 00:00:00 2001 From: bhack Date: Fri, 10 May 2024 02:37:07 +0200 Subject: [PATCH 05/11] Refactor --- aten/src/ATen/native/cuda/Nonzero.cu | 60 ++++++++++++++-------------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index bf9e8779ff69b..7d24b662b72ec 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -15,7 +15,6 @@ #include #endif - namespace at::native { namespace{ @@ -30,7 +29,7 @@ struct NonZeroOp //TODO: actually support int64_t index_t template struct TensorDims { - index_t sizes[MAX_DIMS]; + index_t sizes[MAX_DIMS]; }; template @@ -39,19 +38,19 @@ __global__ void write_indices( TensorDims dims, int ndim, index_t n) { - auto index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < n) { - index_t div = 1; - int64_t idx_flat = inp[index]; + auto index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < n) { + index_t div = 1; + int64_t idx_flat = inp[index]; #pragma unroll - for (int dim = MAX_DIMS; dim >= 0; dim--) { - if (dim > ndim - 1) - continue; - auto dim_size = dims.sizes[dim]; - inp[index + dim * n] = (idx_flat / div) % dim_size; - div *= dim_size; - } + for (int dim = MAX_DIMS; dim >= 0; dim--) { + if (dim > ndim - 1) + continue; + auto dim_size = dims.sizes[dim]; + inp[index + dim * n] = (idx_flat / div) % dim_size; + div *= dim_size; } + } } // Temporary wrapper for DeviceSelect::If from https://github.com/NVIDIA/cccl/pull/1379 @@ -92,9 +91,6 @@ static cudaError_t dispatch_select_if_wrapper( equality_op_t{}, num_items, stream); -} - - } // anonymous namespace template @@ -102,38 +98,43 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { Tensor self_ = self.contiguous(); int64_t N = self_.numel(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); -// compute number of nonzero elements - size_t temp_storage_bytes=0; + + size_t temp_storage_bytes = 0; auto& allocator = *c10::cuda::CUDACachingAllocator::get(); auto num_nonzeros = allocator.allocate(sizeof(int64_t)); + cub::TransformInputIterator, const scalar_t*> itr(self_.const_data_ptr(), NonZeroOp()); cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, static_cast(num_nonzeros.get()), N, stream); auto temp_storage = allocator.allocate(temp_storage_bytes); cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, static_cast(num_nonzeros.get()), N, stream); + int64_t num_nonzeros_h; at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int64_t), cudaMemcpyDeviceToHost, stream); - //expected output size is num_nonzeros x ndim - //we are producing output with size {num_nonzeros, ndim} and strides {1, num_nonzeros} (that is, transposed ndim x num_nonzeros output) - //we are able to directly use passed output with this size and strides, and we can also (per contract) - //resize passed output with incorrect sizes anyway we want. - //However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced. + //expected output size is num_nonzeros x ndim + //we are producing output with size {num_nonzeros, ndim} and strides {1, num_nonzeros} (that is, transposed ndim x num_nonzeros output) + //we are able to directly use passed output with this size and strides, and we can also (per contract) + //resize passed output with incorrect sizes anyway we want. + //However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced. bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous(); - at::Tensor out_temp = need_to_copy ? Tensor(at::detail::empty_cuda({self.dim(), num_nonzeros_h}, out.options())) : out.resize_({self.dim(), num_nonzeros_h}); - //Scalars are expected to produce output of size (1,0), so we can't write to it + at::Tensor out_temp = need_to_copy ? + Tensor(at::detail::empty_cuda({self.dim(), num_nonzeros_h}, out.options())) : out.resize_({self.dim(), num_nonzeros_h}); + //Scalars are expected to produce output of size (1,0), so we can't write to it if (self.dim() > 0) { cub::CountingInputIterator counting_itr(0); temp_storage_bytes = 0; dispatch_select_if_wrapper(nullptr, temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr(), static_cast(num_nonzeros.get()), N, NonZeroOp(), stream); temp_storage = allocator.allocate(temp_storage_bytes); dispatch_select_if_wrapper(temp_storage.get(), temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr(), static_cast(num_nonzeros.get()), N, NonZeroOp(), stream); - if (num_nonzeros_h > 0 && self.dim() > 1){ + + if (num_nonzeros_h > 0 && self.dim() > 1) { TensorDims dims; - for (int i=0; i < self.dim(); i++){ + for (int i = 0; i < self.dim(); i++) { dims.sizes[i] = self.sizes()[i]; } const int nthreads = 256; - const int nblocks = (num_nonzeros_h + nthreads - 1)/nthreads; - write_indices<<>>(out_temp.mutable_data_ptr(), dims, self.dim(), num_nonzeros_h); + const int nblocks = (num_nonzeros_h + nthreads - 1) / nthreads; + write_indices<<>>(out_temp.mutable_data_ptr(), + dims, self.dim(), num_nonzeros_h); C10_CUDA_KERNEL_LAUNCH_CHECK(); } } @@ -158,4 +159,5 @@ Tensor nonzero_cuda(const Tensor& self) { Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong)); return at::native::nonzero_out_cuda(self, out); } + } // namespace at::native From a47573bac609334520400ffcd8a8d34bf6989a9d Mon Sep 17 00:00:00 2001 From: bhack Date: Fri, 10 May 2024 15:05:32 +0000 Subject: [PATCH 06/11] Fix typo --- aten/src/ATen/native/cuda/Nonzero.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index 7d24b662b72ec..917e381994d2f 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -91,6 +91,7 @@ static cudaError_t dispatch_select_if_wrapper( equality_op_t{}, num_items, stream); +} } // anonymous namespace template From dc29c0f2f80f2f58c4dc49542f4a9bbb83a49884 Mon Sep 17 00:00:00 2001 From: bhack Date: Fri, 10 May 2024 21:19:48 +0200 Subject: [PATCH 07/11] Add complex template --- aten/src/ATen/native/cuda/Nonzero.cu | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index 917e381994d2f..694b702182c40 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -26,6 +26,14 @@ struct NonZeroOp } }; +// Specialization for complex types +template +struct NonZeroOp> { + __host__ __device__ __forceinline__ bool operator()(const c10::complex& a) const { + return (a.real() != T(0)) || (a.imag() != T(0)); + } +}; + //TODO: actually support int64_t index_t template struct TensorDims { @@ -99,34 +107,30 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { Tensor self_ = self.contiguous(); int64_t N = self_.numel(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - size_t temp_storage_bytes = 0; auto& allocator = *c10::cuda::CUDACachingAllocator::get(); auto num_nonzeros = allocator.allocate(sizeof(int64_t)); - cub::TransformInputIterator, const scalar_t*> itr(self_.const_data_ptr(), NonZeroOp()); cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, static_cast(num_nonzeros.get()), N, stream); auto temp_storage = allocator.allocate(temp_storage_bytes); cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, static_cast(num_nonzeros.get()), N, stream); - int64_t num_nonzeros_h; at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int64_t), cudaMemcpyDeviceToHost, stream); //expected output size is num_nonzeros x ndim //we are producing output with size {num_nonzeros, ndim} and strides {1, num_nonzeros} (that is, transposed ndim x num_nonzeros output) //we are able to directly use passed output with this size and strides, and we can also (per contract) //resize passed output with incorrect sizes anyway we want. - //However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced. + //However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced. bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous(); at::Tensor out_temp = need_to_copy ? Tensor(at::detail::empty_cuda({self.dim(), num_nonzeros_h}, out.options())) : out.resize_({self.dim(), num_nonzeros_h}); - //Scalars are expected to produce output of size (1,0), so we can't write to it + //Scalars are expected to produce output of size (1,0), so we can't write to it if (self.dim() > 0) { cub::CountingInputIterator counting_itr(0); temp_storage_bytes = 0; dispatch_select_if_wrapper(nullptr, temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr(), static_cast(num_nonzeros.get()), N, NonZeroOp(), stream); temp_storage = allocator.allocate(temp_storage_bytes); dispatch_select_if_wrapper(temp_storage.get(), temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr(), static_cast(num_nonzeros.get()), N, NonZeroOp(), stream); - if (num_nonzeros_h > 0 && self.dim() > 1) { TensorDims dims; for (int i = 0; i < self.dim(); i++) { @@ -160,5 +164,4 @@ Tensor nonzero_cuda(const Tensor& self) { Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong)); return at::native::nonzero_out_cuda(self, out); } - } // namespace at::native From 58eeb415d93d09784cffd7418aeb5387862dcaac Mon Sep 17 00:00:00 2001 From: bhack Date: Sat, 11 May 2024 00:18:07 +0200 Subject: [PATCH 08/11] Reformat code --- aten/src/ATen/native/cuda/Nonzero.cu | 192 ++++++++++++++------------- 1 file changed, 97 insertions(+), 95 deletions(-) diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index 694b702182c40..850ad3d2f2be8 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -15,54 +15,11 @@ #include #endif -namespace at::native { - -namespace{ -template -struct NonZeroOp -{ - __host__ __device__ __forceinline__ bool operator()(const T& a) const { - return (a!=T(0)); - } -}; - -// Specialization for complex types -template -struct NonZeroOp> { - __host__ __device__ __forceinline__ bool operator()(const c10::complex& a) const { - return (a.real() != T(0)) || (a.imag() != T(0)); - } -}; -//TODO: actually support int64_t index_t -template -struct TensorDims { - index_t sizes[MAX_DIMS]; -}; - -template -__global__ void write_indices( - int64_t* inp, - TensorDims dims, - int ndim, - index_t n) { - auto index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < n) { - index_t div = 1; - int64_t idx_flat = inp[index]; -#pragma unroll - for (int dim = MAX_DIMS; dim >= 0; dim--) { - if (dim > ndim - 1) - continue; - auto dim_size = dims.sizes[dim]; - inp[index + dim * n] = (idx_flat / div) % dim_size; - div *= dim_size; - } - } -} +namespace at::native { -// Temporary wrapper for DeviceSelect::If from https://github.com/NVIDIA/cccl/pull/1379 -// until CCCL https://github.com/NVIDIA/cccl/issues/1422 is resolved +namespace { +// Wrapper for DeviceSelect::If to handle tensors larger than INT_MAX template +struct NonZeroOp +{ + __host__ __device__ __forceinline__ bool operator()(const T& a) const { + return (a != T(0)); + } +}; + +//TODO: actually support int64_t index_t +template +struct TensorDims { + index_t sizes[MAX_DIMS]; +}; + +template +__global__ void write_indices( + int64_t* inp, + TensorDims dims, + int ndim, + index_t n) { + auto index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < n) { + index_t div = 1; + int64_t idx_flat = inp[index]; + #pragma unroll + for (int dim = MAX_DIMS; dim >= 0; dim--) { + if (dim > ndim - 1) + continue; + auto dim_size = dims.sizes[dim]; + inp[index + dim * n] = (idx_flat / div) % dim_size; + div *= dim_size; + } + } +} + +} //anonymous namespace template void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { - Tensor self_ = self.contiguous(); - int64_t N = self_.numel(); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - size_t temp_storage_bytes = 0; - auto& allocator = *c10::cuda::CUDACachingAllocator::get(); - auto num_nonzeros = allocator.allocate(sizeof(int64_t)); - cub::TransformInputIterator, const scalar_t*> itr(self_.const_data_ptr(), NonZeroOp()); - cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, static_cast(num_nonzeros.get()), N, stream); - auto temp_storage = allocator.allocate(temp_storage_bytes); - cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, static_cast(num_nonzeros.get()), N, stream); - int64_t num_nonzeros_h; - at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int64_t), cudaMemcpyDeviceToHost, stream); + Tensor self_ = self.contiguous(); + int64_t N = self_.numel(); // Changed to int64_t to handle larger sizes + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + //Compute number of nonzero elements + size_t temp_storage_bytes = 0; + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); + auto num_nonzeros = allocator.allocate(sizeof(int)); + cub::TransformInputIterator, const scalar_t*> itr(self_.const_data_ptr(), NonZeroOp()); + cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream); + auto temp_storage = allocator.allocate(temp_storage_bytes); + cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream); + int num_nonzeros_h; + at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream); //expected output size is num_nonzeros x ndim //we are producing output with size {num_nonzeros, ndim} and strides {1, num_nonzeros} (that is, transposed ndim x num_nonzeros output) //we are able to directly use passed output with this size and strides, and we can also (per contract) //resize passed output with incorrect sizes anyway we want. //However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced. - bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous(); - at::Tensor out_temp = need_to_copy ? - Tensor(at::detail::empty_cuda({self.dim(), num_nonzeros_h}, out.options())) : out.resize_({self.dim(), num_nonzeros_h}); - //Scalars are expected to produce output of size (1,0), so we can't write to it - if (self.dim() > 0) { - cub::CountingInputIterator counting_itr(0); - temp_storage_bytes = 0; - dispatch_select_if_wrapper(nullptr, temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr(), static_cast(num_nonzeros.get()), N, NonZeroOp(), stream); - temp_storage = allocator.allocate(temp_storage_bytes); - dispatch_select_if_wrapper(temp_storage.get(), temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr(), static_cast(num_nonzeros.get()), N, NonZeroOp(), stream); - if (num_nonzeros_h > 0 && self.dim() > 1) { - TensorDims dims; - for (int i = 0; i < self.dim(); i++) { - dims.sizes[i] = self.sizes()[i]; - } - const int nthreads = 256; - const int nblocks = (num_nonzeros_h + nthreads - 1) / nthreads; - write_indices<<>>(out_temp.mutable_data_ptr(), - dims, self.dim(), num_nonzeros_h); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - } - if (need_to_copy) { - out.copy_(out_temp.t()); - } else { - Tensor out_ = out_temp.t(); - out.set_(out_); + bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous(); + at::Tensor out_temp = need_to_copy ? + Tensor(at::detail::empty_cuda({self.dim(), num_nonzeros_h}, out.options())) : + out.resize_({self.dim(), num_nonzeros_h}); + //Scalars are expected to produce output of size (1,0), so we can't write to it + if (self.dim() > 0) { + cub::CountingInputIterator counting_itr(0); + temp_storage_bytes = 0; + dispatch_select_if_wrapper(nullptr, temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr(), + (int*)num_nonzeros.get(), N, NonZeroOp(), stream); + temp_storage = allocator.allocate(temp_storage_bytes); + dispatch_select_if_wrapper(temp_storage.get(), temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr(), + (int*)num_nonzeros.get(), N, NonZeroOp(), stream); + if (num_nonzeros_h > 0 && self.dim() > 1) { + TensorDims dims; + for (int i = 0; i < self.dim(); i++) { + dims.sizes[i] = self.sizes()[i]; + } + const int nthreads = 256; + const int nblocks = (num_nonzeros_h + nthreads - 1) / nthreads; + write_indices<<>>(out_temp.mutable_data_ptr(), + dims, self.dim(), num_nonzeros_h); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } + } + if (need_to_copy) { + out.copy_(out_temp.t()); + } else { + //transpose out so it is correct size + Tensor out_ = out_temp.t(); + out.set_(out_); + } } Tensor& nonzero_out_cuda(const Tensor& self, Tensor& out) { - TORCH_CHECK(self.numel() < std::numeric_limits::max(), "nonzero is not supported for tensors with more than INT64_MAX elements."); - TORCH_CHECK(out.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out.dtype()); - TORCH_CHECK(self.device() == out.device(), "Expected self and out to be on the same device, but got out on ", out.device(), " and self on ", self.device()); - TORCH_CHECK(self.dim() <= MAX_DIMS, "nonzero is not supported for tensor with more than ", MAX_DIMS, " dimensions"); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, self.scalar_type(), "nonzero_cuda", [&] {nonzero_cuda_out_impl(self, out);}); - return out; + TORCH_CHECK(self.numel() < std::numeric_limits::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \ + See https://github.com/pytorch/pytorch/issues/51871"); + TORCH_CHECK(out.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out.dtype()); + TORCH_CHECK(self.device() == out.device(), "expected self and out to be on the same device, but got out on ", + out.device(), " and self on ", self.device()); + TORCH_CHECK(self.dim() <= MAX_DIMS, "nonzero is not supported for tensor with more than ", MAX_DIMS, " dimensions"); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, + self.scalar_type(), "nonzero_cuda", + [&] {nonzero_cuda_out_impl(self, out);}); + return out; } Tensor nonzero_cuda(const Tensor& self) { - Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong)); - return at::native::nonzero_out_cuda(self, out); + Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong)); + return at::native::nonzero_out_cuda(self, out); } -} // namespace at::native +} //namespace at::native From 052bb6100217375592abc15ea3d6340dfddb5838 Mon Sep 17 00:00:00 2001 From: bhack Date: Sat, 11 May 2024 00:20:18 +0200 Subject: [PATCH 09/11] Add extra comments --- aten/src/ATen/native/cuda/Nonzero.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index 850ad3d2f2be8..796406ef5d6c6 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -20,6 +20,8 @@ namespace at::native { namespace { // Wrapper for DeviceSelect::If to handle tensors larger than INT_MAX +// Imported from https://github.com/NVIDIA/cccl/pull/1379 +// #todo Remove the wrapper when https://github.com/NVIDIA/cccl/issues/1422 is released template Date: Sat, 11 May 2024 00:21:15 +0200 Subject: [PATCH 10/11] Format TODO --- aten/src/ATen/native/cuda/Nonzero.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index 796406ef5d6c6..c79882ca99958 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -21,7 +21,7 @@ namespace at::native { namespace { // Wrapper for DeviceSelect::If to handle tensors larger than INT_MAX // Imported from https://github.com/NVIDIA/cccl/pull/1379 -// #todo Remove the wrapper when https://github.com/NVIDIA/cccl/issues/1422 is released +// #TODO: Remove the wrapper when https://github.com/NVIDIA/cccl/issues/1422 is released template Date: Mon, 13 May 2024 13:15:50 +0200 Subject: [PATCH 11/11] Add bool op for complex --- aten/src/ATen/native/cuda/Nonzero.cu | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index c79882ca99958..46f1b235c9a92 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -68,6 +68,14 @@ struct NonZeroOp } }; +template +struct NonZeroOp> +{ + __host__ __device__ __forceinline__ bool operator()(const c10::complex& a) const { + return (a.real() != T(0) || a.imag() != T(0)); + } +}; + //TODO: actually support int64_t index_t template struct TensorDims { @@ -105,7 +113,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { //Compute number of nonzero elements size_t temp_storage_bytes = 0; auto& allocator = *c10::cuda::CUDACachingAllocator::get(); - auto num_nonzeros = allocator.allocate(sizeof(int)); + auto num_nonzeros = allocator.allocate(sizeof(int)); cub::TransformInputIterator, const scalar_t*> itr(self_.const_data_ptr(), NonZeroOp()); cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream); auto temp_storage = allocator.allocate(temp_storage_bytes); @@ -126,7 +134,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { cub::CountingInputIterator counting_itr(0); temp_storage_bytes = 0; dispatch_select_if_wrapper(nullptr, temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr(), - (int*)num_nonzeros.get(), N, NonZeroOp(), stream); + (int*)num_nonzeros.get(), N, NonZeroOp(), stream); temp_storage = allocator.allocate(temp_storage_bytes); dispatch_select_if_wrapper(temp_storage.get(), temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr(), (int*)num_nonzeros.get(), N, NonZeroOp(), stream); @@ -141,7 +149,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { dims, self.dim(), num_nonzeros_h); C10_CUDA_KERNEL_LAUNCH_CHECK(); } - } + } if (need_to_copy) { out.copy_(out_temp.t()); } else {