Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sparse semantics support in gradcheck (2nd try) #95405

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 10 additions & 12 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4651,7 +4651,7 @@ def fn(sparse):
check_batched_grad=False, fast_mode=fast_mode)
with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'):
gradcheck(fn, torch.rand(10, dtype=torch.double).to_sparse().requires_grad_(True), check_sparse_nnz=False,
check_batched_grad=False, fast_mode=fast_mode)
check_batched_grad=False, fast_mode=fast_mode, masked=True)
check(fast_mode=True)
check(fast_mode=False)

Expand All @@ -4665,8 +4665,8 @@ def fn(sparse_csr):

with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'):
gradcheck(fn, torch.rand(2, 2, dtype=torch.double).to_sparse_csr().requires_grad_(True), check_sparse_nnz=False,
check_batched_grad=False, fast_mode=fast_mode)
# check(fast_mode=True) # RuntimeError: sparse_mask_sparse_csr expects self to be 2D
check_batched_grad=False, fast_mode=fast_mode, masked=True)
check(fast_mode=True)
check(fast_mode=False)

def test_gradcheck_sparse_csc_input(self):
Expand All @@ -4679,8 +4679,8 @@ def fn(sparse_csc):

with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'):
gradcheck(fn, torch.rand(2, 2, dtype=torch.double).to_sparse_csc().requires_grad_(True), check_sparse_nnz=False,
check_batched_grad=False, fast_mode=fast_mode)
# check(fast_mode=True) # RuntimeError: Expected result Tensor to be of format CSR
check_batched_grad=False, fast_mode=fast_mode, masked=True)
check(fast_mode=True)
check(fast_mode=False)

def test_gradcheck_sparse_bsr_input(self):
Expand All @@ -4693,9 +4693,8 @@ def fn(sparse_bsr):

with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'):
gradcheck(fn, torch.rand(2, 2, dtype=torch.double).to_sparse_bsr((2, 2)).requires_grad_(True),
check_sparse_nnz=False, check_batched_grad=False, fast_mode=fast_mode)
# RuntimeError: "empty_sparse_compressed" expected sparse compressed (non-block) tensor layout but got SparseBsr
# check(fast_mode=True)
check_sparse_nnz=False, check_batched_grad=False, fast_mode=fast_mode, masked=True)
check(fast_mode=True)
check(fast_mode=False)

def test_gradcheck_sparse_bsc_input(self):
Expand All @@ -4708,9 +4707,8 @@ def fn(sparse_bsc):

with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'):
gradcheck(fn, torch.rand(2, 2, dtype=torch.double).to_sparse_bsc((2, 2)).requires_grad_(True),
check_sparse_nnz=False, check_batched_grad=False, fast_mode=fast_mode)
# RuntimeError: "empty_sparse_compressed" expected sparse compressed (non-block) tensor layout but got SparseBsc
# check(fast_mode=True)
check_sparse_nnz=False, check_batched_grad=False, fast_mode=fast_mode, masked=True)
check(fast_mode=True)
check(fast_mode=False)

def test_gradcheck_nondeterministic(self):
Expand Down Expand Up @@ -4746,7 +4744,7 @@ def check(fast_mode):
x = torch.rand(10, requires_grad=True).to_sparse()
with self.assertRaisesRegex(RuntimeError, 'dense when check_sparse_nnz is set to False.'):
gradcheck(lambda x: x.to_dense(), (x,), check_sparse_nnz=False, check_batched_grad=False,
fast_mode=fast_mode)
fast_mode=fast_mode, masked=True)
self.assertFalse(gradcheck(lambda x: x.to_dense(), (x,), check_sparse_nnz=False,
check_batched_grad=False, raise_exception=False, fast_mode=fast_mode))

Expand Down
129 changes: 116 additions & 13 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \
parametrize, subtest, is_coalesced_indices, suppress_warnings
parametrize, subtest, is_coalesced_indices, suppress_warnings, is_slow_gradcheck_env
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
from numbers import Number
from typing import Dict, Any
Expand Down Expand Up @@ -58,6 +58,15 @@ def all_sparse_layouts(test_name='layout', include_strided=False):
subtest(torch.sparse_bsc, name='SparseBSC'),
][(0 if include_strided else 1):])

def gradcheck_semantics(test_name='gradcheck'):
gradcheck_sparse = functools.partial(gradcheck, masked=False)
gradcheck_masked = functools.partial(gradcheck, masked=True, check_sparse_nnz=True)
gradcheck_sparse.masked = False
gradcheck_masked.masked = True
return parametrize(test_name, [
subtest(gradcheck_sparse, name='sparse'),
subtest(gradcheck_masked, name='masked')])


class CrossRefSparseFakeMode(torch._subclasses.CrossRefFakeMode):
def __init__(self):
Expand Down Expand Up @@ -402,7 +411,11 @@ def test_ctor_size_checks(self, device, dtype):

@dtypes(*floating_and_complex_types_and(torch.float16, torch.bfloat16))
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
def test_to_dense(self, device, dtype):
@gradcheck_semantics()
def test_to_dense_with_gradcheck(self, device, dtype, gradcheck):
if not gradcheck.masked and is_slow_gradcheck_env():
self.skipTest('FIXME: to_dense_backward supports masked semantics only')

def test_tensor(x, res):
x.to_dense() # Tests triple to_dense for memory corruption
x.to_dense()
Expand Down Expand Up @@ -535,7 +548,11 @@ def test_shared(self, device, dtype):

@dtypes(torch.double, torch.cdouble)
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
def test_to_dense_hybrid(self, device, dtype):
@gradcheck_semantics()
def test_to_dense_hybrid(self, device, dtype, gradcheck):
if not gradcheck.masked and is_slow_gradcheck_env():
self.skipTest('FIXME: to_dense_backward supports masked semantics only')

def test_tensor(x, res):
x.to_dense() # Tests double to_dense for memory corruption
x.to_dense()
Expand Down Expand Up @@ -889,7 +906,10 @@ def test_shape(sparse_dims, nnz, with_size):
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
def test_permute(self, device, dtype, coalesced):
@gradcheck_semantics()
def test_permute(self, device, dtype, coalesced, gradcheck):
if not gradcheck.masked and is_slow_gradcheck_env():
self.skipTest('FIXME: to_dense_backward supports masked semantics only')
# trivial checks
s = torch.rand(3, 3, 3, device=device, dtype=dtype).to_sparse()
with self.assertRaisesRegex(RuntimeError, "does not match the length"):
Expand Down Expand Up @@ -1513,7 +1533,8 @@ def test_shape(di, dj, dk, nnz):
@coalescedonoff
@unittest.skip("See https://github.com/pytorch/pytorch/issues/73145")
@dtypes(torch.double, torch.cdouble, torch.bfloat16)
def test_sparse_addmm(self, device, dtype, coalesced):
@gradcheck_semantics()
def test_sparse_addmm(self, device, dtype, coalesced, gradcheck):
def test_shape(m, n, p, nnz, broadcast, alpha_beta=None):
if alpha_beta is None:
alpha = random.random()
Expand Down Expand Up @@ -1560,15 +1581,16 @@ def test_shape(d1, d2, d3, nnz, transposed):

def fn(S, D):
return torch.sparse.mm(S, D)
gradcheck(fn, (S, D), check_sparse_nnz=True)
gradcheck(fn, (S, D), check_sparse_nnz=True, masked=True)

test_shape(7, 8, 9, 20, False)
test_shape(7, 8, 9, 20, True)

@coalescedonoff
@dtypes(torch.double)
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
def test_sparse_mul(self, device, dtype, coalesced):
@gradcheck_semantics()
def test_sparse_mul(self, device, dtype, coalesced, gradcheck):
# https://github.com/pytorch/pytorch/issues/79914
a = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True)
b = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True)
Expand Down Expand Up @@ -1760,7 +1782,7 @@ def fn(S):
if res.is_sparse:
res = res.to_dense()
return res
gradcheck(fn, (S,), check_sparse_nnz=True)
gradcheck(fn, (S,), check_sparse_nnz=True, masked=True)
else:
S_sum = torch.sparse.sum(S, td)
D_sum = D.sum(td)
Expand All @@ -1771,7 +1793,7 @@ def fn(S):
if res.is_sparse:
res = res.to_dense()
return res
gradcheck(fn, (S,), check_sparse_nnz=True)
gradcheck(fn, (S,), check_sparse_nnz=True, masked=True)

nnz = 10
sparse_dims = 2
Expand Down Expand Up @@ -3557,9 +3579,9 @@ def fn(D1, D2):
# This is because cuSparse sometimes returns approximate zero values like `~e-323`
# TODO: Check this cuSparse issue.
# This happens when you do chain multiplication `torch.sparse.mm` operations
gradcheck(fn, (a, b), check_sparse_nnz=True, nondet_tol=1e-5)
gradcheck(fn, (a, b), check_sparse_nnz=True, nondet_tol=1e-5, masked=True)
else:
gradcheck(fn, (a, b), check_sparse_nnz=True)
gradcheck(fn, (a, b), check_sparse_nnz=True, masked=True)
grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b)

def test_error_cases():
Expand Down Expand Up @@ -4060,7 +4082,8 @@ def fn(x):
check_grad_dtypes=True,
check_sparse_nnz=True,
nondet_tol=op.gradcheck_nondet_tol,
fast_mode=op.gradcheck_fast_mode))
fast_mode=op.gradcheck_fast_mode,
masked=True))


class TestSparseMaskedReductions(TestCase):
Expand Down Expand Up @@ -4354,14 +4377,42 @@ def test_generate_simple_inputs(self):
@parametrize("index_dtype", [torch.int32, torch.int64])
def test_to_dense(self, from_layout, device, dtype, index_dtype):
"""
This test tests conversion from any layout to any sparse layout.
This test tests conversion from any layout to strided layout.
"""
for t in self.generate_simple_inputs(
from_layout, device=device, dtype=dtype, index_dtype=index_dtype):
r = t.to_dense()
self.assertEqual(r.layout, torch.strided)
self.assertEqual(r, t)

@all_sparse_layouts('from_layout', include_strided=False)
@dtypes(torch.float64, torch.complex128)
@parametrize("index_dtype", [torch.int64])
@gradcheck_semantics()
def test_gradcheck_to_dense(self, from_layout, device, dtype, index_dtype, gradcheck):
for t in self.generate_simple_inputs(
from_layout, device=device, dtype=dtype, index_dtype=index_dtype):
batch_dim = t.dim() - t.dense_dim() - t.sparse_dim()
if batch_dim > 0:
# TODO: implement batch support in _convert_indices_from_csr_to_coo
continue
t = t.clone().detach().requires_grad_(True)
if is_slow_gradcheck_env() and not gradcheck.masked:
# TODO: remove this if-block when TODO items below are resolved
try:
gradcheck(torch.Tensor.to_dense, t)
except RuntimeError as msg:
# TODO: implement non-masked semantics support in to_dense_backward
with self.assertRaisesRegex(RuntimeError, "Jacobian mismatch"):
gradcheck(torch.Tensor.to_dense, t)
self.skipTest('non-masked semantics not supported')
r = gradcheck(torch.Tensor.to_dense, t)
self.assertTrue(r)

# when the following assert fails, it means that the if-block
# above and the assertFalse test below can be safely removed
self.assertFalse(is_slow_gradcheck_env() and not gradcheck.masked)

@all_sparse_layouts('from_layout', include_strided=True)
@all_sparse_layouts('to_layout', include_strided=False)
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
Expand Down Expand Up @@ -4629,6 +4680,58 @@ def test_unsupported_backend_error_message(self, mth, layout, device):
with self.assertRaisesRegex(RuntimeError, expected_behaviour[1]):
mth(inp)

@onlyNativeDeviceTypes
@all_sparse_layouts('layout', include_strided=not True)
@dtypes(torch.float64, torch.cdouble)
@parametrize("masked", [subtest(False, name='sparse'), subtest(True, name='masked')])
@parametrize("fast_mode", [subtest(False, name='slow'), subtest(True, name='fast')])
def test_gradcheck_mm(self, layout, dtype, device, masked, fast_mode):
# This function does not check the following cases:
# - batch or hybrid tensors because addmm does not support
# such inputs yet
# - check_forward_ad=True because of the lack of sparse tensor
# support in aten::view_as_real, torch._VF._make_dual, etc.

ref_x = torch.tensor([[1, 2, 0, 0],
[0, 6, 0, 0],
[0, 0, 0, 0],
[13, 14, 0, 15]], dtype=dtype, device=device)
ref_y = torch.tensor([[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]],
dtype=dtype, device=device)

mm = torch.sparse.mm if masked else torch.mm

blocksize = (2, 2) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
x = ref_x.to_sparse(layout=layout, blocksize=blocksize).requires_grad_(True)
y = ref_y.requires_grad_(True)

if layout is torch.sparse_bsr and not masked or layout is torch.sparse_bsc:
with self.assertRaisesRegex(
RuntimeError,
r"addmm: computation on (CPU|CUDA) is not implemented for Strided \+ Sparse(Bsr|Bsc) @ Strided"):
torch.autograd.gradcheck(mm, (x, y), check_sparse_nnz=True, fast_mode=fast_mode, masked=masked)
self.skipTest('NOT IMPL')
elif layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and masked:
with self.assertRaisesRegex(
RuntimeError,
r"(sparse_addmm_sparse_backward: unsupported combination of layouts,"
r" grad: Strided, mat1: Sparse(Csc|Bsr|Bsc), mat2: Strided"
r"|addmm: computation on (CPU|CUDA) is not implemented for "
r"Strided \+ Sparse(Csc|Bsr|Bsc) @ Strided without MKL)"):
torch.autograd.gradcheck(mm, (x, y), check_sparse_nnz=True, fast_mode=fast_mode, masked=masked)
self.skipTest('NOT IMPL')
else:
if masked:
r = torch.autograd.gradcheck(mm, (x, y), check_sparse_nnz=True, fast_mode=fast_mode, masked=masked)
else:
# Specifying check_sparse_nnz is unnecessary in
# non-masked/sparse semantics
r = torch.autograd.gradcheck(mm, (x, y), fast_mode=fast_mode, masked=masked)
self.assertTrue(r)


# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')
Expand Down