Skip to content

Commit

Permalink
Implement sparse semantics support in gradcheck (2nd try)
Browse files Browse the repository at this point in the history
ghstack-source-id: fb5de9d0fa81e42ddd728971073ff4acc6ece20a
Pull Request resolved: #95405
  • Loading branch information
pearu committed Feb 24, 2023
1 parent d4882a9 commit 51a9931
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 46 deletions.
22 changes: 10 additions & 12 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4651,7 +4651,7 @@ def fn(sparse):
check_batched_grad=False, fast_mode=fast_mode)
with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'):
gradcheck(fn, torch.rand(10, dtype=torch.double).to_sparse().requires_grad_(True), check_sparse_nnz=False,
check_batched_grad=False, fast_mode=fast_mode)
check_batched_grad=False, fast_mode=fast_mode, masked=True)
check(fast_mode=True)
check(fast_mode=False)

Expand All @@ -4665,8 +4665,8 @@ def fn(sparse_csr):

with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'):
gradcheck(fn, torch.rand(2, 2, dtype=torch.double).to_sparse_csr().requires_grad_(True), check_sparse_nnz=False,
check_batched_grad=False, fast_mode=fast_mode)
# check(fast_mode=True) # RuntimeError: sparse_mask_sparse_csr expects self to be 2D
check_batched_grad=False, fast_mode=fast_mode, masked=True)
check(fast_mode=True)
check(fast_mode=False)

def test_gradcheck_sparse_csc_input(self):
Expand All @@ -4679,8 +4679,8 @@ def fn(sparse_csc):

with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'):
gradcheck(fn, torch.rand(2, 2, dtype=torch.double).to_sparse_csc().requires_grad_(True), check_sparse_nnz=False,
check_batched_grad=False, fast_mode=fast_mode)
# check(fast_mode=True) # RuntimeError: Expected result Tensor to be of format CSR
check_batched_grad=False, fast_mode=fast_mode, masked=True)
check(fast_mode=True)
check(fast_mode=False)

def test_gradcheck_sparse_bsr_input(self):
Expand All @@ -4693,9 +4693,8 @@ def fn(sparse_bsr):

with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'):
gradcheck(fn, torch.rand(2, 2, dtype=torch.double).to_sparse_bsr((2, 2)).requires_grad_(True),
check_sparse_nnz=False, check_batched_grad=False, fast_mode=fast_mode)
# RuntimeError: "empty_sparse_compressed" expected sparse compressed (non-block) tensor layout but got SparseBsr
# check(fast_mode=True)
check_sparse_nnz=False, check_batched_grad=False, fast_mode=fast_mode, masked=True)
check(fast_mode=True)
check(fast_mode=False)

def test_gradcheck_sparse_bsc_input(self):
Expand All @@ -4708,9 +4707,8 @@ def fn(sparse_bsc):

with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'):
gradcheck(fn, torch.rand(2, 2, dtype=torch.double).to_sparse_bsc((2, 2)).requires_grad_(True),
check_sparse_nnz=False, check_batched_grad=False, fast_mode=fast_mode)
# RuntimeError: "empty_sparse_compressed" expected sparse compressed (non-block) tensor layout but got SparseBsc
# check(fast_mode=True)
check_sparse_nnz=False, check_batched_grad=False, fast_mode=fast_mode, masked=True)
check(fast_mode=True)
check(fast_mode=False)

def test_gradcheck_nondeterministic(self):
Expand Down Expand Up @@ -4746,7 +4744,7 @@ def check(fast_mode):
x = torch.rand(10, requires_grad=True).to_sparse()
with self.assertRaisesRegex(RuntimeError, 'dense when check_sparse_nnz is set to False.'):
gradcheck(lambda x: x.to_dense(), (x,), check_sparse_nnz=False, check_batched_grad=False,
fast_mode=fast_mode)
fast_mode=fast_mode, masked=True)
self.assertFalse(gradcheck(lambda x: x.to_dense(), (x,), check_sparse_nnz=False,
check_batched_grad=False, raise_exception=False, fast_mode=fast_mode))

Expand Down
127 changes: 114 additions & 13 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \
parametrize, subtest, is_coalesced_indices, suppress_warnings
parametrize, subtest, is_coalesced_indices, suppress_warnings, is_slow_gradcheck_env
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
from numbers import Number
from typing import Dict, Any
Expand Down Expand Up @@ -58,6 +58,15 @@ def all_sparse_layouts(test_name='layout', include_strided=False):
subtest(torch.sparse_bsc, name='SparseBSC'),
][(0 if include_strided else 1):])

def gradcheck_semantics(test_name='gradcheck'):
gradcheck_sparse = functools.partial(gradcheck, masked=False)
gradcheck_masked = functools.partial(gradcheck, masked=True, check_sparse_nnz=True)
gradcheck_sparse.masked = False
gradcheck_masked.masked = True
return parametrize(test_name, [
subtest(gradcheck_sparse, name='sparse'),
subtest(gradcheck_masked, name='masked')])


class CrossRefSparseFakeMode(torch._subclasses.CrossRefFakeMode):
def __init__(self):
Expand Down Expand Up @@ -402,7 +411,11 @@ def test_ctor_size_checks(self, device, dtype):

@dtypes(*floating_and_complex_types_and(torch.float16, torch.bfloat16))
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
def test_to_dense(self, device, dtype):
@gradcheck_semantics()
def test_to_dense_with_gradcheck(self, device, dtype, gradcheck):
if not gradcheck.masked and is_slow_gradcheck_env():
self.skipTest('FIXME: to_dense_backward supports masked semantics only')

def test_tensor(x, res):
x.to_dense() # Tests triple to_dense for memory corruption
x.to_dense()
Expand Down Expand Up @@ -535,7 +548,11 @@ def test_shared(self, device, dtype):

@dtypes(torch.double, torch.cdouble)
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
def test_to_dense_hybrid(self, device, dtype):
@gradcheck_semantics()
def test_to_dense_hybrid(self, device, dtype, gradcheck):
if not gradcheck.masked and is_slow_gradcheck_env():
self.skipTest('FIXME: to_dense_backward supports masked semantics only')

def test_tensor(x, res):
x.to_dense() # Tests double to_dense for memory corruption
x.to_dense()
Expand Down Expand Up @@ -889,7 +906,8 @@ def test_shape(sparse_dims, nnz, with_size):
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
def test_permute(self, device, dtype, coalesced):
@gradcheck_semantics()
def test_permute(self, device, dtype, coalesced, gradcheck):
# trivial checks
s = torch.rand(3, 3, 3, device=device, dtype=dtype).to_sparse()
with self.assertRaisesRegex(RuntimeError, "does not match the length"):
Expand Down Expand Up @@ -1513,7 +1531,8 @@ def test_shape(di, dj, dk, nnz):
@coalescedonoff
@unittest.skip("See https://github.com/pytorch/pytorch/issues/73145")
@dtypes(torch.double, torch.cdouble, torch.bfloat16)
def test_sparse_addmm(self, device, dtype, coalesced):
@gradcheck_semantics()
def test_sparse_addmm(self, device, dtype, coalesced, gradcheck):
def test_shape(m, n, p, nnz, broadcast, alpha_beta=None):
if alpha_beta is None:
alpha = random.random()
Expand Down Expand Up @@ -1560,15 +1579,16 @@ def test_shape(d1, d2, d3, nnz, transposed):

def fn(S, D):
return torch.sparse.mm(S, D)
gradcheck(fn, (S, D), check_sparse_nnz=True)
gradcheck(fn, (S, D), check_sparse_nnz=True, masked=True)

test_shape(7, 8, 9, 20, False)
test_shape(7, 8, 9, 20, True)

@coalescedonoff
@dtypes(torch.double)
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
def test_sparse_mul(self, device, dtype, coalesced):
@gradcheck_semantics()
def test_sparse_mul(self, device, dtype, coalesced, gradcheck):
# https://github.com/pytorch/pytorch/issues/79914
a = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True)
b = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True)
Expand Down Expand Up @@ -1760,7 +1780,7 @@ def fn(S):
if res.is_sparse:
res = res.to_dense()
return res
gradcheck(fn, (S,), check_sparse_nnz=True)
gradcheck(fn, (S,), check_sparse_nnz=True, masked=True)
else:
S_sum = torch.sparse.sum(S, td)
D_sum = D.sum(td)
Expand All @@ -1771,7 +1791,7 @@ def fn(S):
if res.is_sparse:
res = res.to_dense()
return res
gradcheck(fn, (S,), check_sparse_nnz=True)
gradcheck(fn, (S,), check_sparse_nnz=True, masked=True)

nnz = 10
sparse_dims = 2
Expand Down Expand Up @@ -3557,9 +3577,9 @@ def fn(D1, D2):
# This is because cuSparse sometimes returns approximate zero values like `~e-323`
# TODO: Check this cuSparse issue.
# This happens when you do chain multiplication `torch.sparse.mm` operations
gradcheck(fn, (a, b), check_sparse_nnz=True, nondet_tol=1e-5)
gradcheck(fn, (a, b), check_sparse_nnz=True, nondet_tol=1e-5, masked=True)
else:
gradcheck(fn, (a, b), check_sparse_nnz=True)
gradcheck(fn, (a, b), check_sparse_nnz=True, masked=True)
grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b)

def test_error_cases():
Expand Down Expand Up @@ -4060,7 +4080,8 @@ def fn(x):
check_grad_dtypes=True,
check_sparse_nnz=True,
nondet_tol=op.gradcheck_nondet_tol,
fast_mode=op.gradcheck_fast_mode))
fast_mode=op.gradcheck_fast_mode,
masked=True))


class TestSparseMaskedReductions(TestCase):
Expand Down Expand Up @@ -4354,14 +4375,42 @@ def test_generate_simple_inputs(self):
@parametrize("index_dtype", [torch.int32, torch.int64])
def test_to_dense(self, from_layout, device, dtype, index_dtype):
"""
This test tests conversion from any layout to any sparse layout.
This test tests conversion from any layout to strided layout.
"""
for t in self.generate_simple_inputs(
from_layout, device=device, dtype=dtype, index_dtype=index_dtype):
r = t.to_dense()
self.assertEqual(r.layout, torch.strided)
self.assertEqual(r, t)

@all_sparse_layouts('from_layout', include_strided=False)
@dtypes(torch.float64, torch.complex128)
@parametrize("index_dtype", [torch.int64])
@gradcheck_semantics()
def test_gradcheck_to_dense(self, from_layout, device, dtype, index_dtype, gradcheck):
for t in self.generate_simple_inputs(
from_layout, device=device, dtype=dtype, index_dtype=index_dtype):
batch_dim = t.dim() - t.dense_dim() - t.sparse_dim()
if batch_dim > 0:
# TODO: implement batch support in _convert_indices_from_csr_to_coo
continue
t = t.clone().detach().requires_grad_(True)
if is_slow_gradcheck_env() and not gradcheck.masked:
# TODO: remove this if-block when TODO items below are resolved
try:
gradcheck(torch.Tensor.to_dense, t)
except RuntimeError as msg:
# TODO: implement non-masked semantics support in to_dense_backward
with self.assertRaisesRegex(RuntimeError, "Jacobian mismatch"):
gradcheck(torch.Tensor.to_dense, t)
self.skipTest('non-masked semantics not supported')
r = gradcheck(torch.Tensor.to_dense, t)
self.assertTrue(r)

# when the following assert fails, it means that the if-block
# above and the assertFalse test below can be safely removed
self.assertFalse(is_slow_gradcheck_env() and not gradcheck.masked)

@all_sparse_layouts('from_layout', include_strided=True)
@all_sparse_layouts('to_layout', include_strided=False)
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
Expand Down Expand Up @@ -4629,6 +4678,58 @@ def test_unsupported_backend_error_message(self, mth, layout, device):
with self.assertRaisesRegex(RuntimeError, expected_behaviour[1]):
mth(inp)

@onlyNativeDeviceTypes
@all_sparse_layouts('layout', include_strided=not True)
@dtypes(torch.float64, torch.cdouble)
@parametrize("masked", [subtest(False, name='sparse'), subtest(True, name='masked')])
@parametrize("fast_mode", [subtest(False, name='slow'), subtest(True, name='fast')])
def test_gradcheck_mm(self, layout, dtype, device, masked, fast_mode):
# This function does not check the following cases:
# - batch or hybrid tensors because addmm does not support
# such inputs yet
# - check_forward_ad=True because of the lack of sparse tensor
# support in aten::view_as_real, torch._VF._make_dual, etc.

ref_x = torch.tensor([[1, 2, 0, 0],
[0, 6, 0, 0],
[0, 0, 0, 0],
[13, 14, 0, 15]], dtype=dtype, device=device)
ref_y = torch.tensor([[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]],
dtype=dtype, device=device)

mm = torch.sparse.mm if masked else torch.mm

blocksize = (2, 2) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
x = ref_x.to_sparse(layout=layout, blocksize=blocksize).requires_grad_(True)
y = ref_y.requires_grad_(True)

if layout is torch.sparse_bsr and not masked or layout is torch.sparse_bsc:
with self.assertRaisesRegex(
RuntimeError,
r"addmm: computation on (CPU|CUDA) is not implemented for Strided \+ Sparse(Bsr|Bsc) @ Strided"):
torch.autograd.gradcheck(mm, (x, y), check_sparse_nnz=True, fast_mode=fast_mode, masked=masked)
self.skipTest('NOT IMPL')
elif layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and masked:
with self.assertRaisesRegex(
RuntimeError,
r"(sparse_addmm_sparse_backward: unsupported combination of layouts,"
r" grad: Strided, mat1: Sparse(Csc|Bsr|Bsc), mat2: Strided"
r"|addmm: computation on (CPU|CUDA) is not implemented for "
r"Strided \+ Sparse(Csc|Bsr|Bsc) @ Strided without MKL)"):
torch.autograd.gradcheck(mm, (x, y), check_sparse_nnz=True, fast_mode=fast_mode, masked=masked)
self.skipTest('NOT IMPL')
else:
if masked:
r = torch.autograd.gradcheck(mm, (x, y), check_sparse_nnz=True, fast_mode=fast_mode, masked=masked)
else:
# Specifying check_sparse_nnz is unnecessary in
# non-masked/sparse semantics
r = torch.autograd.gradcheck(mm, (x, y), fast_mode=fast_mode, masked=masked)
self.assertTrue(r)


# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')
Expand Down

0 comments on commit 51a9931

Please sign in to comment.