-
Couldn't load subscription status.
- Fork 25.7k
Implement nonzero_static in CUDA. #136415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,6 +55,34 @@ __global__ void write_indices( | |
| } | ||
| } | ||
|
|
||
| template <typename index_t> | ||
| __global__ void write_indices_static( | ||
| int64_t* inp, | ||
| TensorDims<index_t> dims, | ||
| int ndim, | ||
| const int* num_nonzeros_ptr, | ||
| int n, | ||
| int64_t fill_value) { | ||
| auto index = threadIdx.x + blockIdx.x * blockDim.x; | ||
| index_t num_nonzeros = *num_nonzeros_ptr; | ||
| if (index < n) { | ||
| index_t div = 1; | ||
| int64_t idx_flat = inp[index]; | ||
| #pragma unroll | ||
| for (int dim = MAX_DIMS; dim >= 0; dim--) { | ||
| if (dim > ndim - 1) | ||
| continue; | ||
| auto dim_size = dims.sizes[dim]; | ||
| if (index < num_nonzeros) { | ||
| inp[index + dim * n] = (idx_flat / div) % dim_size; | ||
| } else { | ||
| inp[index + dim * n] = fill_value; | ||
| } | ||
| div *= dim_size; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| } //anonymous namespace | ||
|
|
||
| template<typename scalar_t> | ||
|
|
@@ -137,4 +165,132 @@ Tensor nonzero_cuda(const Tensor& self){ | |
| Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong)); | ||
| return at::native::nonzero_out_cuda(self, out); | ||
| } | ||
|
|
||
| template <typename scalar_t> | ||
| void nonzero_static_cuda_out_impl( | ||
| const Tensor& self, | ||
| int size, | ||
| int64_t fill_value, | ||
| Tensor& out) { | ||
| Tensor self_ = self.contiguous(); | ||
| const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
| // compute number of nonzero elements | ||
| size_t temp_storage_bytes = 0; | ||
| auto& allocator = *c10::cuda::CUDACachingAllocator::get(); | ||
| auto num_nonzeros = allocator.allocate(sizeof(int)); | ||
| cub::TransformInputIterator<bool, NonZeroOp<scalar_t>, const scalar_t*> itr( | ||
| self_.const_data_ptr<scalar_t>(), NonZeroOp<scalar_t>()); | ||
| // expected output size is num_nonzeros x ndim | ||
| // we are producing output with size {num_nonzeros, ndim} and strides {1, | ||
| // num_nonzeros} (that is, transposed ndim x num_nonzeros output) we are able | ||
| // to directly use passed output with this size and strides, and we can also | ||
| // (per contract) resize passed output with incorrect sizes anyway we want. | ||
| // However, out with correct sizes and incorrect strides will have to be | ||
| // copied to from the intermediate we've produced. | ||
| bool need_to_copy = out.dim() == 2 && out.sizes()[0] == size && | ||
| out.sizes()[1] == self.dim() && !out.t().is_contiguous(); | ||
| at::Tensor out_temp = need_to_copy | ||
| ? Tensor(at::detail::empty_cuda({self.dim(), size}, out.options())) | ||
| : out.resize_({self.dim(), size}); | ||
| // Scalars are expected to produce output of size (1,0), so we can't write to | ||
| // it | ||
| if (size > 0 && self.dim() > 0) { | ||
| cub::CountingInputIterator<int64_t> counting_itr(0); | ||
| temp_storage_bytes = 0; | ||
| cub::DeviceSelect::Flagged( | ||
| nullptr, | ||
| temp_storage_bytes, | ||
| counting_itr, | ||
| itr, | ||
| out_temp.mutable_data_ptr<int64_t>(), | ||
| (int*)num_nonzeros.get(), | ||
| std::min(size, (int)self_.numel()), | ||
| stream); | ||
| auto temp_storage = allocator.allocate(temp_storage_bytes); | ||
| cub::DeviceSelect::Flagged( | ||
| temp_storage.get(), | ||
| temp_storage_bytes, | ||
| counting_itr, | ||
| itr, | ||
| out_temp.mutable_data_ptr<int64_t>(), | ||
| (int*)num_nonzeros.get(), | ||
| std::min(size, (int)self_.numel()), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, so to make sure that all elements are inspected I'll need to pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Right. Thank you for pointing this out. size < input.numel() is a case I did not consider I can boil down your concern into this example: As far as how to handle this in a reasonable way, I am still thinking. Doing a scan for number of nonzeros up to a particular index (rather than a reduce), followed by a binary search for "size", would allow us to specify num_items correctly. But I want to think more about whether there is a better way. |
||
| stream); | ||
| TensorDims<int> dims; | ||
| for (int i = 0; i < self.dim(); i++) { | ||
| dims.sizes[i] = self.sizes()[i]; | ||
| } | ||
| const int nthreads = 256; | ||
| // if size is 0, then the launch will fail. | ||
| const int nblocks = (size + nthreads - 1) / nthreads; | ||
| write_indices_static<<<nblocks, nthreads, 0, stream>>>( | ||
| out_temp.mutable_data_ptr<int64_t>(), | ||
| dims, | ||
| self.dim(), | ||
| (int*)num_nonzeros.get(), | ||
| size, | ||
| fill_value); | ||
| C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
| } | ||
| if (need_to_copy) { | ||
| out.copy_(out_temp.t()); | ||
| } else { | ||
| // transpose out so it is correct size | ||
| Tensor out_ = out_temp.t(); | ||
| out.set_(out_); | ||
| } | ||
| } | ||
|
|
||
| Tensor& nonzero_static_out_cuda( | ||
| const Tensor& self, | ||
| int64_t size, | ||
| int64_t fill_value, | ||
| Tensor& out) { | ||
| TORCH_CHECK( | ||
| self.numel() < std::numeric_limits<int>::max(), | ||
| "nonzero_static is not supported for tensors with more than INT_MAX elements, \ | ||
| See https://github.com/pytorch/pytorch/issues/51871"); | ||
| TORCH_CHECK( | ||
| size < std::numeric_limits<int>::max(), | ||
| "nonzero_static is not supported for output tensors with more than INT_MAX elements, \ | ||
| See https://github.com/pytorch/pytorch/issues/51871"); | ||
| TORCH_CHECK( | ||
| out.dtype() == at::kLong, | ||
| "Expected object of scalar type ", | ||
| at::kLong, | ||
| " as out, but got ", | ||
| out.dtype()); | ||
| TORCH_CHECK( | ||
| self.device() == out.device(), | ||
| "expected self and out to be on the same device, but got out on ", | ||
| out.device(), | ||
| " and self on ", | ||
| self.device()); | ||
| TORCH_CHECK( | ||
| self.dim() <= MAX_DIMS, | ||
| "nonzero is not supported for tensor with more than ", | ||
| MAX_DIMS, | ||
| " dimensions"); | ||
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( | ||
| at::ScalarType::ComplexHalf, | ||
| at::ScalarType::Bool, | ||
| at::ScalarType::BFloat16, | ||
| at::ScalarType::Half, | ||
| self.scalar_type(), | ||
| "nonzero_static_cuda", | ||
| [&] { | ||
| nonzero_static_cuda_out_impl<scalar_t>(self, size, fill_value, out); | ||
| }); | ||
| return out; | ||
| } | ||
|
|
||
| Tensor nonzero_static_cuda( | ||
| const Tensor& self, | ||
| int64_t size, | ||
| int64_t fill_value) { | ||
| Tensor out = | ||
| at::detail::empty_cuda({size, self.dim()}, self.options().dtype(kLong)); | ||
| return at::native::nonzero_static_out_cuda(self, size, fill_value, out); | ||
| } | ||
|
|
||
| } //namespace at::native | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will always fail, unless you either allocate
outwith 0 size (like for the non-static case) or allocate it as a transposed tensor, and thus you'll incur extra copy