Skip to content

Commit

Permalink
[sparse][semi-structured][inductor] meta registrations for _cslt_spar…
Browse files Browse the repository at this point in the history
…se_mm + additional stride checking in test.

Summary:

This PR adds in meta registrations for _cslt_sparse_mm.

Based on the work drisspg did
in #114370.

Additionally, it updates the tests by checking that the strides of the
spare result and the result returned by sparse+compile are the same, to
avoid errors like those found in

#114477.

Test Plan:
```
python test/test_sparse_semi_structred -k compile_cusparselt
python test/test_sparse_semi_structred -k compile_cutlass
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: cc03188401100f6f1268844b99d1239d498dcead
Pull Request resolved: #114685
  • Loading branch information
jcaip committed Nov 28, 2023
1 parent cef79c0 commit 7d5e3e9
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 19 deletions.
56 changes: 37 additions & 19 deletions test/test_sparse_semi_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def setUp(self):
def tearDown(self):
super().tearDown()

@unittest.skipIf(IS_WINDOWS, "torch.compile not support on windows")
def test_mlp_contiguous_relu_compile(self):
@staticmethod
def _test_mlp_contiguous_relu_compile(backend, dense_input_shape):
"""
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
We expect:
Expand All @@ -146,28 +146,46 @@ def forward(self, x):
x = x.contiguous()
return torch.nn.functional.relu(x)

def _test_mlp_contiguous_relu_compile(backend, dense_input_shape):
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
input = torch.rand(dense_input_shape, device="cuda").half()
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"

model = Model().eval().cuda().half()
mod_linear = model.linear
m, n = mod_linear.weight.shape
mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda()
# set masked weight
mod_linear.weight = nn.Parameter(mod_linear.weight * mask)
input = torch.rand(dense_input_shape, device="cuda").half()
model = Model().eval().cuda().half()
mod_linear = model.linear
m, n = mod_linear.weight.shape
mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda()
# set masked weight
mod_linear.weight = nn.Parameter(mod_linear.weight * mask)

dense_result = model(input)
mod_linear.weight = nn.Parameter(to_sparse_semi_structured(mod_linear.weight))
dense_result = model(input)
mod_linear.weight = nn.Parameter(to_sparse_semi_structured(mod_linear.weight))
sparse_result = model(input)

model = torch.compile(model, backend="inductor", fullgraph=True)
sparse_result = model(input)
model = torch.compile(model, backend="inductor", fullgraph=True)
sparse_compile_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
# test that sparse_compile_result and dense_result are numerically close
assert torch.allclose(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3)
# assert sparse and sparse_compile have the same strides,
# as meta registrations may return contiguous tensors when the output is transposed
# https://github.com/pytorch/pytorch/pull/114477
assert sparse_result.stride() == sparse_compile_result.stride()

for backend in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
for dense_input_shape in [(128, 128), (64, 128), (1, 128), (64, 128, 128)]:
_test_mlp_contiguous_relu_compile(backend, dense_input_shape)
@unittest.skipIf(IS_WINDOWS, "torch.compile not support on windows")
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
def test_mlp_contiguous_relu_compile_cusparselt(self):
"""
test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile
"""
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cusparselt", dense_input_shape)

@unittest.skipIf(IS_WINDOWS, "torch.compile not support on windows")
def test_mlp_contiguous_relu_compile_cutlass(self):
"""
test for CUTLASS meta registrations (_sparse_semi_structured_linear) + torch.compile
"""
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape)

class TestSparseSemiStructured(TestCase):

Expand Down
34 changes: 34 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,40 @@ def meta_sparse_structured_linear(
return output


@register_meta(aten._cslt_sparse_mm)
def meta__cslt_sparse_mm(
compressed_A: torch.Tensor,
dense_B: torch.Tensor,
bias: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
transpose_result: bool = False,
):
assert dense_B.dtype in {
torch.float16,
torch.bfloat16,
torch.int8,
}, "_cslt_sparse_mm only supports fp16, bf16, and int8"
assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype"
assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs"

is_int8_input_type = compressed_A.dtype == torch.int8
compression_factor = 10 if is_int8_input_type else 9
k = dense_B.size(0)
n = dense_B.size(1)
m = (compressed_A.numel() * 16) // (compression_factor * k)
if bias is not None:
assert m == bias.size(0)

mixed_dtype = out_dtype is not None and (is_int8_input_type != out_dtype)
if mixed_dtype:
assert (
is_int8_input_type and mixed_dtype is torch.float16
), "out_dtype is only supported for i8i8->fp16 matmul"
output_shape = (n, m) if transpose_result else (m, n)
result = dense_B.new_empty(output_shape, dtype=out_dtype if mixed_dtype else None)
return result


@register_meta(aten.index_reduce.default)
def meta_index_reduce(
self: Tensor,
Expand Down

0 comments on commit 7d5e3e9

Please sign in to comment.