Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse CSR CUDA: add addmv_out #61407

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
b5c6ded
Sparse CSR CUDA: add `addmv_out`
IvanYashchuk Jul 8, 2021
f3b3cbe
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Jul 9, 2021
d205cca
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Jul 9, 2021
e6f7588
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Jul 9, 2021
4a4ec64
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Jul 9, 2021
33c92cc
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Jul 12, 2021
0788c03
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Jul 19, 2021
f47bab9
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Jul 26, 2021
d09d1f3
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Jul 26, 2021
f359add
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Aug 18, 2021
82daaa5
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Aug 19, 2021
3d8285f
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Aug 24, 2021
6cc2b6b
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Aug 24, 2021
9da55a2
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Aug 25, 2021
d967d0e
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Aug 25, 2021
8ae6d51
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Aug 25, 2021
cae220f
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Aug 25, 2021
8f6e268
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Aug 25, 2021
e898771
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Aug 26, 2021
ad26134
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Aug 26, 2021
fa1d8e2
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Aug 26, 2021
9025902
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Aug 31, 2021
42f51cf
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Sep 24, 2021
35dc3c6
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Sep 24, 2021
c002264
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Sep 24, 2021
2397b70
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Sep 24, 2021
86b0f1b
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Sep 25, 2021
39de443
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Sep 27, 2021
dfc64cd
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Sep 28, 2021
31c4794
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Oct 11, 2021
de10bfe
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Oct 12, 2021
e78419a
Update on "Sparse CSR CUDA: add `addmv_out`"
IvanYashchuk Oct 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions BUILD.bazel
Expand Up @@ -377,6 +377,7 @@ filegroup(
"aten/src/ATen/native/miopen/Conv_miopen.cpp",
"aten/src/ATen/native/miopen/RNN_miopen.cpp",
"aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp",
"aten/src/ATen/native/sparse/cuda/SparseBlas.cpp",
"aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp",
"aten/src/THC/THCGeneral.cpp",
"aten/src/THC/THCStorageCopy.cpp",
Expand Down
17 changes: 17 additions & 0 deletions aten/src/ATen/cuda/CUDASparseDescriptors.cpp
Expand Up @@ -91,6 +91,23 @@ CuSparseDnMatDescriptor::CuSparseDnMatDescriptor(const Tensor& input) {
descriptor_.reset(raw_descriptor);
}

CuSparseDnVecDescriptor::CuSparseDnVecDescriptor(const Tensor& input) {
// cuSPARSE doesn't support batched vectors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() == 1);

// cuSPARSE doesn't support non-contiguous vectors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_contiguous());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_non_overlapping_and_dense());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is guaranteed to be true if is_contiguous is true?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, I'll remove it. Good it's only for debugging 🙂

pytorch/c10/core/TensorImpl.h

Lines 2299 to 2300 in 5f15186

is_non_overlapping_and_dense_ =
is_contiguous_ || compute_non_overlapping_and_dense();


cudaDataType value_type = ScalarTypeToCudaDataType(input.scalar_type());
check_supported_cuda_type(value_type);

cusparseDnVecDescr_t raw_descriptor;
TORCH_CUDASPARSE_CHECK(cusparseCreateDnVec(
&raw_descriptor, input.numel(), input.data_ptr(), value_type));
descriptor_.reset(raw_descriptor);
}

CuSparseSpMatCsrDescriptor::CuSparseSpMatCsrDescriptor(const Tensor& input) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_sparse_csr());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() == 2);
Expand Down
10 changes: 8 additions & 2 deletions aten/src/ATen/cuda/CUDASparseDescriptors.h
Expand Up @@ -39,7 +39,13 @@ class CuSparseDescriptor {
class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
: public CuSparseDescriptor<cusparseDnMatDescr, &cusparseDestroyDnMat> {
public:
CuSparseDnMatDescriptor(const Tensor& input);
explicit CuSparseDnMatDescriptor(const Tensor& input);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

};

class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
: public CuSparseDescriptor<cusparseDnVecDescr, &cusparseDestroyDnVec> {
public:
explicit CuSparseDnVecDescriptor(const Tensor& input);
};

class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
Expand All @@ -48,7 +54,7 @@ class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor
: public CuSparseSpMatDescriptor {
public:
CuSparseSpMatCsrDescriptor(const Tensor& input);
explicit CuSparseSpMatCsrDescriptor(const Tensor& input);
};

} // namespace sparse
Expand Down
8 changes: 6 additions & 2 deletions aten/src/ATen/native/Blas.cpp
Expand Up @@ -17,6 +17,10 @@ TORCH_META_FUNC(addmv)(const Tensor &self, const Tensor &mat, const Tensor &vec,
"size mismatch, got ", self.size(0), ", ", mat.size(0), "x", mat.size(1), ",", vec.size(0));
auto names = at::namedinference::propagate_names_for_addmv(mat, vec, self);
set_output(0, IntArrayRef(mat.sizes().data(), 1), {}, mat.options(), names);
auto result = maybe_get_output(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these lines were removed in #65686, is there a conflict?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I resolved the conflict incorrectly. Will fix that.

//this check can fire for inplace op only, for all other versions result is guaranteed to be correct size
TORCH_CHECK(result.dim() == 1 && result.sizes()[0] == mat.sizes()[0], "output of addmv operation should be 1D with ",
"size equal to mat.size(0), yet got output size ", result.sizes(), " and mat.size(0) ", mat.size(0));
}
}

Expand Down Expand Up @@ -97,14 +101,14 @@ Tensor &mv_out(const Tensor &self, const Tensor &vec, Tensor& result) {
//in addmv, because addmv expects self to satisfy proper conditions
//to avoid this, supply correctly sized self, its contents doesn't matter because beta is 0
if (result.dim() > 1 || (result.numel() != self.size(0) || result.numel() !=1)) {
Tensor self_addmv = at::empty({self.size(0)}, self.options());
Tensor self_addmv = at::empty({self.size(0)}, vec.options());
return at::addmv_out(result, self_addmv, self, vec, 0, 1);
}
return at::addmv_out(result, result, self, vec, 0, 1);
}

Tensor mv(const Tensor &self, const Tensor &vec) {
Tensor result = at::empty({self.size(0)}, self.options());
Tensor result = at::empty({self.size(0)}, vec.options());
//inplace version is more efficient if we can use it
return at::addmv_(result, self, vec, 0, 1);
}
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -454,6 +454,7 @@
dispatch:
CPU: addmv_out_cpu
CUDA: addmv_out_cuda
SparseCsrCUDA: addmv_out_sparse_csr_cuda

- func: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
variants: function, method
Expand Down Expand Up @@ -3112,8 +3113,8 @@
- func: mv(Tensor self, Tensor vec) -> Tensor
variants: function, method
dispatch:
CPU, CUDA: mv
SparseCPU, SparseCUDA, SparseCsrCPU, SparseCsrCUDA: mv_sparse
CPU, CUDA, SparseCsrCUDA: mv
SparseCPU, SparseCUDA, SparseCsrCPU: mv_sparse

- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Expand Down
61 changes: 61 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseBlas.cpp
@@ -0,0 +1,61 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDASparse.h>
#include <ATen/native/Resize.h>
#include <ATen/native/sparse/cuda/SparseBlasImpl.h>

#include <c10/util/MaybeOwned.h>

namespace at {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also move addmm_out_sparse_csr_dense_cuda from SparseCsrTensorMath.cu here? That would be a logical place.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved addmm_out_sparse_csr_dense_cuda with a separate PR #66485.

namespace native {

Tensor& addmv_out_sparse_csr_cuda(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha, Tensor& result) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat.is_sparse_csr());

TORCH_CHECK(mat.dim() == 2, "addmv: Expected mat to be 2-D");
TORCH_CHECK(vec.dim() == 1, "addmv: Expected vec to be 1-D");

TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {mat, "mat", 2}, {vec, "vec", 3}};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need TensorArgs? It's a perf penalty for very small convenience of using checkAllSameGPU. #62653 is landing soon that will enable these checks conveniently on the Tensors.
Also, out of curiosity, how does det_device and is_cuda in checkAllSameGPU work for sparse mat?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, we don't need this check here at all because it's already there in the generated code.
SparseCsrCUDA is part of is_cuda_dispatch_key and device check is generated using

device_check = RegisterDispatchKey.gen_device_check(f.device_check, list(device_check_args), name)

Example of generated code for addmv:

at::Tensor & wrapper_out_addmv_out_out(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) {
  c10::optional<Device> common_device = nullopt;
(void)common_device; // Suppress unused variable warning

  c10::impl::check_and_update_common_device(common_device, out, "wrapper_out_addmv_out_out", "out");
  c10::impl::check_and_update_common_device(common_device, self, "wrapper_out_addmv_out_out", "self");
  c10::impl::check_and_update_common_device(common_device, mat, "wrapper_out_addmv_out_out", "mat");
  c10::impl::check_and_update_common_device(common_device, vec, "wrapper_out_addmv_out_out", "vec");

  const OptionalDeviceGuard device_guard(device_of(self));
  return at::native::addmv_out_sparse_csr_cuda(self, mat, vec, beta, alpha, out);
}

I was thinking that device checks and guards are not generated for sparse because of #59058 but the checks are not generated only for SparseCPU + dense CUDA.

is_cuda is TensorImpl's method and it's just specialized for the case of key_set_ equal to DispatchKey::SparseCsrCUDA:

bool is_cuda() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has(DispatchKey::CUDA) ||
key_set_.has(DispatchKey::SparseCUDA) ||
key_set_.has(DispatchKey::SparseCsrCUDA) ||
key_set_.has(DispatchKey::QuantizedCUDA);
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addmm_out_cuda_impl (dense CUDA implementation) has this "sameGPU" check and probably it shouldn't be there.

TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
checkAllSameGPU(__func__, args);

checkAllSameGPU(__func__, args);

c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
auto betaval = beta.toComplexDouble();

if (&result != &self) {
at::native::resize_output(result, self_->sizes());
if (betaval != 0.0) {
at::native::copy_(result, *self_);
}
}

if (mat._nnz() == 0) {
// shortcut for an empty matrix
// By definition, when beta==0, values in self should be ignored. nans and infs
// should not propagate
if (betaval == 0.0) {
return result.zero_();
} else {
return at::mul_out(
const_cast<Tensor&>(result),
self,
at::native::scalar_tensor(
beta, self.scalar_type(), c10::nullopt /* layout */, at::kCPU, c10::nullopt /* pin_memory */));
}
}

// cuda 11.3 version computes garbage for float16 inputs
// couldn't check bfloat16 because it requires Ampere GPU but I assume the problem is same
// addmm works fine
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to hear that. Is this still an ongoing problem?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll prepare a standalone C file and test it in CUDA 11.4.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (vec.scalar_type() == kHalf || vec.scalar_type() == kBFloat16) {
result.unsqueeze_(-1);
sparse::impl::cuda::addmm_out_sparse_csr(mat, vec.unsqueeze(-1), beta, alpha, result);
result.squeeze_(-1);
return result;
}

sparse::impl::cuda::addmv_out_sparse_csr(mat, vec, beta, alpha, result);
return result;
}

} // namespace native
} // namespace at
96 changes: 96 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
Expand Up @@ -55,6 +55,16 @@ void addmm_out_legacy(
at::native::s_addmm_out_csr_sparse_dense_cuda_worker(nnz, m, n, k, result, beta, result, alpha, crow_indices, col_indices, values, mat2);
}

c10::MaybeOwned<Tensor> inline prepare_dense_vector_for_cusparse(
const Tensor& tensor) {
if (tensor.is_non_overlapping_and_dense()) {
return c10::MaybeOwned<Tensor>::borrowed(tensor);
} else {
return c10::MaybeOwned<Tensor>::owned(
tensor.clone(at::MemoryFormat::Contiguous));
}
}

} // anonymous namespace

void addmm_out_sparse_csr(
Expand Down Expand Up @@ -166,6 +176,92 @@ void addmm_out_sparse_csr(
#endif
}

/*
Computes a sparse matrix-dense vector product defined as
y <- alpha*op(A)*x + beta*y

Args:
* `mat` - Tensor storing sparse m x n matrix A.
* `vec` - Tensor storing dense vector x of size n.
* `result` - [in] Tensor storing dense vector y of size m.
[out] result of the operation.
*/
void addmv_out_sparse_csr(
const at::sparse_csr::SparseCsrTensor& mat,
const Tensor& vec,
const Scalar& beta,
const Scalar& alpha,
const Tensor& result) {
#if !AT_USE_CUSPARSE_GENERIC_API()
TORCH_CHECK(
false,
"Calling addmv on a sparse GPU tensor requires compiling ",
"PyTorch with CUDA 10.2+ (CUDA 11+ on Windows). ",
"Please use PyTorch built with newer CUDA version.");
#else
cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;

c10::MaybeOwned<Tensor> result_ = prepare_dense_vector_for_cusparse(result);
c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_cusparse(vec);

// TODO: update this to support COO sparse layout
auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat);
auto descX = at::cuda::sparse::CuSparseDnVecDescriptor(*vec_);
auto descY = at::cuda::sparse::CuSparseDnVecDescriptor(*result_);

// cusparseSpMVAlg_t was updated in cuda 11.2.1 (cusparse 11.4.0)
#if CUSPARSE_VERSION >= 11400
cusparseSpMVAlg_t alg = CUSPARSE_SPMV_ALG_DEFAULT;
#else
cusparseSpMVAlg_t alg = CUSPARSE_MV_ALG_DEFAULT;
#endif

// There is no dispatch for kHalf and kBFloat16 types because cusparse
// computes garbage in this case, latest checked version of cuda is 11.3
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still an ongoing problem?

result.scalar_type(),
"addmv_out_sparse_csr_cuda_impl",
[&] {
auto beta_ = beta.to<scalar_t>();
auto alpha_ = alpha.to<scalar_t>();
cudaDataType compute_type = at::cuda::getCudaDataType<scalar_t>();
auto handle = at::cuda::getCurrentCUDASparseHandle();

size_t buffer_size;
TORCH_CUDASPARSE_CHECK(cusparseSpMV_bufferSize(
handle,
opA,
&alpha_,
descA.descriptor(),
descX.descriptor(),
&beta_,
descY.descriptor(),
compute_type,
alg,
&buffer_size // output
));

auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto work_data = allocator.allocate(buffer_size);

TORCH_CUDASPARSE_CHECK(cusparseSpMV(
handle,
opA,
&alpha_,
descA.descriptor(),
descX.descriptor(),
&beta_,
descY.descriptor(),
compute_type,
alg,
work_data.get()));
});
if (!result.is_same(*result_)) {
result.copy_(*result_);
}
#endif
}

} // namespace cuda
} // namespace impl
} // namespace sparse
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h
Expand Up @@ -17,6 +17,13 @@ void addmm_out_sparse_csr(
const Scalar& alpha,
const Tensor& result);

void addmv_out_sparse_csr(
const at::sparse_csr::SparseCsrTensor& mat,
const Tensor& vec,
const Scalar& beta,
const Scalar& alpha,
const Tensor& result);

} // namespace cuda
} // namespace impl
} // namespace sparse
Expand Down
15 changes: 12 additions & 3 deletions test/test_sparse_csr.py
Expand Up @@ -8,7 +8,7 @@
from torch.testing._internal.common_utils import \
(IS_MACOS, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA,
(instantiate_device_type_tests, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoCusparseGeneric,
precisionOverride)
from torch.testing._internal.common_dtype import floating_types, get_all_dtypes

Expand Down Expand Up @@ -402,7 +402,11 @@ def test_matmul_device_mismatch(self, device, dtype):
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
torch.addmm(s, csr, m2)

@dtypes(torch.float, torch.double)
@skipCUDAIfNoCusparseGeneric
@dtypes(*torch.testing.floating_types())
@dtypesIfCUDA(*get_all_complex_dtypes(),
*get_all_fp_dtypes(include_half=SM53OrLater, include_bfloat16=SM80OrLater))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is SM80OrLater correct guard for bfloat16? For regular addmm bfloat16 is supported (with perf equivalent to fp32) for earlier architectures.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, cuSPARSE raises CUSPARSE_STATUS_ARCH_MISMATCH for earlier architectures.
From the documentation:

Unsupported data types and Compute Capability (CC):
__half on GPUs with CC < 53 (e.g. Kepler)
__nv_bfloat16 on GPUs with CC < 80 (e.g. Kepler, Maxwell, Pascal, Volta, Turing)

@precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, 1e-2 seems high for float16?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is. It seems that cuSPARSE uses a different accumulation strategy or something else is different leading to less accurate results than cuBLAS computes.
I'll verify again the tolerances.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, 1e-2 is required for tests to pass for float16.
Running python -m pytest test/test_sparse_csr.py -k "test_csr_matvec" -vvv fails with

Tensors failed to compare as equal!With rtol=0.001 and atol=0.001, found 5 element(s) (out of 100) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.0078125 (4.2578125 vs. 4.25), which occurred at index 69.

Interestingly running specific test python -m pytest test/test_sparse_csr.py -k "test_csr_matvec_cuda_float16" -vvv to generate a different input with same size gives exactly the same greatest difference of 0.0078125!

Tensors failed to compare as equal!With rtol=0.001 and atol=0.001, found 4 element(s) (out of 100) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.0078125 (-3.404296875 vs. -3.412109375), which occurred at index 92.

Tested on CUDA 11.4.2 and Turing card.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using SpMV instead of SpMM makes the test pass without precision overrides for float16. So I'll remove it here.

def test_csr_matvec(self, device, dtype):
side = 100
for index_dtype in [torch.int32, torch.int64]:
Expand All @@ -415,7 +419,12 @@ def test_csr_matvec(self, device, dtype):
self.assertEqual(res, expected)

bad_vec = torch.randn(side + 10, dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError, "mv: expected"):
err_msg = "mv: expected"
# CUDA path now uses generic meta/structured implementation
# TODO: move CPU path to not use `mv_sparse` function
if self.device_type == 'cuda':
err_msg = "size mismatch, got"
with self.assertRaisesRegex(RuntimeError, err_msg):
csr.matmul(bad_vec)

@dtypes(torch.double)
Expand Down
5 changes: 4 additions & 1 deletion torch/testing/_internal/common_device_type.py
Expand Up @@ -14,7 +14,7 @@
skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, DeterministicGuard, TEST_SKIP_NOARCH, \
_TestParametrizer, dtype_name, TEST_WITH_MIOPEN_SUGGEST_NHWC
from torch.testing._internal.common_cuda import _get_torch_cuda_version
from torch.testing._internal.common_cuda import _get_torch_cuda_version, TEST_CUSPARSE_GENERIC
from torch.testing._internal.common_dtype import get_all_dtypes

# The implementation should be moved here as soon as the deprecation period is over.
Expand Down Expand Up @@ -1211,6 +1211,9 @@ def wrap_fn(self, *args, **kwargs):
return wrap_fn
return dec_fn

# Skips a test on CUDA if cuSparse generic API is not available
def skipCUDAIfNoCusparseGeneric(fn):
return skipCUDAIf(not TEST_CUSPARSE_GENERIC, "cuSparse Generic API not available")(fn)

def skipCUDAIfNoCudnn(fn):
return skipCUDAIfCudnnVersionLessThan(0)(fn)
Expand Down