Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions aten/src/ATen/native/cuda/Nonzero.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,34 @@ __global__ void write_indices(
}
}

template <typename index_t>
__global__ void write_indices_static(
int64_t* inp,
TensorDims<index_t> dims,
int ndim,
const int* num_nonzeros_ptr,
int n,
int64_t fill_value) {
auto index = threadIdx.x + blockIdx.x * blockDim.x;
index_t num_nonzeros = *num_nonzeros_ptr;
if (index < n) {
index_t div = 1;
int64_t idx_flat = inp[index];
#pragma unroll
for (int dim = MAX_DIMS; dim >= 0; dim--) {
if (dim > ndim - 1)
continue;
auto dim_size = dims.sizes[dim];
if (index < num_nonzeros) {
inp[index + dim * n] = (idx_flat / div) % dim_size;
} else {
inp[index + dim * n] = fill_value;
}
div *= dim_size;
}
}
}

} //anonymous namespace

template<typename scalar_t>
Expand Down Expand Up @@ -137,4 +165,132 @@ Tensor nonzero_cuda(const Tensor& self){
Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong));
return at::native::nonzero_out_cuda(self, out);
}

template <typename scalar_t>
void nonzero_static_cuda_out_impl(
const Tensor& self,
int size,
int64_t fill_value,
Tensor& out) {
Tensor self_ = self.contiguous();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// compute number of nonzero elements
size_t temp_storage_bytes = 0;
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto num_nonzeros = allocator.allocate(sizeof(int));
cub::TransformInputIterator<bool, NonZeroOp<scalar_t>, const scalar_t*> itr(
self_.const_data_ptr<scalar_t>(), NonZeroOp<scalar_t>());
// expected output size is num_nonzeros x ndim
// we are producing output with size {num_nonzeros, ndim} and strides {1,
// num_nonzeros} (that is, transposed ndim x num_nonzeros output) we are able
// to directly use passed output with this size and strides, and we can also
// (per contract) resize passed output with incorrect sizes anyway we want.
// However, out with correct sizes and incorrect strides will have to be
// copied to from the intermediate we've produced.
bool need_to_copy = out.dim() == 2 && out.sizes()[0] == size &&
out.sizes()[1] == self.dim() && !out.t().is_contiguous();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will always fail, unless you either allocate out with 0 size (like for the non-static case) or allocate it as a transposed tensor, and thus you'll incur extra copy

at::Tensor out_temp = need_to_copy
? Tensor(at::detail::empty_cuda({self.dim(), size}, out.options()))
: out.resize_({self.dim(), size});
// Scalars are expected to produce output of size (1,0), so we can't write to
// it
if (size > 0 && self.dim() > 0) {
cub::CountingInputIterator<int64_t> counting_itr(0);
temp_storage_bytes = 0;
cub::DeviceSelect::Flagged(
nullptr,
temp_storage_bytes,
counting_itr,
itr,
out_temp.mutable_data_ptr<int64_t>(),
(int*)num_nonzeros.get(),
std::min(size, (int)self_.numel()),
stream);
auto temp_storage = allocator.allocate(temp_storage_bytes);
cub::DeviceSelect::Flagged(
temp_storage.get(),
temp_storage_bytes,
counting_itr,
itr,
out_temp.mutable_data_ptr<int64_t>(),
(int*)num_nonzeros.get(),
std::min(size, (int)self_.numel()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, so to make sure that all elements are inspected I'll need to pass size that's equal to or greater than self_.numel() right? So if e.g. I know that there are only N nonzero elements in my tensor I can't make use of this information, and would still need to allocate full size output in this case, which oftentimes is prohibitively expensive (nonzero was implemented like this at first, but people complained about it, so we changed the implementation to first count nonzeros and then allocate output).
Don't know how to do the right thing with cub APIs though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if e.g. I know that there are only N nonzero elements in my tensor I can't make use of this information, and would still need to allocate full size output in this case, which oftentimes is prohibitively expensive (nonzero was implemented like this at first, but people complained about it, so we changed the implementation to first count nonzeros and then allocate output).

Right. Thank you for pointing this out. size < input.numel() is a case I did not consider I can boil down your concern into this example:

In [4]: torch.nonzero_static(torch.tensor([1,0,1,0]),  size=2, fill_value=-1)
Out[4]: 
tensor([[0],
        [2]])

In [5]: torch.nonzero_static(torch.tensor([1,0,1,0], device="cuda"),  size=2, fill_value=-1)
Out[5]: 
tensor([[ 0],
        [-1]], device='cuda:0')

As far as how to handle this in a reasonable way, I am still thinking. Doing a scan for number of nonzeros up to a particular index (rather than a reduce), followed by a binary search for "size", would allow us to specify num_items correctly. But I want to think more about whether there is a better way.

stream);
TensorDims<int> dims;
for (int i = 0; i < self.dim(); i++) {
dims.sizes[i] = self.sizes()[i];
}
const int nthreads = 256;
// if size is 0, then the launch will fail.
const int nblocks = (size + nthreads - 1) / nthreads;
write_indices_static<<<nblocks, nthreads, 0, stream>>>(
out_temp.mutable_data_ptr<int64_t>(),
dims,
self.dim(),
(int*)num_nonzeros.get(),
size,
fill_value);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
if (need_to_copy) {
out.copy_(out_temp.t());
} else {
// transpose out so it is correct size
Tensor out_ = out_temp.t();
out.set_(out_);
}
}

Tensor& nonzero_static_out_cuda(
const Tensor& self,
int64_t size,
int64_t fill_value,
Tensor& out) {
TORCH_CHECK(
self.numel() < std::numeric_limits<int>::max(),
"nonzero_static is not supported for tensors with more than INT_MAX elements, \
See https://github.com/pytorch/pytorch/issues/51871");
TORCH_CHECK(
size < std::numeric_limits<int>::max(),
"nonzero_static is not supported for output tensors with more than INT_MAX elements, \
See https://github.com/pytorch/pytorch/issues/51871");
TORCH_CHECK(
out.dtype() == at::kLong,
"Expected object of scalar type ",
at::kLong,
" as out, but got ",
out.dtype());
TORCH_CHECK(
self.device() == out.device(),
"expected self and out to be on the same device, but got out on ",
out.device(),
" and self on ",
self.device());
TORCH_CHECK(
self.dim() <= MAX_DIMS,
"nonzero is not supported for tensor with more than ",
MAX_DIMS,
" dimensions");
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
at::ScalarType::ComplexHalf,
at::ScalarType::Bool,
at::ScalarType::BFloat16,
at::ScalarType::Half,
self.scalar_type(),
"nonzero_static_cuda",
[&] {
nonzero_static_cuda_out_impl<scalar_t>(self, size, fill_value, out);
});
return out;
}

Tensor nonzero_static_cuda(
const Tensor& self,
int64_t size,
int64_t fill_value) {
Tensor out =
at::detail::empty_cuda({size, self.dim()}, self.options().dtype(kLong));
return at::native::nonzero_static_out_cuda(self, size, fill_value, out);
}

} //namespace at::native
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9260,11 +9260,13 @@
- func: nonzero_static.out(Tensor self, *, int size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: nonzero_static_out_cpu
CUDA: nonzero_static_out_cuda

- func: nonzero_static(Tensor self, *, int size, int fill_value=-1) -> Tensor
variants: method, function
dispatch:
CPU: nonzero_static_cpu
CUDA: nonzero_static_cuda

- func: nonzero_numpy(Tensor self) -> Tensor[]
variants: method, function
Expand Down
3 changes: 1 addition & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20297,9 +20297,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
OpInfo('nonzero_static',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
sample_inputs_func=sample_inputs_nonzero_static,
supports_out=False,
supports_out=False, # TODO: shouldn't this be true?
supports_autograd=False,
decorators=[onlyCPU],
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
Expand Down
Loading