Skip to content

Commit

Permalink
sparse.mm.backward: fix for non-contiguous grad values on CPU
Browse files Browse the repository at this point in the history
ghstack-source-id: bfd5b612a65f91ef50a03bf35ad9d87aa8f9f3fc
Pull Request resolved: #106127
  • Loading branch information
nikitaved committed Jul 27, 2023
1 parent 0eb8cb2 commit 7bbe0cd
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 34 deletions.
92 changes: 58 additions & 34 deletions aten/src/ATen/native/sparse/SparseMatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,18 @@ void csr_to_coo(const int64_t n_row, const int64_t Ap[], int64_t Bi[]) {
}
}

template<typename index_t_ptr = int64_t*>
int64_t _csr_matmult_maxnnz(
const int64_t n_row,
const int64_t n_col,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Ap[],
const index_t_ptr Ap,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Aj[],
const index_t_ptr Aj,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Bp[],
const index_t_ptr Bp,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Bj[]) {
const index_t_ptr Bj) {
/*
Compute needed buffer size for matrix `C` in `C = A@B` operation.
Expand Down Expand Up @@ -88,28 +89,28 @@ int64_t _csr_matmult_maxnnz(
return nnz;
}

template<class scalar_t>
template<typename index_t_ptr, typename scalar_t_ptr>
void _csr_matmult(
const int64_t n_row,
const int64_t n_col,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Ap[],
const index_t_ptr Ap,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Aj[],
const index_t_ptr Aj,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const scalar_t Ax[],
const scalar_t_ptr Ax,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Bp[],
const index_t_ptr Bp,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Bj[],
const index_t_ptr Bj,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const scalar_t Bx[],
const scalar_t_ptr Bx,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t Cp[],
typename index_t_ptr::value_type Cp[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t Cj[],
typename index_t_ptr::value_type Cj[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
scalar_t Cx[]) {
typename scalar_t_ptr::value_type Cx[]) {
/*
Compute CSR entries for matrix C = A@B.
Expand All @@ -133,27 +134,30 @@ void _csr_matmult(
Note:
Output arrays Cp, Cj, and Cx must be preallocated
*/
std::vector<int64_t> next(n_col, -1);
using index_t = typename index_t_ptr::value_type;
using scalar_t = typename scalar_t_ptr::value_type;

std::vector<index_t> next(n_col, -1);
std::vector<scalar_t> sums(n_col, 0);

int64_t nnz = 0;

Cp[0] = 0;

for (const auto i : c10::irange(n_row)) {
int64_t head = -2;
int64_t length = 0;
index_t head = -2;
index_t length = 0;

int64_t jj_start = Ap[i];
int64_t jj_end = Ap[i + 1];
index_t jj_start = Ap[i];
index_t jj_end = Ap[i + 1];
for (const auto jj : c10::irange(jj_start, jj_end)) {
int64_t j = Aj[jj];
index_t j = Aj[jj];
scalar_t v = Ax[jj];

int64_t kk_start = Bp[j];
int64_t kk_end = Bp[j + 1];
index_t kk_start = Bp[j];
index_t kk_end = Bp[j + 1];
for (const auto kk : c10::irange(kk_start, kk_end)) {
int64_t k = Bj[kk];
index_t k = Bj[kk];

sums[k] += v * Bx[kk];

Expand All @@ -174,7 +178,7 @@ void _csr_matmult(
Cx[nnz] = sums[head];
nnz++;

int64_t temp = head;
index_t temp = head;
head = next[head];

next[temp] = -1; // clear arrays
Expand All @@ -183,6 +187,7 @@ void _csr_matmult(

// Make sure that col indices are sorted.
// TODO: a better approach is to implement a CSR @ CSC kernel.
// NOTE: Cx arrays are expected to be contiguous!
auto col_indices_accessor = StridedRandomAccessor<int64_t>(Cj + nnz - length, 1);
auto val_accessor = StridedRandomAccessor<scalar_t>(Cx + nnz - length, 1);
auto kv_accessor = CompositeRandomAccessorCPU<
Expand Down Expand Up @@ -212,13 +217,32 @@ void sparse_matmul_kernel(
const auto mat1_csr = mat1.to_sparse_csr();
const auto mat2_csr = mat2.to_sparse_csr();

const auto nnz = _csr_matmult_maxnnz(
M,
N,
auto mat1_crow_indices_ptr = StridedRandomAccessor<int64_t>(
mat1_csr.crow_indices().data_ptr<int64_t>(),
mat1_csr.crow_indices().stride(-1));
auto mat1_col_indices_ptr = StridedRandomAccessor<int64_t>(
mat1_csr.col_indices().data_ptr<int64_t>(),
mat1_csr.col_indices().stride(-1));
auto mat1_values_ptr = StridedRandomAccessor<scalar_t>(
mat1_csr.values().data_ptr<scalar_t>(),
mat1_csr.values().stride(-1));
auto mat2_crow_indices_ptr = StridedRandomAccessor<int64_t>(
mat2_csr.crow_indices().data_ptr<int64_t>(),
mat2_csr.col_indices().data_ptr<int64_t>());
mat2_csr.crow_indices().stride(-1));
auto mat2_col_indices_ptr = StridedRandomAccessor<int64_t>(
mat2_csr.col_indices().data_ptr<int64_t>(),
mat2_csr.col_indices().stride(-1));
auto mat2_values_ptr = StridedRandomAccessor<scalar_t>(
mat2_csr.values().data_ptr<scalar_t>(),
mat2_csr.values().stride(-1));

const auto nnz = _csr_matmult_maxnnz(
M,
N,
mat1_crow_indices_ptr,
mat1_col_indices_ptr,
mat2_crow_indices_ptr,
mat2_col_indices_ptr);

auto output_indices = output._indices();
auto output_values = output._values();
Expand All @@ -234,12 +258,12 @@ void sparse_matmul_kernel(
_csr_matmult(
M,
N,
mat1_csr.crow_indices().data_ptr<int64_t>(),
mat1_csr.col_indices().data_ptr<int64_t>(),
mat1_csr.values().data_ptr<scalar_t>(),
mat2_csr.crow_indices().data_ptr<int64_t>(),
mat2_csr.col_indices().data_ptr<int64_t>(),
mat2_csr.values().data_ptr<scalar_t>(),
mat1_crow_indices_ptr,
mat1_col_indices_ptr,
mat1_values_ptr,
mat2_crow_indices_ptr,
mat2_col_indices_ptr,
mat2_values_ptr,
output_indptr.data_ptr<int64_t>(),
output_col_indices.data_ptr<int64_t>(),
output_values.data_ptr<scalar_t>());
Expand Down
12 changes: 12 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3700,6 +3700,17 @@ def different_dtypes():

self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes)

def test_backward_noncontiguous():
# Sparse.mm backward used to wrong with non-contiguous grads,
# see https://github.com/pytorch/pytorch/issues/102493.
n_reps = 7
for _ in range(n_reps):
A = torch.eye(5).to_sparse().requires_grad_(True)
B = torch.eye(5).to_sparse()
out = torch.sparse.mm(A, B)
out.coalesce().values().sum().backward()
self.assertEqual(A.grad, A)

for n in range(2, 5):
for m in range(2, 8):
for p in range(2, 8):
Expand All @@ -3708,6 +3719,7 @@ def different_dtypes():
test_sparse_matmul(2, 0, [0, 0], [0, 0])
test_sparse_matmul(2, 0, [0, 10], [10, 0])
test_error_cases()
test_backward_noncontiguous()

@coalescedonoff
@dtypes(torch.double)
Expand Down

0 comments on commit 7bbe0cd

Please sign in to comment.