Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions aten/src/ATen/native/sparse/cuda/SoftMax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,18 @@ static int getNumThreads(int nElem) {
return threadSizes[4];
}

int64_t get_nvalues(const IntArrayRef& sizes, int64_t sparse_dim) {
/* Return the number of entries in the dense part of a sparse tensor.
`sizes` is a vector of sparse tensor dimensions.
`sparse_dim` is the dimension of the sparse part of a sparse tensor.
*/
return c10::multiply_integers(sizes.begin() + sparse_dim, sizes.end());
}

template <typename scalar_t, bool LogSoftMax>
__global__ void cuda_sparse_coo_softmax_kernel(
int64_t* sorted_pool_indices,
int64_t size,
int64_t pool_size,
int64_t* pool_sizes,
int64_t* pool_offsets,
int64_t nvalues,
Expand All @@ -103,7 +111,7 @@ __global__ void cuda_sparse_coo_softmax_kernel(
int index = tid + blkid * blksz;
int step = blksz * gridsz;

while (index < size) {
while (index < pool_size) {
int64_t offset = pool_offsets[index];
int64_t* pool_indices = sorted_pool_indices + offset;
int64_t pool_indices_size = pool_sizes[index];
Expand Down Expand Up @@ -408,7 +416,7 @@ void cuda_sparse_coo_softmax(

auto nnz = values.size(0);
auto sizes = input.sizes();
auto nvalues = values.numel() / nnz;
auto nvalues = get_nvalues(sizes, sparse_dim);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what the CPU kernel does.


/* Prepare accessors */
auto values_2 = values.view({nnz, nvalues});
Expand All @@ -429,17 +437,23 @@ void cuda_sparse_coo_softmax(
int block_size = getNumThreads(pool_size);
const int grid_size = (pool_size + block_size - 1) / block_size;

cuda_sparse_coo_softmax_kernel<scalar_t, LogSoftMax>
<<<grid_size, block_size, 0, stream>>>(
sorted_indices.data_ptr<int64_t>(),
pool_size,
pool_sizes.data_ptr<int64_t>(),
pool_offsets.data_ptr<int64_t>(),
nvalues,
mx_buffer.data_ptr<scalar_t>(),
values_accessor,
out_values_accessor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// If either nvalues or pool_size are zero, then cuda_sparse_coo_softmax_kernel
// won't actually perform any computation. Further, they will be
// invalid configuration parameters for the launch. So let's not
// launch a kernel unless both are non-zero.
if (nvalues > 0 && pool_size > 0) {
cuda_sparse_coo_softmax_kernel<scalar_t, LogSoftMax>
<<<grid_size, block_size, 0, stream>>>(
sorted_indices.data_ptr<int64_t>(),
pool_size,
pool_sizes.data_ptr<int64_t>(),
pool_offsets.data_ptr<int64_t>(),
nvalues,
mx_buffer.data_ptr<scalar_t>(),
values_accessor,
out_values_accessor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}

template <typename scalar_t, bool LogSoftMax>
Expand Down
7 changes: 7 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3383,6 +3383,13 @@ def sparse_log(x):
test_op(3, 100, [3, 4, 2, 3, 5, 2], coalesced)
test_op(4, 100, [3, 4, 2, 3, 5, 2], coalesced)


@dtypes(torch.double)
def test_softmax_zero_nnz(self, device, dtype):
t = torch.sparse_coo_tensor([[]], [], (3,), device=device, dtype=dtype)
out = torch.sparse.softmax(t, 0)
self.assertEqual(out.to_dense(), torch.zeros_like(t))

# TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA
@skipIfRocm
@coalescedonoff
Expand Down