diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index e87f46cd844e..46f1b235c9a9 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -18,19 +18,68 @@ namespace at::native { -namespace{ +namespace { +// Wrapper for DeviceSelect::If to handle tensors larger than INT_MAX +// Imported from https://github.com/NVIDIA/cccl/pull/1379 +// #TODO: Remove the wrapper when https://github.com/NVIDIA/cccl/issues/1422 is released +template +static cudaError_t dispatch_select_if_wrapper( + void* d_temp_storage, + std::size_t& temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + OffsetT num_items, + SelectOp select_op, + cudaStream_t stream = 0) +{ + using flag_iterator_t = cub::NullType*; + using equality_op_t = cub::NullType; + + return cub::DispatchSelectIf< + InputIteratorT, + flag_iterator_t, + OutputIteratorT, + NumSelectedIteratorT, + SelectOp, + equality_op_t, + OffsetT, + false>::Dispatch(d_temp_storage, + temp_storage_bytes, + d_in, + nullptr, + d_out, + d_num_selected_out, + select_op, + equality_op_t{}, + num_items, + stream); +} + template struct NonZeroOp { __host__ __device__ __forceinline__ bool operator()(const T& a) const { - return (a!=T(0)); + return (a != T(0)); + } +}; + +template +struct NonZeroOp> +{ + __host__ __device__ __forceinline__ bool operator()(const c10::complex& a) const { + return (a.real() != T(0) || a.imag() != T(0)); } }; //TODO: actually support int64_t index_t template struct TensorDims { - index_t sizes[MAX_DIMS]; + index_t sizes[MAX_DIMS]; }; template @@ -43,7 +92,7 @@ __global__ void write_indices( if (index < n) { index_t div = 1; int64_t idx_flat = inp[index]; -#pragma unroll + #pragma unroll for (int dim = MAX_DIMS; dim >= 0; dim--) { if (dim > ndim - 1) continue; @@ -57,12 +106,12 @@ __global__ void write_indices( } //anonymous namespace template -void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){ +void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { Tensor self_ = self.contiguous(); - int N = self_.numel(); + int64_t N = self_.numel(); // Changed to int64_t to handle larger sizes const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); -// compute number of nonzero elements - size_t temp_storage_bytes=0; + //Compute number of nonzero elements + size_t temp_storage_bytes = 0; auto& allocator = *c10::cuda::CUDACachingAllocator::get(); auto num_nonzeros = allocator.allocate(sizeof(int)); cub::TransformInputIterator, const scalar_t*> itr(self_.const_data_ptr(), NonZeroOp()); @@ -84,21 +133,21 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){ if (self.dim() > 0) { cub::CountingInputIterator counting_itr(0); temp_storage_bytes = 0; - cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr, - out_temp.mutable_data_ptr(), (int*)num_nonzeros.get(), N, stream); + dispatch_select_if_wrapper(nullptr, temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr(), + (int*)num_nonzeros.get(), N, NonZeroOp(), stream); temp_storage = allocator.allocate(temp_storage_bytes); - cub::DeviceSelect::Flagged(temp_storage.get(), temp_storage_bytes, counting_itr, itr, - out_temp.mutable_data_ptr(), (int*)num_nonzeros.get(), N, stream); - if (num_nonzeros_h > 0 && self.dim() > 1){ - TensorDims dims; - for (int i=0; i>>(out_temp.mutable_data_ptr(), - dims, self.dim(), num_nonzeros_h); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + dispatch_select_if_wrapper(temp_storage.get(), temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr(), + (int*)num_nonzeros.get(), N, NonZeroOp(), stream); + if (num_nonzeros_h > 0 && self.dim() > 1) { + TensorDims dims; + for (int i = 0; i < self.dim(); i++) { + dims.sizes[i] = self.sizes()[i]; + } + const int nthreads = 256; + const int nblocks = (num_nonzeros_h + nthreads - 1) / nthreads; + write_indices<<>>(out_temp.mutable_data_ptr(), + dims, self.dim(), num_nonzeros_h); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } if (need_to_copy) { @@ -110,20 +159,20 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){ } } -Tensor& nonzero_out_cuda(const Tensor& self, Tensor& out){ - TORCH_CHECK(self.numel() < std::numeric_limits::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \ +Tensor& nonzero_out_cuda(const Tensor& self, Tensor& out) { + TORCH_CHECK(self.numel() < std::numeric_limits::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \ See https://github.com/pytorch/pytorch/issues/51871"); TORCH_CHECK(out.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out.dtype()); TORCH_CHECK(self.device() == out.device(), "expected self and out to be on the same device, but got out on ", out.device(), " and self on ", self.device()); TORCH_CHECK(self.dim() <= MAX_DIMS, "nonzero is not supported for tensor with more than ", MAX_DIMS, " dimensions"); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, - self.scalar_type(), "nonzero_cuda", - [&] {nonzero_cuda_out_impl(self, out);}); + self.scalar_type(), "nonzero_cuda", + [&] {nonzero_cuda_out_impl(self, out);}); return out; } -Tensor nonzero_cuda(const Tensor& self){ +Tensor nonzero_cuda(const Tensor& self) { Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong)); return at::native::nonzero_out_cuda(self, out); }