New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sparse CSR CUDA: add addmv_out
#61407
Changes from 30 commits
b5c6ded
f3b3cbe
d205cca
e6f7588
4a4ec64
33c92cc
0788c03
f47bab9
d09d1f3
f359add
82daaa5
3d8285f
6cc2b6b
9da55a2
d967d0e
8ae6d51
cae220f
8f6e268
e898771
ad26134
fa1d8e2
9025902
42f51cf
35dc3c6
c002264
2397b70
86b0f1b
39de443
dfc64cd
31c4794
de10bfe
e78419a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,13 @@ class CuSparseDescriptor { | |
class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor | ||
: public CuSparseDescriptor<cusparseDnMatDescr, &cusparseDestroyDnMat> { | ||
public: | ||
CuSparseDnMatDescriptor(const Tensor& input); | ||
explicit CuSparseDnMatDescriptor(const Tensor& input); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! |
||
}; | ||
|
||
class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor | ||
: public CuSparseDescriptor<cusparseDnVecDescr, &cusparseDestroyDnVec> { | ||
public: | ||
explicit CuSparseDnVecDescriptor(const Tensor& input); | ||
}; | ||
|
||
class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor | ||
|
@@ -48,7 +54,7 @@ class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor | |
class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor | ||
: public CuSparseSpMatDescriptor { | ||
public: | ||
CuSparseSpMatCsrDescriptor(const Tensor& input); | ||
explicit CuSparseSpMatCsrDescriptor(const Tensor& input); | ||
}; | ||
|
||
} // namespace sparse | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,10 @@ TORCH_META_FUNC(addmv)(const Tensor &self, const Tensor &mat, const Tensor &vec, | |
"size mismatch, got ", self.size(0), ", ", mat.size(0), "x", mat.size(1), ",", vec.size(0)); | ||
auto names = at::namedinference::propagate_names_for_addmv(mat, vec, self); | ||
set_output(0, IntArrayRef(mat.sizes().data(), 1), {}, mat.options(), names); | ||
auto result = maybe_get_output(0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these lines were removed in #65686, is there a conflict? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I resolved the conflict incorrectly. Will fix that. |
||
//this check can fire for inplace op only, for all other versions result is guaranteed to be correct size | ||
TORCH_CHECK(result.dim() == 1 && result.sizes()[0] == mat.sizes()[0], "output of addmv operation should be 1D with ", | ||
"size equal to mat.size(0), yet got output size ", result.sizes(), " and mat.size(0) ", mat.size(0)); | ||
} | ||
} | ||
|
||
|
@@ -97,14 +101,14 @@ Tensor &mv_out(const Tensor &self, const Tensor &vec, Tensor& result) { | |
//in addmv, because addmv expects self to satisfy proper conditions | ||
//to avoid this, supply correctly sized self, its contents doesn't matter because beta is 0 | ||
if (result.dim() > 1 || (result.numel() != self.size(0) || result.numel() !=1)) { | ||
Tensor self_addmv = at::empty({self.size(0)}, self.options()); | ||
Tensor self_addmv = at::empty({self.size(0)}, vec.options()); | ||
return at::addmv_out(result, self_addmv, self, vec, 0, 1); | ||
} | ||
return at::addmv_out(result, result, self, vec, 0, 1); | ||
} | ||
|
||
Tensor mv(const Tensor &self, const Tensor &vec) { | ||
Tensor result = at::empty({self.size(0)}, self.options()); | ||
Tensor result = at::empty({self.size(0)}, vec.options()); | ||
//inplace version is more efficient if we can use it | ||
return at::addmv_(result, self, vec, 0, 1); | ||
} | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,61 @@ | ||||||||||||||||||||||||
#include <ATen/ATen.h> | ||||||||||||||||||||||||
#include <ATen/Dispatch.h> | ||||||||||||||||||||||||
#include <ATen/cuda/CUDASparse.h> | ||||||||||||||||||||||||
#include <ATen/native/Resize.h> | ||||||||||||||||||||||||
#include <ATen/native/sparse/cuda/SparseBlasImpl.h> | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
#include <c10/util/MaybeOwned.h> | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
namespace at { | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you also move There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved |
||||||||||||||||||||||||
namespace native { | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Tensor& addmv_out_sparse_csr_cuda(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha, Tensor& result) { | ||||||||||||||||||||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat.is_sparse_csr()); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
TORCH_CHECK(mat.dim() == 2, "addmv: Expected mat to be 2-D"); | ||||||||||||||||||||||||
TORCH_CHECK(vec.dim() == 1, "addmv: Expected vec to be 1-D"); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {mat, "mat", 2}, {vec, "vec", 3}}; | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still need TensorArgs? It's a perf penalty for very small convenience of using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright, we don't need this check here at all because it's already there in the generated code.
Example of generated code for addmv: at::Tensor & wrapper_out_addmv_out_out(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) {
c10::optional<Device> common_device = nullopt;
(void)common_device; // Suppress unused variable warning
c10::impl::check_and_update_common_device(common_device, out, "wrapper_out_addmv_out_out", "out");
c10::impl::check_and_update_common_device(common_device, self, "wrapper_out_addmv_out_out", "self");
c10::impl::check_and_update_common_device(common_device, mat, "wrapper_out_addmv_out_out", "mat");
c10::impl::check_and_update_common_device(common_device, vec, "wrapper_out_addmv_out_out", "vec");
const OptionalDeviceGuard device_guard(device_of(self));
return at::native::addmv_out_sparse_csr_cuda(self, mat, vec, beta, alpha, out);
} I was thinking that device checks and guards are not generated for sparse because of #59058 but the checks are not generated only for SparseCPU + dense CUDA.
Lines 853 to 860 in 5f15186
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
pytorch/aten/src/ATen/native/cuda/Blas.cpp Lines 99 to 100 in 5f15186
|
||||||||||||||||||||||||
checkAllSameGPU(__func__, args); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)}); | ||||||||||||||||||||||||
auto betaval = beta.toComplexDouble(); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if (&result != &self) { | ||||||||||||||||||||||||
at::native::resize_output(result, self_->sizes()); | ||||||||||||||||||||||||
if (betaval != 0.0) { | ||||||||||||||||||||||||
at::native::copy_(result, *self_); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if (mat._nnz() == 0) { | ||||||||||||||||||||||||
// shortcut for an empty matrix | ||||||||||||||||||||||||
// By definition, when beta==0, values in self should be ignored. nans and infs | ||||||||||||||||||||||||
// should not propagate | ||||||||||||||||||||||||
if (betaval == 0.0) { | ||||||||||||||||||||||||
return result.zero_(); | ||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||
return at::mul_out( | ||||||||||||||||||||||||
const_cast<Tensor&>(result), | ||||||||||||||||||||||||
self, | ||||||||||||||||||||||||
at::native::scalar_tensor( | ||||||||||||||||||||||||
beta, self.scalar_type(), c10::nullopt /* layout */, at::kCPU, c10::nullopt /* pin_memory */)); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// cuda 11.3 version computes garbage for float16 inputs | ||||||||||||||||||||||||
// couldn't check bfloat16 because it requires Ampere GPU but I assume the problem is same | ||||||||||||||||||||||||
// addmm works fine | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry to hear that. Is this still an ongoing problem? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll prepare a standalone C file and test it in CUDA 11.4. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, #61407 (comment). |
||||||||||||||||||||||||
if (vec.scalar_type() == kHalf || vec.scalar_type() == kBFloat16) { | ||||||||||||||||||||||||
result.unsqueeze_(-1); | ||||||||||||||||||||||||
sparse::impl::cuda::addmm_out_sparse_csr(mat, vec.unsqueeze(-1), beta, alpha, result); | ||||||||||||||||||||||||
result.squeeze_(-1); | ||||||||||||||||||||||||
return result; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
sparse::impl::cuda::addmv_out_sparse_csr(mat, vec, beta, alpha, result); | ||||||||||||||||||||||||
return result; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
} // namespace native | ||||||||||||||||||||||||
} // namespace at |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,6 +55,16 @@ void addmm_out_legacy( | |
at::native::s_addmm_out_csr_sparse_dense_cuda_worker(nnz, m, n, k, result, beta, result, alpha, crow_indices, col_indices, values, mat2); | ||
} | ||
|
||
c10::MaybeOwned<Tensor> inline prepare_dense_vector_for_cusparse( | ||
const Tensor& tensor) { | ||
if (tensor.is_non_overlapping_and_dense()) { | ||
return c10::MaybeOwned<Tensor>::borrowed(tensor); | ||
} else { | ||
return c10::MaybeOwned<Tensor>::owned( | ||
tensor.clone(at::MemoryFormat::Contiguous)); | ||
} | ||
} | ||
|
||
} // anonymous namespace | ||
|
||
void addmm_out_sparse_csr( | ||
|
@@ -166,6 +176,92 @@ void addmm_out_sparse_csr( | |
#endif | ||
} | ||
|
||
/* | ||
Computes a sparse matrix-dense vector product defined as | ||
y <- alpha*op(A)*x + beta*y | ||
|
||
Args: | ||
* `mat` - Tensor storing sparse m x n matrix A. | ||
* `vec` - Tensor storing dense vector x of size n. | ||
* `result` - [in] Tensor storing dense vector y of size m. | ||
[out] result of the operation. | ||
*/ | ||
void addmv_out_sparse_csr( | ||
const at::sparse_csr::SparseCsrTensor& mat, | ||
const Tensor& vec, | ||
const Scalar& beta, | ||
const Scalar& alpha, | ||
const Tensor& result) { | ||
#if !AT_USE_CUSPARSE_GENERIC_API() | ||
TORCH_CHECK( | ||
false, | ||
"Calling addmv on a sparse GPU tensor requires compiling ", | ||
"PyTorch with CUDA 10.2+ (CUDA 11+ on Windows). ", | ||
"Please use PyTorch built with newer CUDA version."); | ||
#else | ||
cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; | ||
|
||
c10::MaybeOwned<Tensor> result_ = prepare_dense_vector_for_cusparse(result); | ||
c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_cusparse(vec); | ||
|
||
// TODO: update this to support COO sparse layout | ||
auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat); | ||
auto descX = at::cuda::sparse::CuSparseDnVecDescriptor(*vec_); | ||
auto descY = at::cuda::sparse::CuSparseDnVecDescriptor(*result_); | ||
|
||
// cusparseSpMVAlg_t was updated in cuda 11.2.1 (cusparse 11.4.0) | ||
#if CUSPARSE_VERSION >= 11400 | ||
cusparseSpMVAlg_t alg = CUSPARSE_SPMV_ALG_DEFAULT; | ||
#else | ||
cusparseSpMVAlg_t alg = CUSPARSE_MV_ALG_DEFAULT; | ||
#endif | ||
|
||
// There is no dispatch for kHalf and kBFloat16 types because cusparse | ||
// computes garbage in this case, latest checked version of cuda is 11.3 | ||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this still an ongoing problem? |
||
result.scalar_type(), | ||
"addmv_out_sparse_csr_cuda_impl", | ||
[&] { | ||
auto beta_ = beta.to<scalar_t>(); | ||
auto alpha_ = alpha.to<scalar_t>(); | ||
cudaDataType compute_type = at::cuda::getCudaDataType<scalar_t>(); | ||
auto handle = at::cuda::getCurrentCUDASparseHandle(); | ||
|
||
size_t buffer_size; | ||
TORCH_CUDASPARSE_CHECK(cusparseSpMV_bufferSize( | ||
handle, | ||
opA, | ||
&alpha_, | ||
descA.descriptor(), | ||
descX.descriptor(), | ||
&beta_, | ||
descY.descriptor(), | ||
compute_type, | ||
alg, | ||
&buffer_size // output | ||
)); | ||
|
||
auto& allocator = *c10::cuda::CUDACachingAllocator::get(); | ||
auto work_data = allocator.allocate(buffer_size); | ||
|
||
TORCH_CUDASPARSE_CHECK(cusparseSpMV( | ||
handle, | ||
opA, | ||
&alpha_, | ||
descA.descriptor(), | ||
descX.descriptor(), | ||
&beta_, | ||
descY.descriptor(), | ||
compute_type, | ||
alg, | ||
work_data.get())); | ||
}); | ||
if (!result.is_same(*result_)) { | ||
result.copy_(*result_); | ||
} | ||
#endif | ||
} | ||
|
||
} // namespace cuda | ||
} // namespace impl | ||
} // namespace sparse | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
from torch.testing._internal.common_utils import \ | ||
(IS_MACOS, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff) | ||
from torch.testing._internal.common_device_type import \ | ||
(instantiate_device_type_tests, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, | ||
(instantiate_device_type_tests, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoCusparseGeneric, | ||
precisionOverride) | ||
from torch.testing._internal.common_dtype import floating_types, get_all_dtypes | ||
|
||
|
@@ -402,7 +402,11 @@ def test_matmul_device_mismatch(self, device, dtype): | |
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): | ||
torch.addmm(s, csr, m2) | ||
|
||
@dtypes(torch.float, torch.double) | ||
@skipCUDAIfNoCusparseGeneric | ||
@dtypes(*torch.testing.floating_types()) | ||
@dtypesIfCUDA(*get_all_complex_dtypes(), | ||
*get_all_fp_dtypes(include_half=SM53OrLater, include_bfloat16=SM80OrLater)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is SM80OrLater correct guard for bfloat16? For regular addmm bfloat16 is supported (with perf equivalent to fp32) for earlier architectures. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately, cuSPARSE raises
|
||
@precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm, 1e-2 seems high for float16? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is. It seems that cuSPARSE uses a different accumulation strategy or something else is different leading to less accurate results than cuBLAS computes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately, 1e-2 is required for tests to pass for float16.
Interestingly running specific test
Tested on CUDA 11.4.2 and Turing card. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using SpMV instead of SpMM makes the test pass without precision overrides for float16. So I'll remove it here. |
||
def test_csr_matvec(self, device, dtype): | ||
side = 100 | ||
for index_dtype in [torch.int32, torch.int64]: | ||
|
@@ -415,7 +419,12 @@ def test_csr_matvec(self, device, dtype): | |
self.assertEqual(res, expected) | ||
|
||
bad_vec = torch.randn(side + 10, dtype=dtype, device=device) | ||
with self.assertRaisesRegex(RuntimeError, "mv: expected"): | ||
err_msg = "mv: expected" | ||
# CUDA path now uses generic meta/structured implementation | ||
# TODO: move CPU path to not use `mv_sparse` function | ||
if self.device_type == 'cuda': | ||
err_msg = "size mismatch, got" | ||
with self.assertRaisesRegex(RuntimeError, err_msg): | ||
csr.matmul(bad_vec) | ||
|
||
@dtypes(torch.double) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is guaranteed to be true if
is_contiguous
is true?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right, I'll remove it. Good it's only for debugging 🙂
pytorch/c10/core/TensorImpl.h
Lines 2299 to 2300 in 5f15186