Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA support in the CSR layout: sparse_to_dense/add_sparse_csr #59011

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 5 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -352,7 +352,7 @@
variants: function, method
dispatch:
SparseCPU, SparseCUDA: add_sparse
SparseCsrCPU: add_sparse_csr
SparseCsrCPU, SparseCsrCUDA: add_sparse_csr
MkldnnCPU: mkldnn_add

- func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
Expand All @@ -361,7 +361,7 @@
structured_delegate: add.out
dispatch:
SparseCPU, SparseCUDA: add_sparse_
SparseCsrCPU: add_sparse_csr_
SparseCsrCPU, SparseCsrCUDA: add_sparse_csr_
MkldnnCPU: mkldnn_add_

- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
Expand All @@ -373,6 +373,7 @@
SparseCPU: add_out_sparse_cpu
SparseCUDA: add_out_sparse_cuda
SparseCsrCPU: add_out_sparse_csr_cpu
SparseCsrCUDA: add_out_sparse_csr_cuda
MkldnnCPU: mkldnn_add_out

- func: _add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
Expand Down Expand Up @@ -4581,7 +4582,7 @@
variants: function
dispatch:
SparseCPU, SparseCUDA: resize_as_sparse_
SparseCsrCPU: resize_as_sparse_csr_
SparseCsrCPU, SparseCsrCUDA: resize_as_sparse_csr_

- func: zero_(Tensor(a!) self) -> Tensor(a!)
device_check: NoCheck # TensorIterator
Expand Down Expand Up @@ -4866,7 +4867,7 @@
- func: to_dense(Tensor self, ScalarType? dtype=None) -> Tensor
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseCsrCPU: sparse_to_dense
SparseCPU, SparseCUDA, SparseCsrCPU, SparseCsrCUDA: sparse_to_dense
MkldnnCPU: mkldnn_to_dense

- func: to_dense_backward(Tensor grad, Tensor input) -> Tensor
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
Expand Up @@ -288,12 +288,12 @@ Tensor& add_out_dense_sparse_csr_cpu(
auto out_strides0 = resultBuffer.strides()[0];
auto out_strides1 = resultBuffer.strides()[1];

for (int32_t irow = 0; irow < src_crow_indices.size(0) - 1;
for (index_t irow = 0; irow < src_crow_indices.size(0) - 1;
++irow) {
int32_t start_index = crow_indices_accessor[irow];
int32_t end_index = crow_indices_accessor[irow + 1];
index_t start_index = crow_indices_accessor[irow];
index_t end_index = crow_indices_accessor[irow + 1];

for (int i = start_index; i < end_index; ++i) {
for (index_t i = start_index; i < end_index; ++i) {
auto icol = col_indices_accessor[i];
auto index = resultBuffer.storage_offset() + irow * out_strides0 +
icol * out_strides1;
Expand Down
160 changes: 160 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
@@ -0,0 +1,160 @@
#include <ATen/ATen.h>
#include <ATen/ExpandUtils.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/SparseCsrTensorImpl.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/Resize.h>
#include <algorithm>

#include <cuda_runtime.h>
#include <type_traits>

#include <THC/THCTensorMathPointwise.cuh>
#include <THC/THCThrustAllocator.cuh>

#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAUtils.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/native/sparse/cuda/SparseCUDABlas.cuh>

#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/for_each.h>
#include <thrust/sequence.h>

namespace at {
namespace native {

using namespace at::sparse_csr;
// certain utiliy functions are usable from sparse COO.
using namespace at::sparse;

Tensor& add_out_dense_sparse_csr_cuda(
Tensor& output,
const Tensor& dense,
const SparseCsrTensor& src,
const Scalar& alpha) {
TORCH_INTERNAL_ASSERT(dense.layout() == kStrided);
TORCH_INTERNAL_ASSERT(src.is_sparse_csr());
TORCH_INTERNAL_ASSERT(dense.is_cuda());

TORCH_CHECK(
output.is_contiguous(),
"out argument must be contiguous, but got: ",
output.suggest_memory_format());
TORCH_CHECK(
aocsa marked this conversation as resolved.
Show resolved Hide resolved
output.is_cuda(),
"add: expected 'out' to be CUDA tensor, but got tensor on device: ",
output.device());

TORCH_CHECK(
src.is_cuda(),
"add: expected 'other' to be a CUDA tensor, but got tensor on device: ",
src.device());

TORCH_CHECK(
dense.sizes().equals(src.sizes()),
"add: expected 'self' and 'other' to have same size, but self has size ",
dense.sizes(),
" while other has size ",
src.sizes(),
" (FYI: op2-sparse addition does not currently support broadcasting)");

auto commonDtype = promoteTypes(dense.scalar_type(), src.scalar_type());
aocsa marked this conversation as resolved.
Show resolved Hide resolved
TORCH_CHECK(
canCast(commonDtype, output.scalar_type()),
"Can't convert result type ",
commonDtype,
" to output ",
output.scalar_type(),
" in add operation");

Tensor src_values = src.values();
Tensor src_crow_indices = src.crow_indices();
Tensor src_col_indices = src.col_indices();

resize_output(output, dense.sizes());
aocsa marked this conversation as resolved.
Show resolved Hide resolved

Tensor resultBuffer = output;
Tensor valuesBuffer = src_values.to(commonDtype);
if (output.scalar_type() != commonDtype) {
resultBuffer = dense.to(commonDtype);
} else if (!is_same_tensor(output, dense)) {
resultBuffer.copy_(dense);
}
AT_DISPATCH_ALL_TYPES(
commonDtype,
"add_out_op2_sparse_csr",
[&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() {
AT_DISPATCH_INDEX_TYPES(
aocsa marked this conversation as resolved.
Show resolved Hide resolved
src_crow_indices.scalar_type(),
"csr_add_out_crow_indices",
[&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() {
scalar_t* values_accessor = valuesBuffer.data_ptr<scalar_t>();
scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
scalar_t cast_value = alpha.to<scalar_t>();

index_t* crow_indices_accessor = src_crow_indices.data_ptr<index_t>();
index_t* col_indices_accessor = src_col_indices.data_ptr<index_t>();
int64_t out_storage_offset = resultBuffer.storage_offset();
aocsa marked this conversation as resolved.
Show resolved Hide resolved

auto out_strides = resultBuffer.strides();
int64_t out_strides0 = out_strides[0];
int64_t out_strides1 = out_strides[1];

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);

// Note that this could be wildly imbalanced if the sparsity pattern varies a lot between rows.
thrust::for_each(
aocsa marked this conversation as resolved.
Show resolved Hide resolved
policy,
thrust::make_counting_iterator(int64_t(0)),
thrust::make_counting_iterator(int64_t(src_crow_indices.size(0) - 1)),
[values_accessor,
crow_indices_accessor,
col_indices_accessor,
out_ptr,
out_storage_offset,
out_strides0,
cast_value,
out_strides1
]__device__(int64_t irow) {
index_t start_index = crow_indices_accessor[irow];
index_t end_index = crow_indices_accessor[irow + 1];

for (index_t i = start_index; i < end_index; ++i) {
auto icol = col_indices_accessor[i];
auto index = out_storage_offset + irow * out_strides0 + icol * out_strides1;
out_ptr[index] += cast_value * values_accessor[i];
}
});
});
});
if (output.scalar_type() != commonDtype) {
output.copy_(resultBuffer);
}
return output;
}

Tensor& add_out_sparse_csr_cuda(
const Tensor& self,
const SparseCsrTensor& other,
const Scalar& alpha,
SparseCsrTensor& out) {
if (self.layout() == kStrided) {
return add_out_dense_sparse_csr_cuda(out, self, other, alpha);
} else {
TORCH_CHECK(
false,
"NotImplementedError: Addition of sparse CSR tensors is not yet implemented.")
}
return out;
}

} // namespace native
} // namespace at
31 changes: 29 additions & 2 deletions test/test_sparse_csr.py
Expand Up @@ -278,7 +278,6 @@ def test_sparse_csr_from_dense(self, device):
self.assertEqual(torch.tensor([0, 1, 2] * 3, dtype=torch.int64), sparse.col_indices())
self.assertEqual(torch.tensor([2] * 9), sparse.values())

@onlyCPU
@dtypes(torch.double)
def test_dense_convert(self, device, dtype):
size = (5, 5)
Expand Down Expand Up @@ -400,7 +399,35 @@ def test_shape(di, dj, dk, nnz):
for k in range(2, 8):
test_shape(i, j, k, i * j // 2)

@onlyCPU
@dtypes(torch.float, torch.double)
def test_add(self, device, dtype):
def _test_spadd_shape(nnz, shape):
x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
y = torch.randn(*shape, dtype=dtype, device=device)
r = random.random()

res = torch.add(y, x, alpha=r)
expected = y + r * x.to_dense()
self.assertEqual(res, expected)

# Non contiguous dense tensor
s = list(shape)
s[0] = shape[-1]
s[-1] = shape[0]
y = torch.randn(*s, dtype=torch.double, device=device)
y.transpose_(0, len(s) - 1)
r = random.random()

res = torch.add(y, x, alpha=r)
expected = y + r * x.to_dense()

self.assertEqual(res, expected)

_test_spadd_shape(10, [100, 100])
_test_spadd_shape(0, [100, 100])
_test_spadd_shape(10, [100, 1])
_test_spadd_shape(10, [1, 100])

@dtypes(*torch.testing.floating_types())
def test_coo_csr_conversion(self, device, dtype):
size = (5, 5)
Expand Down