diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index 31d294f248786..2d2ccb31eb324 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -127,8 +127,8 @@ def setUp(self): def tearDown(self): super().tearDown() - @unittest.skipIf(IS_WINDOWS, "torch.compile not support on windows") - def test_mlp_contiguous_relu_compile(self): + @staticmethod + def _test_mlp_contiguous_relu_compile(backend, dense_input_shape): """ Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile We expect: @@ -146,28 +146,46 @@ def forward(self, x): x = x.contiguous() return torch.nn.functional.relu(x) - def _test_mlp_contiguous_relu_compile(backend, dense_input_shape): - SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" - input = torch.rand(dense_input_shape, device="cuda").half() + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" - model = Model().eval().cuda().half() - mod_linear = model.linear - m, n = mod_linear.weight.shape - mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda() - # set masked weight - mod_linear.weight = nn.Parameter(mod_linear.weight * mask) + input = torch.rand(dense_input_shape, device="cuda").half() + model = Model().eval().cuda().half() + mod_linear = model.linear + m, n = mod_linear.weight.shape + mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda() + # set masked weight + mod_linear.weight = nn.Parameter(mod_linear.weight * mask) - dense_result = model(input) - mod_linear.weight = nn.Parameter(to_sparse_semi_structured(mod_linear.weight)) + dense_result = model(input) + mod_linear.weight = nn.Parameter(to_sparse_semi_structured(mod_linear.weight)) + sparse_result = model(input) - model = torch.compile(model) - sparse_result = model(input) + model = torch.compile(model, backend="inductor", fullgraph=True) + sparse_compile_result = model(input) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + # test that sparse_compile_result and dense_result are numerically close + assert torch.allclose(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3) + # assert sparse and sparse_compile have the same strides, + # as meta registrations may return contiguous tensors when the output is transposed + # https://github.com/pytorch/pytorch/pull/114477 + assert sparse_result.stride() == sparse_compile_result.stride() - for backend in SEMI_STRUCTURED_SUPPORTED_BACKENDS: - for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]: - _test_mlp_contiguous_relu_compile(backend, dense_input_shape) + @unittest.skipIf(IS_WINDOWS, "torch.compile not support on windows") + @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") + def test_mlp_contiguous_relu_compile_cusparselt(self): + """ + test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile + """ + for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]: + SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cusparselt", dense_input_shape) + + @unittest.skipIf(IS_WINDOWS, "torch.compile not support on windows") + def test_mlp_contiguous_relu_compile_cutlass(self): + """ + test for CUTLASS meta registrations (_sparse_semi_structured_linear) + torch.compile + """ + for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]: + SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape) class TestSparseSemiStructured(TestCase): diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index bb10a34c4c064..752879eaea3bb 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -394,6 +394,34 @@ def meta_unsqueeze_(self, dim): return self +@register_meta(aten._cslt_sparse_mm) +def meta__cslt_sparse_mm( + compressed_A: torch.Tensor, + dense_B: torch.Tensor, + bias: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + transpose_result: bool = False, +): + assert dense_B.dtype in {torch.float16, torch.bfloat16, torch.int8}, "_cslt_sparse_mm only supports fp16, bf16, and int8" + assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype" + assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs" + + is_int8_input_type = compressed_A.dtype == torch.int8 + compression_factor = 10 if is_int8_input_type else 9 + k = dense_B.size(0) + n = dense_B.size(1) + m = (compressed_A.numel() * 16) // (compression_factor * k) + if bias is not None: + assert m == bias.size(0) + + mixed_dtype = out_dtype is not None and (is_int8_input_type != out_dtype) + if mixed_dtype: + assert is_int8_input_type and mixed_dtype is torch.float16, "out_dtype is only supported for i8i8->fp16 matmul" + output_shape = (n, m) if transpose_result else (m, n) + result = dense_B.new_empty(output_shape, dtype=out_dtype if mixed_dtype else None) + return result + + @register_meta(aten.index_reduce.default) def meta_index_reduce( self: Tensor,