diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index e5fb9230de76..ad41292d3b46 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -55,6 +55,34 @@ __global__ void write_indices( } } +template +__global__ void write_indices_static( + int64_t* inp, + TensorDims dims, + int ndim, + const int* num_nonzeros_ptr, + int n, + int64_t fill_value) { + auto index = threadIdx.x + blockIdx.x * blockDim.x; + index_t num_nonzeros = *num_nonzeros_ptr; + if (index < n) { + index_t div = 1; + int64_t idx_flat = inp[index]; +#pragma unroll + for (int dim = MAX_DIMS; dim >= 0; dim--) { + if (dim > ndim - 1) + continue; + auto dim_size = dims.sizes[dim]; + if (index < num_nonzeros) { + inp[index + dim * n] = (idx_flat / div) % dim_size; + } else { + inp[index + dim * n] = fill_value; + } + div *= dim_size; + } + } +} + } //anonymous namespace template @@ -137,4 +165,132 @@ Tensor nonzero_cuda(const Tensor& self){ Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong)); return at::native::nonzero_out_cuda(self, out); } + +template +void nonzero_static_cuda_out_impl( + const Tensor& self, + int size, + int64_t fill_value, + Tensor& out) { + Tensor self_ = self.contiguous(); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // compute number of nonzero elements + size_t temp_storage_bytes = 0; + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); + auto num_nonzeros = allocator.allocate(sizeof(int)); + cub::TransformInputIterator, const scalar_t*> itr( + self_.const_data_ptr(), NonZeroOp()); + // expected output size is num_nonzeros x ndim + // we are producing output with size {num_nonzeros, ndim} and strides {1, + // num_nonzeros} (that is, transposed ndim x num_nonzeros output) we are able + // to directly use passed output with this size and strides, and we can also + // (per contract) resize passed output with incorrect sizes anyway we want. + // However, out with correct sizes and incorrect strides will have to be + // copied to from the intermediate we've produced. + bool need_to_copy = out.dim() == 2 && out.sizes()[0] == size && + out.sizes()[1] == self.dim() && !out.t().is_contiguous(); + at::Tensor out_temp = need_to_copy + ? Tensor(at::detail::empty_cuda({self.dim(), size}, out.options())) + : out.resize_({self.dim(), size}); + // Scalars are expected to produce output of size (1,0), so we can't write to + // it + if (size > 0 && self.dim() > 0) { + cub::CountingInputIterator counting_itr(0); + temp_storage_bytes = 0; + cub::DeviceSelect::Flagged( + nullptr, + temp_storage_bytes, + counting_itr, + itr, + out_temp.mutable_data_ptr(), + (int*)num_nonzeros.get(), + std::min(size, (int)self_.numel()), + stream); + auto temp_storage = allocator.allocate(temp_storage_bytes); + cub::DeviceSelect::Flagged( + temp_storage.get(), + temp_storage_bytes, + counting_itr, + itr, + out_temp.mutable_data_ptr(), + (int*)num_nonzeros.get(), + std::min(size, (int)self_.numel()), + stream); + TensorDims dims; + for (int i = 0; i < self.dim(); i++) { + dims.sizes[i] = self.sizes()[i]; + } + const int nthreads = 256; + // if size is 0, then the launch will fail. + const int nblocks = (size + nthreads - 1) / nthreads; + write_indices_static<<>>( + out_temp.mutable_data_ptr(), + dims, + self.dim(), + (int*)num_nonzeros.get(), + size, + fill_value); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + if (need_to_copy) { + out.copy_(out_temp.t()); + } else { + // transpose out so it is correct size + Tensor out_ = out_temp.t(); + out.set_(out_); + } +} + +Tensor& nonzero_static_out_cuda( + const Tensor& self, + int64_t size, + int64_t fill_value, + Tensor& out) { + TORCH_CHECK( + self.numel() < std::numeric_limits::max(), + "nonzero_static is not supported for tensors with more than INT_MAX elements, \ + See https://github.com/pytorch/pytorch/issues/51871"); + TORCH_CHECK( + size < std::numeric_limits::max(), + "nonzero_static is not supported for output tensors with more than INT_MAX elements, \ + See https://github.com/pytorch/pytorch/issues/51871"); + TORCH_CHECK( + out.dtype() == at::kLong, + "Expected object of scalar type ", + at::kLong, + " as out, but got ", + out.dtype()); + TORCH_CHECK( + self.device() == out.device(), + "expected self and out to be on the same device, but got out on ", + out.device(), + " and self on ", + self.device()); + TORCH_CHECK( + self.dim() <= MAX_DIMS, + "nonzero is not supported for tensor with more than ", + MAX_DIMS, + " dimensions"); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + at::ScalarType::ComplexHalf, + at::ScalarType::Bool, + at::ScalarType::BFloat16, + at::ScalarType::Half, + self.scalar_type(), + "nonzero_static_cuda", + [&] { + nonzero_static_cuda_out_impl(self, size, fill_value, out); + }); + return out; +} + +Tensor nonzero_static_cuda( + const Tensor& self, + int64_t size, + int64_t fill_value) { + Tensor out = + at::detail::empty_cuda({size, self.dim()}, self.options().dtype(kLong)); + return at::native::nonzero_static_out_cuda(self, size, fill_value, out); +} + } //namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8219383bd843..02a49c25148d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9260,11 +9260,13 @@ - func: nonzero_static.out(Tensor self, *, int size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: nonzero_static_out_cpu + CUDA: nonzero_static_out_cuda - func: nonzero_static(Tensor self, *, int size, int fill_value=-1) -> Tensor variants: method, function dispatch: CPU: nonzero_static_cpu + CUDA: nonzero_static_cuda - func: nonzero_numpy(Tensor self) -> Tensor[] variants: method, function diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5d93e9ef68bd..aabbe6bb70ca 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -20297,9 +20297,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): OpInfo('nonzero_static', dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), sample_inputs_func=sample_inputs_nonzero_static, - supports_out=False, + supports_out=False, # TODO: shouldn't this be true? supports_autograd=False, - decorators=[onlyCPU], skips=( DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),