Skip to content

Commit

Permalink
[Sparse] Add reference implementation for addmv (#97353)
Browse files Browse the repository at this point in the history
Partially addresses the problem raised in #96972

Add `test_addmv` and enable `test_block_addmv` on all platforms (so the test could be run on M1)

TODO: Make sure that test_block_addmv non-contiguous mode actually
generate non-contiguous as rigth now it probably does not, as test
passes assuming values are contiguous.

Pull Request resolved: #97353
Approved by: https://github.com/cpuhrsch
  • Loading branch information
malfet authored and pytorchmergebot committed Mar 24, 2023
1 parent 31e858e commit ad5d81a
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 7 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/mkl/Sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <ATen/Config.h>

// MKL Sparse is not currently supported on Windows
// See https://github.com/pytorch/pytorch/pull/50937#issuecomment-779272492
// See https://github.com/pytorch/pytorch/issues/97352
#if AT_MKL_ENABLED() && (!defined(_WIN32))
#define AT_USE_MKL_SPARSE() 1
#else
Expand Down
105 changes: 100 additions & 5 deletions aten/src/ATen/native/sparse/SparseBlasImpl.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Config.h>
#include <ATen/mkl/Sparse.h>
#include <ATen/native/mkl/SparseBlasImpl.h>
#include <ATen/native/sparse/SparseBlasImpl.h>
#include <ATen/SparseCsrTensorUtils.h>
Expand All @@ -14,6 +15,11 @@
#include <ATen/ops/zeros.h>
#endif

#if !AT_USE_MKL_SPARSE()
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#endif

namespace at {
namespace native {
namespace sparse {
Expand Down Expand Up @@ -212,6 +218,90 @@ Tensor& _compressed_row_strided_addmm_out(
}

namespace cpu {
#if !AT_USE_MKL_SPARSE()
namespace {
template<typename scalar_t, typename idx_t>
void addmv_sparse_csr(
const scalar_t* mat_values,
const idx_t* crow_index,
const idx_t* col_index,
const int64_t mat_rows,
const scalar_t* vec,
const scalar_t alpha,
const scalar_t beta,
scalar_t* result) {
at::parallel_for(0, mat_rows, 0, [&](int64_t rstart, int64_t rend) {
for(const auto row: c10::irange(rstart, rend)) {
scalar_t acc(0);
for(const auto idx: c10::irange(crow_index[row], crow_index[row + 1])) {
acc += mat_values[idx] * vec[col_index[idx]];
}
result[row] = acc * alpha + result[row] * beta;
}
});
}

template<typename scalar_t, typename idx_t>
void addmv_sparse_bsr(
const scalar_t* mat_values,
const idx_t* crow_index,
const idx_t* col_index,
const int64_t mat_rows,
const int64_t blocksize_rows,
const int64_t blocksize_cols,
const scalar_t* vec,
const scalar_t alpha,
const scalar_t beta,
scalar_t* result) {
at::parallel_for(0, mat_rows, 0, [&](int64_t rstart, int64_t rend) {
for(const auto row: c10::irange(rstart, rend)) {
const auto block_row = row / blocksize_rows;
const auto block_row_offset = row % blocksize_rows;
scalar_t acc(0);
for(const auto block_idx: c10::irange(crow_index[block_row], crow_index[block_row + 1])) {
const auto block_offs = (block_idx * blocksize_rows + block_row_offset) * blocksize_cols;
const auto vec_offs = col_index[block_idx]* blocksize_cols;
for(const auto idx: c10::irange(blocksize_cols)) {
acc += mat_values[block_offs + idx] * vec[vec_offs + idx];
}
}
result[row] = acc * alpha + result[row] * beta;
}
});
}

template<typename scalar_t, typename idx_t>
void addmv_out_sparse_csr(
const Tensor& mat,
const Tensor& vec,
const Scalar& beta,
const Scalar& alpha,
const Tensor& result) {
auto cont_values = mat.values().contiguous();
if (mat.layout() == kSparseBsr) {
addmv_sparse_bsr(cont_values.template data<scalar_t>(),
mat.crow_indices().template data<idx_t>(),
mat.col_indices().template data_ptr<idx_t>(),
mat.size(0),
mat.values().size(1),
mat.values().size(2),
vec.template data<scalar_t>(),
alpha.template to<scalar_t>(),
beta.template to<scalar_t>(),
result.template data<scalar_t>());
} else {
addmv_sparse_csr(cont_values.template data<scalar_t>(),
mat.crow_indices().template data<idx_t>(),
mat.col_indices().template data_ptr<idx_t>(),
mat.size(0),
vec.template data<scalar_t>(),
alpha.template to<scalar_t>(),
beta.template to<scalar_t>(),
result.template data<scalar_t>());
}
}
} // anonymous namespace
#endif // !AT_USE_MKL_SPARSE()

/*
Computes a sparse matrix-dense vector product defined as
Expand All @@ -229,11 +319,16 @@ void addmv_out_sparse_csr(
const Scalar& beta,
const Scalar& alpha,
const Tensor& result) {
#if !AT_MKL_ENABLED()
TORCH_CHECK(
false,
"Calling addmv on a sparse CPU tensor requires compiling PyTorch with MKL. ",
"Please use PyTorch built MKL support.");
#if !AT_USE_MKL_SPARSE()
TORCH_CHECK(mat.layout() == kSparseBsr || mat.layout() == kSparseCsr, "Unexpected layout", mat.layout());
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
result.scalar_type(), "addmv_out_sparse_csr_impl_reference", [&] {
if (mat.crow_indices().scalar_type() == kLong) {
addmv_out_sparse_csr<scalar_t, int64_t>(mat, vec, beta, alpha, result);
} else {
addmv_out_sparse_csr<scalar_t, int32_t>(mat, vec, beta, alpha, result);
}
});
#else
sparse::impl::mkl::addmv_out_sparse_csr(mat, vec, beta, alpha, result);
#endif
Expand Down
14 changes: 13 additions & 1 deletion test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,6 @@ def ref_half_bfloat16(c, a, b, alpha=None, beta=None, out=None):
@parametrize("block_size", [2, 3])
@parametrize("index_dtype", [torch.int32, torch.int64])
@parametrize("noncontiguous", [True, False])
@skipCPUIfNoMklSparse
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_block_addmv(self, device, dtype, index_dtype, block_size, noncontiguous):
Expand All @@ -1655,6 +1654,19 @@ def test_block_addmv(self, device, dtype, index_dtype, block_size, noncontiguous
c = make_tensor((m * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous)
self.run_test_block_addmm_addmv(torch.addmv, c, a, b, dtype=dtype, device=device)

@parametrize("matrix_shape", [(3, 3), (5, 7), (11, 9)])
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@onlyCPU
def test_addmv(self, device, dtype, matrix_shape):
mat = torch.randn(matrix_shape, dtype=dtype, device=device)
mat[mat.real < 0] = 0
sparse_mat = mat.to_sparse_csr()
mvec = torch.randn((mat.size(1),), dtype=dtype, device=device)
avec = torch.randn((mat.size(0),), dtype=torch.float64, device=device)
ref_output = torch.addmv(avec, mat, mvec)
output = torch.addmv(avec, sparse_mat, mvec)
self.assertEqual(ref_output, output)

@parametrize("block_size", [2, 3])
@parametrize("index_dtype", [torch.int32, torch.int64])
@parametrize("noncontiguous", [True, False])
Expand Down

0 comments on commit ad5d81a

Please sign in to comment.