Skip to content

Commit

Permalink
[sparse] Add padding for dense matrices in semi-structured sparse
Browse files Browse the repository at this point in the history
Summary:

Currently we have shape constraints in semi-structured sparsity for both
CUTLASS and cuSPARSELt

These shape constraints unfortunately apply to both the dense and sparse
matrices in sparsedense matmul.

This PR adds in support for calling `F.pad` in order to pad dense
matrices to the right size with zeros and then pull out the
corresponding rows from the resultant result matrix.

Test Plan:
```
python test/test_sparse_semi_structured.py -k pad
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 34916c76e086e6d159202b55903d816942f4fca6
Pull Request resolved: #110583
  • Loading branch information
jcaip committed Oct 5, 2023
1 parent 46a5558 commit 0d27856
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 38 deletions.
45 changes: 33 additions & 12 deletions test/test_sparse_semi_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,23 +178,23 @@ def test_mm_sparse_first_NT(self, dtype, device, backend):
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_first_T(self, dtype, device, backend):
@dtypes(*semi_structured_supported_dtypes)
@parametrize("backend", semi_structured_supported_backends)
def test_mm_sparse_first_t(self, dtype, device, backend):
"""
Ensure torch.mm(A_sparse.t(), B) throws error
ensure torch.mm(a_sparse.t(), b) throws error
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)
sparsesemistructuredtensor._force_cutlass = (backend == "cutlass")
a = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
a_sparse = to_sparse_semi_structured(a)

B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
b = torch.rand((128, 128), device=a_sparse.device).to(dtype)

with self.assertRaisesRegex(
NotImplementedError,
r"arg0: SparseSemiStructuredTensor\(.*transposed=True",
with self.assertraisesregex(
notimplementederror,
r"arg0: sparsesemistructuredtensor\(.*transposed=true",
):
torch.mm(A_sparse.t(), B)
torch.mm(a_sparse.t(), b)

@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
Expand Down Expand Up @@ -277,6 +277,27 @@ def test_linear(self, inference_mode, device, backend):

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_padding_linear(self, device, backend):
"""
Test padding for inputs that aren't multiples of 8
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
input = torch.rand(1, 128, device=device).half()
model = nn.Linear(128, 128).to(device).half()
m, n = model.weight.shape
mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
# set masked weight
model.weight = nn.Parameter(model.weight * mask)

dense_result = model(input)

model.weight = nn.Parameter(to_sparse_semi_structured(model.weight))
with torch.inference_mode():
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_values(self, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
Expand Down
75 changes: 49 additions & 26 deletions torch/sparse/semi_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,21 +179,6 @@ def __init__(
"dtype must be one of: {_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG}"
)

# check shape
m, n = original_tensor.shape
min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[
original_tensor.dtype
].min_rows
min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[
original_tensor.dtype
].min_cols
if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
# TODO in the future we can add in padding to support dimensions that aren't perfect multiples
raise RuntimeError(
f"Error original_tensor.shape {original_tensor.shape} is not supported! "
f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
)

compressed_tensor_cusparselt = None
sparse_tensor_cutlass = None
meta_tensor_cutlass = None
Expand Down Expand Up @@ -232,6 +217,25 @@ def __repr__(self) -> str: # type: ignore[override]

__torch_function__ = torch._C._disabled_torch_function_impl

def _pad_tensor_for_matmul(self, original_tensor):
# only work on 2d tensors
assert original_tensor.dim() == 2

# check shape
m, n = original_tensor.shape
min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[
original_tensor.dtype
].min_rows
min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[
original_tensor.dtype
].min_cols
if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
to_pad_m, to_pad_n = -m % min_rows, -n % min_cols
return torch.nn.functional.pad(input_tensor_2d, (0, to_pad_n, 0, to_pad_m))
else:
return original_tensor


@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
"""Overload __torch_dispatch__ to use torch._sparse_semi_structured_linear.
Expand Down Expand Up @@ -303,47 +307,66 @@ def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
if func is torch.ops.aten.mm.default:
input_A, input_B = args

# first element sparse
if isinstance(input_A, cls) and not input_A.transposed:
input_B_padded = cls._pad_tensor_for_matmul(input_B)
row, col = input_B_padded.shape
if input_A.compressed_tensor_cusparselt is None:
assert input_A.sparse_tensor_cutlass is not None and input_A.meta_tensor_cutlass is not None
return torch._sparse_semi_structured_linear(
input_B.t(), input_A.sparse_tensor_cutlass, input_A.meta_tensor_cutlass
res = torch._sparse_semi_structured_linear(
input_B_padded.t(), input_A.sparse_tensor_cutlass, input_A.meta_tensor_cutlass
).t()
else:
return torch._cslt_sparse_mm(
input_A.compressed_tensor_cusparselt, input_B, None # type: ignore[arg-type]
res = torch._cslt_sparse_mm(
input_A.compressed_tensor_cusparselt, input_B_padded, None # type: ignore[arg-type]
)
return res[:row, :col]

# second element sparse
elif isinstance(input_B, cls) and input_B.transposed:
input_A_padded = cls._pad_tensor_for_matmul(input_A)
row, col = input_A_padded.shape

if input_B.compressed_tensor_cusparselt is None:
assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None
return torch._sparse_semi_structured_linear(
input_A, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass
res = torch._sparse_semi_structured_linear(
input_A_padded, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass
)
else:
return torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A.T, None).t() # type: ignore[arg-type]
res = torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A.T, None).t() # type: ignore[arg-type]

return res[:row, :col]

# When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(),
# so we must match the aten.linear op. In this case, we need to explicitly handle collapsing to 2d matmul
# TODO see if there's a way to force pytorch to decompose the op so we don't have to handle this here.
if func is torch.ops.aten.linear.default:
input_tensor, weight, bias = args
shape = input_tensor.shape

input_tensor_2d = input_tensor.view(-1, shape[-1])
row, col = input_tensor_2d.shape
# this is a noop if already padded
input_tensor_2d_padded = cld._pad_tensor_for_matmul(input_tensor_2d)

if isinstance(weight, cls):
if weight.compressed_tensor_cusparselt is None:
assert weight.sparse_tensor_cutlass is not None and weight.meta_tensor_cutlass is not None
return torch._sparse_semi_structured_linear(
input_tensor,
res = torch._sparse_semi_structured_linear(
input_tensor_2d_padded,
weight.sparse_tensor_cutlass,
weight.meta_tensor_cutlass,
bias=bias
)
else:
return torch._cslt_sparse_mm(
res = torch._cslt_sparse_mm(
weight.compressed_tensor_cusparselt, # type: ignore[arg-type]
input_tensor.view(-1, shape[-1]).t(),
input_tensor_2d_padded.t(),
bias
).t().view(*shape[:-1], -1)

return res[:row, :col]

# handle values
if func is torch.ops.aten.values.default:
if args[0].compressed_tensor_cusparselt is None:
Expand Down

0 comments on commit 0d27856

Please sign in to comment.