Skip to content

Commit

Permalink
[sparse] Add padding for dense matrices in semi-structured sparse
Browse files Browse the repository at this point in the history
Summary:

Currently we have shape constraints in semi-structured sparsity for both
CUTLASS and cuSPARSELt

These shape constraints unfortunately apply to both the dense and sparse
matrices in sparsedense matmul.

This PR adds in support for calling `F.pad` in order to pad dense
matrices to the right size with zeros and then pull out the
corresponding rows from the resultant result matrix.

We also throw a warning in this case.
The tests have also been updated to take in a dense_input_shape
parameter.

Test Plan:
```
python test/test_sparse_semi_structured.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 4e07920e2eb7108731ed31a80cc5dc8740bfa74e
Pull Request resolved: #110583
  • Loading branch information
jcaip committed Oct 10, 2023
1 parent f10aab0 commit a5ba4f8
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 60 deletions.
8 changes: 5 additions & 3 deletions benchmarks/sparse/benchmark_semi_structured_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from torch.sparse import to_sparse_semi_structured
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from tqdm import tqdm


Expand Down Expand Up @@ -47,9 +47,9 @@ def rand_sparse_semi_structured_mask(
.contiguous()
)


#@torch.inference_mode()
def test_linear(m, k, n, dtype, contiguous, backend):
SparseSemiStructuredTensor.fuse_transpose = contiguous
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
input_tensor = torch.zeros(n, k).to(dtype).cuda()
Expand All @@ -61,6 +61,7 @@ def test_linear(m, k, n, dtype, contiguous, backend):
).blocked_autorange()

dense_output = model(input_tensor)
print(dense_output.shape)

# sparsify weights
model.linear.weight = nn.Parameter(
Expand All @@ -70,6 +71,7 @@ def test_linear(m, k, n, dtype, contiguous, backend):
)

sparse_output = model(input_tensor)
print(sparse_output.shape)

sparse_measurement = benchmark.Timer(
stmt="model(input_tensor)",
Expand Down
110 changes: 72 additions & 38 deletions test/test_sparse_semi_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,18 @@ def test_to_sparse_semi_structured(self, dtype, backend):


@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(128, 1), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_first_NT(self, dtype, device, backend):
def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
Ensure torch.mm(A_sparse, B.t()) is correct
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")

A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)

B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)

# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8:
Expand All @@ -162,7 +162,38 @@ def test_mm_sparse_first_NT(self, dtype, device, backend):
with self.assertRaisesRegex(RuntimeError,
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
sparse_result = torch.mm(A_sparse, B)
else:
dense_result = torch.mm(A, B)
sparse_result = torch.mm(A_sparse, B)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16
and will throw an error for int8 + padding
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")

A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)

B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)

# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8 and dense_input_shape == (1, 128):
# padding with int8 throws an error because transposing B yields a contiguous output
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B.t())
else:
with self.assertRaisesRegex(RuntimeError,
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
sparse_result = torch.mm(A_sparse, B.t())
elif dtype is torch.int8:
# test transpose
# NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior.
# CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor
Expand All @@ -171,25 +202,23 @@ def test_mm_sparse_first_NT(self, dtype, device, backend):
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
else:
dense_result = torch.mm(A, B)
sparse_result = torch.mm(A_sparse, B)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
# test transpose
dense_result = torch.mm(A, B.t())
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_first_T(self, dtype, device, backend):
def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend):
"""
Ensure torch.mm(A_sparse.t(), B) throws error
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)

B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)

with self.assertRaisesRegex(
NotImplementedError,
Expand All @@ -198,16 +227,17 @@ def test_mm_sparse_first_T(self, dtype, device, backend):
torch.mm(A_sparse.t(), B)

@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_second_T(self, dtype, device, backend):
def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A, B_sparse.t()) is correct
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
B_sparse = to_sparse_semi_structured(B)

A = torch.rand((128, 128), device=B_sparse.device).to(dtype)
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)

# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8:
Expand All @@ -220,7 +250,27 @@ def test_mm_sparse_second_T(self, dtype, device, backend):

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

def test_cslt_sparse_mm_int8_in_fp16_out(self, device):
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A, B_sparse) throws error
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
B_sparse = to_sparse_semi_structured(B)

A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)

with self.assertRaisesRegex(
NotImplementedError,
r"arg1: SparseSemiStructuredTensor\(.*transposed=False",
):
sparse_result = torch.mm(A, B_sparse)

@parametrize("dense_input_shape", [(128, 128)])
def test_cslt_sparse_mm_int8_in_fp16_out(self, dense_input_shape, device):
"""
Test sparse mam with int8 input with fp16 output for cuSPARSELt
"""
Expand All @@ -229,9 +279,9 @@ def test_cslt_sparse_mm_int8_in_fp16_out(self, device):
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
A_sparse = to_sparse_semi_structured(A)

B = torch.rand((128, 128), device=A_sparse.device).to(torch.int8)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(torch.int8)

dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.float16)
dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=torch.float16)
sparse_result = torch._cslt_sparse_mm(A_sparse.compressed_tensor_cusparselt, B.t(), out_dtype=torch.float16)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

Expand All @@ -250,33 +300,16 @@ def test_cslt_sparse_mm_int8_in_int32_out(self, device):
sparse_result = torch._cslt_sparse_mm(A_sparse.compressed_tensor_cusparselt, B.t(), out_dtype=torch.int32)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_second_NT(self, dtype, device, backend):
"""
Ensure torch.mm(A, B_sparse) throws error
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
B_sparse = to_sparse_semi_structured(B)

A = torch.rand((128, 128), device=B_sparse.device).to(dtype)

with self.assertRaisesRegex(
NotImplementedError,
r"arg1: SparseSemiStructuredTensor\(.*transposed=False",
):
sparse_result = torch.mm(A, B_sparse)

@parametrize("dense_input_shape", [(1, 128), (128, 128), (64, 128, 128)])
@parametrize("inference_mode", [subtest(True), subtest(False)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_linear(self, inference_mode, device, backend):
def test_linear(self, dense_input_shape, inference_mode, device, backend):
"""
Test nn.Linear has the same numerics
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
input = torch.rand(64, 128, 128, device=device).half()
model = nn.Linear(128, 128).to(device).half()
input = torch.rand((dense_input_shape), device=device).half()
model = nn.Linear(128, 256).to(device).half()
m, n = model.weight.shape
mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
# set masked weight
Expand All @@ -294,14 +327,15 @@ def test_linear(self, inference_mode, device, backend):

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@parametrize("dense_input_shape", [(1, 128), (128, 128), (64, 128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mlp(self, device, backend):
def test_mlp(self, device, dense_input_shape, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
input = torch.rand(64, 768, 768, device=device).half()
input = torch.rand(dense_input_shape, device=device).half()
model = (
nn.Sequential(
nn.Linear(768, 3072),
nn.Linear(3072, 768),
nn.Linear(128, 3072),
nn.Linear(3072, 128),
)
.half()
.to(device)
Expand Down

0 comments on commit a5ba4f8

Please sign in to comment.