Skip to content

Commit

Permalink
[sparse] Add padding for dense matrices in semi-structured sparse (#1…
Browse files Browse the repository at this point in the history
…10583)

Summary:

Currently we have shape constraints in semi-structured sparsity for both
CUTLASS and cuSPARSELt

These shape constraints unfortunately apply to both the dense and sparse
matrices in sparsedense matmul.

This PR adds in support for calling `F.pad` in order to pad dense
matrices to the right size with zeros and then pull out the
corresponding rows from the resultant result matrix.

We also throw a warning in this case.
The tests have also been updated to take in a dense_input_shape
parameter.

Test Plan:
```
python test/test_sparse_semi_structured.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: #110583
Approved by: https://github.com/alexsamardzic, https://github.com/cpuhrsch
  • Loading branch information
jcaip authored and pytorchmergebot committed Oct 13, 2023
1 parent 2b6f281 commit 8db72a4
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 61 deletions.
6 changes: 4 additions & 2 deletions benchmarks/sparse/benchmark_semi_structured_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from torch.sparse import to_sparse_semi_structured
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
from tqdm import tqdm


Expand Down Expand Up @@ -49,7 +49,7 @@ def rand_sparse_semi_structured_mask(


def test_linear(m, k, n, dtype, contiguous, backend):
SparseSemiStructuredTensor.fuse_transpose = contiguous
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
input_tensor = torch.zeros(n, k).to(dtype).cuda()
Expand All @@ -61,6 +61,7 @@ def test_linear(m, k, n, dtype, contiguous, backend):
).blocked_autorange()

dense_output = model(input_tensor)
print(dense_output.shape)

# sparsify weights
model.linear.weight = nn.Parameter(
Expand All @@ -70,6 +71,7 @@ def test_linear(m, k, n, dtype, contiguous, backend):
)

sparse_output = model(input_tensor)
print(sparse_output.shape)

sparse_measurement = benchmark.Timer(
stmt="model(input_tensor)",
Expand Down
114 changes: 74 additions & 40 deletions test/test_sparse_semi_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
# check if cslt is available for now using this:
# TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available()
try:
torch._cslt_compress(torch.ones(128, 128).cuda())
torch._cslt_compress(torch.ones(128, 256).cuda())
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cusparselt")
except Exception:
pass
Expand Down Expand Up @@ -127,7 +127,7 @@ def setUp(self):
def test_to_sparse_semi_structured(self, dtype, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")

A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)

assert A.shape == A_sparse.shape
Expand All @@ -139,18 +139,18 @@ def test_to_sparse_semi_structured(self, dtype, backend):


@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_first_NT(self, dtype, device, backend):
def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
Ensure torch.mm(A_sparse, B.t()) is correct
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")

A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)

B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)

# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8:
Expand All @@ -162,33 +162,62 @@ def test_mm_sparse_first_NT(self, dtype, device, backend):
with self.assertRaisesRegex(RuntimeError,
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
sparse_result = torch.mm(A_sparse, B)
else:
dense_result = torch.mm(A, B)
sparse_result = torch.mm(A_sparse, B)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16
and will throw an error for int8 + padding
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")

A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)

B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)

# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8 and dense_input_shape in {(1, 128), (64, 128)}:
# padding with int8 throws an error because transposing B yields a contiguous output
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B.t())
else:
with self.assertRaisesRegex(RuntimeError,
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
sparse_result = torch.mm(A_sparse, B.t())
elif dtype is torch.int8:
# test transpose
# NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior.
# CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
else:
dense_result = torch.mm(A, B)
sparse_result = torch.mm(A_sparse, B)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
# test transpose
dense_result = torch.mm(A, B.t())
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_first_T(self, dtype, device, backend):
def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend):
"""
Ensure torch.mm(A_sparse.t(), B) throws error
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)

B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)

with self.assertRaisesRegex(
NotImplementedError,
Expand All @@ -197,16 +226,17 @@ def test_mm_sparse_first_T(self, dtype, device, backend):
torch.mm(A_sparse.t(), B)

@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_second_T(self, dtype, device, backend):
def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A, B_sparse.t()) is correct
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
B_sparse = to_sparse_semi_structured(B)

A = torch.rand((128, 128), device=B_sparse.device).to(dtype)
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)

# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8:
Expand All @@ -218,48 +248,51 @@ def test_mm_sparse_second_T(self, dtype, device, backend):

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

def test_cslt_sparse_mm_int8_in_fp16_out(self, device):
"""
This test is only needed for cuSPARSELt
"""
if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
SparseSemiStructuredTensor._FORCE_CUTLASS = False
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
A_sparse = to_sparse_semi_structured(A)

B = torch.rand((128, 128), device=A_sparse.device).to(torch.int8)

dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.float16)
sparse_result = torch._cslt_sparse_mm(A_sparse.compressed_tensor_cusparselt, B.t(), out_dtype=torch.float16)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_second_NT(self, dtype, device, backend):
def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A, B_sparse) throws error
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
B_sparse = to_sparse_semi_structured(B)

A = torch.rand((128, 128), device=B_sparse.device).to(dtype)
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)

with self.assertRaisesRegex(
NotImplementedError,
r"arg1: SparseSemiStructuredTensor\(.*transposed=False",
):
sparse_result = torch.mm(A, B_sparse)

@parametrize("dense_input_shape", [(128, 128)])
def test_cslt_sparse_mm_int8_in_fp16_out(self, dense_input_shape, device):
"""
This test is only needed for cuSPARSELt
"""
if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
SparseSemiStructuredTensor._FORCE_CUTLASS = False
A = rand_sparse_semi_structured_mask(128, 256, dtype=torch.int8)
A_sparse = to_sparse_semi_structured(A)

B = torch.rand(dense_input_shape, device=A_sparse.device).to(torch.int8)

dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=torch.float16)
sparse_result = torch._cslt_sparse_mm(A_sparse.compressed_tensor_cusparselt, B.t(), out_dtype=torch.float16)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
@parametrize("inference_mode", [subtest(True), subtest(False)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_linear(self, inference_mode, device, backend):
def test_linear(self, dense_input_shape, inference_mode, device, backend):
"""
Test nn.Linear has the same numerics
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
input = torch.rand(64, 128, 128, device=device).half()
model = nn.Linear(128, 128).to(device).half()
input = torch.rand((dense_input_shape), device=device).half()
model = nn.Linear(128, 256).to(device).half()
m, n = model.weight.shape
mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
# set masked weight
Expand All @@ -277,14 +310,15 @@ def test_linear(self, inference_mode, device, backend):

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mlp(self, device, backend):
def test_mlp(self, device, dense_input_shape, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
input = torch.rand(64, 768, 768, device=device).half()
input = torch.rand(dense_input_shape, device=device).half()
model = (
nn.Sequential(
nn.Linear(768, 3072),
nn.Linear(3072, 768),
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.half()
.to(device)
Expand Down

0 comments on commit 8db72a4

Please sign in to comment.