Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend nonzero to int64 #125850

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
103 changes: 76 additions & 27 deletions aten/src/ATen/native/cuda/Nonzero.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,68 @@

namespace at::native {

namespace{
namespace {
// Wrapper for DeviceSelect::If to handle tensors larger than INT_MAX
// Imported from https://github.com/NVIDIA/cccl/pull/1379
// #TODO: Remove the wrapper when https://github.com/NVIDIA/cccl/issues/1422 is released
template <typename InputIteratorT,
typename OutputIteratorT,
typename NumSelectedIteratorT,
typename OffsetT,
typename SelectOp>
static cudaError_t dispatch_select_if_wrapper(
void* d_temp_storage,
std::size_t& temp_storage_bytes,
InputIteratorT d_in,
OutputIteratorT d_out,
NumSelectedIteratorT d_num_selected_out,
OffsetT num_items,
SelectOp select_op,
cudaStream_t stream = 0)
{
using flag_iterator_t = cub::NullType*;
using equality_op_t = cub::NullType;

return cub::DispatchSelectIf<
Copy link
Contributor Author

@bhack bhack May 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this requries cub/cccl 2.4.0?

InputIteratorT,
flag_iterator_t,
OutputIteratorT,
NumSelectedIteratorT,
SelectOp,
equality_op_t,
OffsetT,
false>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_in,
nullptr,
d_out,
d_num_selected_out,
select_op,
equality_op_t{},
num_items,
stream);
}

template<typename T>
struct NonZeroOp
{
__host__ __device__ __forceinline__ bool operator()(const T& a) const {
return (a!=T(0));
return (a != T(0));
}
};

template<typename T>
struct NonZeroOp<c10::complex<T>>
{
__host__ __device__ __forceinline__ bool operator()(const c10::complex<T>& a) const {
return (a.real() != T(0) || a.imag() != T(0));
}
};

//TODO: actually support int64_t index_t
template<typename index_t>
struct TensorDims {
index_t sizes[MAX_DIMS];
index_t sizes[MAX_DIMS];
};

template <typename index_t>
Expand All @@ -43,7 +92,7 @@ __global__ void write_indices(
if (index < n) {
index_t div = 1;
int64_t idx_flat = inp[index];
#pragma unroll
#pragma unroll
for (int dim = MAX_DIMS; dim >= 0; dim--) {
if (dim > ndim - 1)
continue;
Expand All @@ -57,12 +106,12 @@ __global__ void write_indices(
} //anonymous namespace

template<typename scalar_t>
void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){
void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) {
Tensor self_ = self.contiguous();
int N = self_.numel();
int64_t N = self_.numel(); // Changed to int64_t to handle larger sizes
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// compute number of nonzero elements
size_t temp_storage_bytes=0;
//Compute number of nonzero elements
size_t temp_storage_bytes = 0;
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto num_nonzeros = allocator.allocate(sizeof(int));
cub::TransformInputIterator<bool, NonZeroOp<scalar_t>, const scalar_t*> itr(self_.const_data_ptr<scalar_t>(), NonZeroOp<scalar_t>());
Expand All @@ -84,21 +133,21 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){
if (self.dim() > 0) {
cub::CountingInputIterator<int64_t> counting_itr(0);
temp_storage_bytes = 0;
cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr,
out_temp.mutable_data_ptr<int64_t>(), (int*)num_nonzeros.get(), N, stream);
dispatch_select_if_wrapper(nullptr, temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr<int64_t>(),
(int*)num_nonzeros.get(), N, NonZeroOp<scalar_t>(), stream);
temp_storage = allocator.allocate(temp_storage_bytes);
cub::DeviceSelect::Flagged(temp_storage.get(), temp_storage_bytes, counting_itr, itr,
out_temp.mutable_data_ptr<int64_t>(), (int*)num_nonzeros.get(), N, stream);
if (num_nonzeros_h > 0 && self.dim() > 1){
TensorDims<int> dims;
for (int i=0; i<self.dim(); i++){
dims.sizes[i] = self.sizes()[i];
}
const int nthreads = 256;
const int nblocks = (num_nonzeros_h + nthreads -1)/nthreads;
write_indices<<<nblocks, nthreads, 0, stream>>>(out_temp.mutable_data_ptr<int64_t>(),
dims, self.dim(), num_nonzeros_h);
C10_CUDA_KERNEL_LAUNCH_CHECK();
dispatch_select_if_wrapper(temp_storage.get(), temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr<int64_t>(),
(int*)num_nonzeros.get(), N, NonZeroOp<scalar_t>(), stream);
if (num_nonzeros_h > 0 && self.dim() > 1) {
TensorDims<int> dims;
for (int i = 0; i < self.dim(); i++) {
dims.sizes[i] = self.sizes()[i];
}
const int nthreads = 256;
const int nblocks = (num_nonzeros_h + nthreads - 1) / nthreads;
write_indices<<<nblocks, nthreads, 0, stream>>>(out_temp.mutable_data_ptr<int64_t>(),
dims, self.dim(), num_nonzeros_h);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
if (need_to_copy) {
Expand All @@ -110,20 +159,20 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){
}
}

Tensor& nonzero_out_cuda(const Tensor& self, Tensor& out){
TORCH_CHECK(self.numel() < std::numeric_limits<int>::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \
Tensor& nonzero_out_cuda(const Tensor& self, Tensor& out) {
TORCH_CHECK(self.numel() < std::numeric_limits<int64_t>::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \
See https://github.com/pytorch/pytorch/issues/51871");
TORCH_CHECK(out.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out.dtype());
TORCH_CHECK(self.device() == out.device(), "expected self and out to be on the same device, but got out on ",
out.device(), " and self on ", self.device());
TORCH_CHECK(self.dim() <= MAX_DIMS, "nonzero is not supported for tensor with more than ", MAX_DIMS, " dimensions");
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
self.scalar_type(), "nonzero_cuda",
[&] {nonzero_cuda_out_impl<scalar_t>(self, out);});
self.scalar_type(), "nonzero_cuda",
[&] {nonzero_cuda_out_impl<scalar_t>(self, out);});
return out;
}

Tensor nonzero_cuda(const Tensor& self){
Tensor nonzero_cuda(const Tensor& self) {
Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong));
return at::native::nonzero_out_cuda(self, out);
}
Expand Down