Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sparse] Add reference implementation for addmv #97353

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/mkl/Sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <ATen/Config.h>

// MKL Sparse is not currently supported on Windows
// See https://github.com/pytorch/pytorch/pull/50937#issuecomment-779272492
// See https://github.com/pytorch/pytorch/issues/97352
#if AT_MKL_ENABLED() && (!defined(_WIN32))
#define AT_USE_MKL_SPARSE() 1
#else
Expand Down
88 changes: 83 additions & 5 deletions aten/src/ATen/native/sparse/SparseBlasImpl.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Config.h>
#include <ATen/mkl/Sparse.h>
#include <ATen/native/mkl/SparseBlasImpl.h>
#include <ATen/native/sparse/SparseBlasImpl.h>
#include <ATen/SparseCsrTensorUtils.h>
Expand All @@ -14,6 +15,11 @@
#include <ATen/ops/zeros.h>
#endif

#if !AT_USE_MKL_SPARSE()
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#endif

namespace at {
namespace native {
namespace sparse {
Expand Down Expand Up @@ -212,6 +218,56 @@ Tensor& _compressed_row_strided_addmm_out(
}

namespace cpu {
#if !AT_USE_MKL_SPARSE()
template<typename scalar_t>
void addmv_sparse_csr(
const scalar_t* mat_values,
const int64_t* crow_index,
const int64_t* col_index,
const int64_t rows,
const int64_t cols,
const scalar_t* vec,
const scalar_t alpha,
const scalar_t beta,
scalar_t* result) {
at::parallel_for(0, rows, 0, [&](int64_t rstart, int64_t rend) {
for(const auto row: c10::irange(rstart, rend)) {
scalar_t acc(0);
for(const auto idx: c10::irange(crow_index[row], crow_index[row+1])) {
acc += mat_values[idx]*vec[col_index[idx]];
}
result[row] = acc * alpha + result[row]*beta;
}
});
}

template<typename scalar_t>
void addmv_sparse_bsr(
malfet marked this conversation as resolved.
Show resolved Hide resolved
const scalar_t* mat_values,
const int64_t* crow_index,
const int64_t* col_index,
const int64_t rows,
const int64_t cols,
const int64_t blocksize,
const scalar_t* vec,
const scalar_t alpha,
const scalar_t beta,
scalar_t* result) {
at::parallel_for(0, rows, 0, [&](int64_t rstart, int64_t rend) {
for(const auto row: c10::irange(rstart, rend)) {
const auto brow = row / blocksize;
const auto rrow = row % blocksize;
scalar_t acc(0);
for(const auto bidx: c10::irange(crow_index[brow], crow_index[brow+1])) {
for(const auto idx: c10::irange(blocksize)) {
acc += mat_values[(bidx*blocksize+rrow)*blocksize + idx]*vec[col_index[bidx]*blocksize + idx];
}
}
result[row] = acc * alpha + result[row]*beta;
}
});
}
#endif // !AT_USE_MKL_SPARSE()

/*
Computes a sparse matrix-dense vector product defined as
Expand All @@ -229,11 +285,33 @@ void addmv_out_sparse_csr(
const Scalar& beta,
const Scalar& alpha,
const Tensor& result) {
#if !AT_MKL_ENABLED()
TORCH_CHECK(
false,
"Calling addmv on a sparse CPU tensor requires compiling PyTorch with MKL. ",
"Please use PyTorch built MKL support.");
#if !AT_USE_MKL_SPARSE()
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
result.scalar_type(), "addmv_out_sparse_csr_impl_reference", [&] {
if (mat.layout() == kSparseBsr) {
addmv_sparse_bsr(mat.values().data<scalar_t>(),
mat.crow_indices().toType(kLong).data<int64_t>(),
mat.col_indices().toType(kLong).data_ptr<int64_t>(),
mat.size(0),
mat.size(1),
mat.values().size(1),
vec.data<scalar_t>(),
alpha.to<scalar_t>(),
beta.to<scalar_t>(),
result.data<scalar_t>());
} else {
addmv_sparse_csr(mat.values().data<scalar_t>(),
mat.crow_indices().data<int64_t>(),
mat.col_indices().data_ptr<int64_t>(),
mat.size(0),
mat.size(1),
vec.data<scalar_t>(),
alpha.to<scalar_t>(),
beta.to<scalar_t>(),
result.data<scalar_t>());
}

});
#else
sparse::impl::mkl::addmv_out_sparse_csr(mat, vec, beta, alpha, result);
#endif
Expand Down
14 changes: 13 additions & 1 deletion test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,6 @@ def ref_half_bfloat16(c, a, b, alpha=None, beta=None, out=None):
@parametrize("block_size", [2, 3])
@parametrize("index_dtype", [torch.int32, torch.int64])
@parametrize("noncontiguous", [True, False])
@skipCPUIfNoMklSparse
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_block_addmv(self, device, dtype, index_dtype, block_size, noncontiguous):
Expand All @@ -1655,6 +1654,19 @@ def test_block_addmv(self, device, dtype, index_dtype, block_size, noncontiguous
c = make_tensor((m * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous)
self.run_test_block_addmm_addmv(torch.addmv, c, a, b, dtype=dtype, device=device)

@parametrize("matrix_shape", [(3, 3), (5, 7), (11, 9)])
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@onlyCPU
def test_addmv(self, device, dtype, matrix_shape):
mat = torch.randn(matrix_shape, dtype=dtype, device=device)
mat[mat.real < 0] = 0
sparse_mat = mat.to_sparse_csr()
mvec = torch.randn((mat.size(1),), dtype=dtype, device=device)
avec = torch.randn((mat.size(0),), dtype=torch.float64, device=device)
ref_output = torch.addmv(avec, mat, mvec)
output = torch.addmv(avec, sparse_mat, mvec)
self.assertEqual(ref_output, output)

@parametrize("block_size", [2, 3])
@parametrize("index_dtype", [torch.int32, torch.int64])
@parametrize("noncontiguous", [True, False])
Expand Down