Skip to content

Commit

Permalink
Add complex template
Browse files Browse the repository at this point in the history
  • Loading branch information
bhack committed May 10, 2024
1 parent a47573b commit dc29c0f
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions aten/src/ATen/native/cuda/Nonzero.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ struct NonZeroOp
}
};

// Specialization for complex types
template<typename T>
struct NonZeroOp<c10::complex<T>> {
__host__ __device__ __forceinline__ bool operator()(const c10::complex<T>& a) const {
return (a.real() != T(0)) || (a.imag() != T(0));
}
};

//TODO: actually support int64_t index_t
template<typename index_t>
struct TensorDims {
Expand Down Expand Up @@ -99,34 +107,30 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) {
Tensor self_ = self.contiguous();
int64_t N = self_.numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

size_t temp_storage_bytes = 0;
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto num_nonzeros = allocator.allocate(sizeof(int64_t));

cub::TransformInputIterator<bool, NonZeroOp<scalar_t>, const scalar_t*> itr(self_.const_data_ptr<scalar_t>(), NonZeroOp<scalar_t>());
cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, static_cast<int64_t*>(num_nonzeros.get()), N, stream);
auto temp_storage = allocator.allocate(temp_storage_bytes);
cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, static_cast<int64_t*>(num_nonzeros.get()), N, stream);

int64_t num_nonzeros_h;
at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int64_t), cudaMemcpyDeviceToHost, stream);
//expected output size is num_nonzeros x ndim
//we are producing output with size {num_nonzeros, ndim} and strides {1, num_nonzeros} (that is, transposed ndim x num_nonzeros output)
//we are able to directly use passed output with this size and strides, and we can also (per contract)
//resize passed output with incorrect sizes anyway we want.
//However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced.
//However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced.
bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous();
at::Tensor out_temp = need_to_copy ?
Tensor(at::detail::empty_cuda({self.dim(), num_nonzeros_h}, out.options())) : out.resize_({self.dim(), num_nonzeros_h});
//Scalars are expected to produce output of size (1,0), so we can't write to it
//Scalars are expected to produce output of size (1,0), so we can't write to it
if (self.dim() > 0) {
cub::CountingInputIterator<int64_t> counting_itr(0);
temp_storage_bytes = 0;
dispatch_select_if_wrapper(nullptr, temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr<int64_t>(), static_cast<int64_t*>(num_nonzeros.get()), N, NonZeroOp<scalar_t>(), stream);
temp_storage = allocator.allocate(temp_storage_bytes);
dispatch_select_if_wrapper(temp_storage.get(), temp_storage_bytes, counting_itr, out_temp.mutable_data_ptr<int64_t>(), static_cast<int64_t*>(num_nonzeros.get()), N, NonZeroOp<scalar_t>(), stream);

if (num_nonzeros_h > 0 && self.dim() > 1) {
TensorDims<int64_t> dims;
for (int i = 0; i < self.dim(); i++) {
Expand Down Expand Up @@ -160,5 +164,4 @@ Tensor nonzero_cuda(const Tensor& self) {
Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong));
return at::native::nonzero_out_cuda(self, out);
}

} // namespace at::native

0 comments on commit dc29c0f

Please sign in to comment.