From 8db72a430d0c3a7d3388749d5d438fb805f53407 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 13 Oct 2023 09:45:35 -0700 Subject: [PATCH] [sparse] Add padding for dense matrices in semi-structured sparse (#110583) Summary: Currently we have shape constraints in semi-structured sparsity for both CUTLASS and cuSPARSELt These shape constraints unfortunately apply to both the dense and sparse matrices in sparsedense matmul. This PR adds in support for calling `F.pad` in order to pad dense matrices to the right size with zeros and then pull out the corresponding rows from the resultant result matrix. We also throw a warning in this case. The tests have also been updated to take in a dense_input_shape parameter. Test Plan: ``` python test/test_sparse_semi_structured.py ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/110583 Approved by: https://github.com/alexsamardzic, https://github.com/cpuhrsch --- .../benchmark_semi_structured_sparsity.py | 6 +- test/test_sparse_semi_structured.py | 114 ++++++++++++------ torch/sparse/semi_structured.py | 80 +++++++++--- 3 files changed, 139 insertions(+), 61 deletions(-) diff --git a/benchmarks/sparse/benchmark_semi_structured_sparsity.py b/benchmarks/sparse/benchmark_semi_structured_sparsity.py index 78b7df3bba103..59aa63915883d 100644 --- a/benchmarks/sparse/benchmark_semi_structured_sparsity.py +++ b/benchmarks/sparse/benchmark_semi_structured_sparsity.py @@ -5,7 +5,7 @@ import torch import torch.utils.benchmark as benchmark from torch import nn -from torch.sparse import to_sparse_semi_structured +from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured from tqdm import tqdm @@ -49,7 +49,7 @@ def rand_sparse_semi_structured_mask( def test_linear(m, k, n, dtype, contiguous, backend): - SparseSemiStructuredTensor.fuse_transpose = contiguous + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype) sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask input_tensor = torch.zeros(n, k).to(dtype).cuda() @@ -61,6 +61,7 @@ def test_linear(m, k, n, dtype, contiguous, backend): ).blocked_autorange() dense_output = model(input_tensor) + print(dense_output.shape) # sparsify weights model.linear.weight = nn.Parameter( @@ -70,6 +71,7 @@ def test_linear(m, k, n, dtype, contiguous, backend): ) sparse_output = model(input_tensor) + print(sparse_output.shape) sparse_measurement = benchmark.Timer( stmt="model(input_tensor)", diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index 1f0cf17c4c5ea..317a54c3227be 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -43,7 +43,7 @@ # check if cslt is available for now using this: # TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available() try: - torch._cslt_compress(torch.ones(128, 128).cuda()) + torch._cslt_compress(torch.ones(128, 256).cuda()) SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cusparselt") except Exception: pass @@ -127,7 +127,7 @@ def setUp(self): def test_to_sparse_semi_structured(self, dtype, backend): SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") - A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) + A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype) A_sparse = to_sparse_semi_structured(A) assert A.shape == A_sparse.shape @@ -139,18 +139,18 @@ def test_to_sparse_semi_structured(self, dtype, backend): @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + @parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)]) @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) - def test_mm_sparse_first_NT(self, dtype, device, backend): + def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend): """ Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8 - Ensure torch.mm(A_sparse, B.t()) is correct """ SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") - A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) + A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) A_sparse = to_sparse_semi_structured(A) - B = torch.rand((128, 128), device=A_sparse.device).to(dtype) + B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype) # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over if dtype is torch.int8: @@ -162,7 +162,38 @@ def test_mm_sparse_first_NT(self, dtype, device, backend): with self.assertRaisesRegex(RuntimeError, "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"): sparse_result = torch.mm(A_sparse, B) + else: + dense_result = torch.mm(A, B) + sparse_result = torch.mm(A_sparse, B) + assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + + @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) + @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) + def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend): + """ + Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16 + and will throw an error for int8 + padding + """ + SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + + A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) + A_sparse = to_sparse_semi_structured(A) + + B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype) + # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over + if dtype is torch.int8 and dense_input_shape in {(1, 128), (64, 128)}: + # padding with int8 throws an error because transposing B yields a contiguous output + # and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS. + if backend == "cutlass": + with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"): + sparse_result = torch.mm(A_sparse, B.t()) + else: + with self.assertRaisesRegex(RuntimeError, + "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"): + sparse_result = torch.mm(A_sparse, B.t()) + elif dtype is torch.int8: # test transpose # NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior. # CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor @@ -170,25 +201,23 @@ def test_mm_sparse_first_NT(self, dtype, device, backend): sparse_result = torch.mm(A_sparse, B.t()) assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) else: - dense_result = torch.mm(A, B) - sparse_result = torch.mm(A_sparse, B) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) # test transpose dense_result = torch.mm(A, B.t()) sparse_result = torch.mm(A_sparse, B.t()) assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) - def test_mm_sparse_first_T(self, dtype, device, backend): + def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend): """ Ensure torch.mm(A_sparse.t(), B) throws error """ SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") - A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) + A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype) A_sparse = to_sparse_semi_structured(A) - B = torch.rand((128, 128), device=A_sparse.device).to(dtype) + B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype) with self.assertRaisesRegex( NotImplementedError, @@ -197,16 +226,17 @@ def test_mm_sparse_first_T(self, dtype, device, backend): torch.mm(A_sparse.t(), B) @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) - def test_mm_sparse_second_T(self, dtype, device, backend): + def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend): """ Ensure torch.mm(A, B_sparse.t()) is correct """ SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") - B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) + B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) B_sparse = to_sparse_semi_structured(B) - A = torch.rand((128, 128), device=B_sparse.device).to(dtype) + A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype) # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over if dtype is torch.int8: @@ -218,32 +248,18 @@ def test_mm_sparse_second_T(self, dtype, device, backend): assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) - def test_cslt_sparse_mm_int8_in_fp16_out(self, device): - """ - This test is only needed for cuSPARSELt - """ - if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS: - SparseSemiStructuredTensor._FORCE_CUTLASS = False - A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8) - A_sparse = to_sparse_semi_structured(A) - - B = torch.rand((128, 128), device=A_sparse.device).to(torch.int8) - - dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.float16) - sparse_result = torch._cslt_sparse_mm(A_sparse.compressed_tensor_cusparselt, B.t(), out_dtype=torch.float16) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) - @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) - def test_mm_sparse_second_NT(self, dtype, device, backend): + def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend): """ Ensure torch.mm(A, B_sparse) throws error """ SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") - B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) + B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) B_sparse = to_sparse_semi_structured(B) - A = torch.rand((128, 128), device=B_sparse.device).to(dtype) + A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype) with self.assertRaisesRegex( NotImplementedError, @@ -251,15 +267,32 @@ def test_mm_sparse_second_NT(self, dtype, device, backend): ): sparse_result = torch.mm(A, B_sparse) + @parametrize("dense_input_shape", [(128, 128)]) + def test_cslt_sparse_mm_int8_in_fp16_out(self, dense_input_shape, device): + """ + This test is only needed for cuSPARSELt + """ + if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS: + SparseSemiStructuredTensor._FORCE_CUTLASS = False + A = rand_sparse_semi_structured_mask(128, 256, dtype=torch.int8) + A_sparse = to_sparse_semi_structured(A) + + B = torch.rand(dense_input_shape, device=A_sparse.device).to(torch.int8) + + dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=torch.float16) + sparse_result = torch._cslt_sparse_mm(A_sparse.compressed_tensor_cusparselt, B.t(), out_dtype=torch.float16) + assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + + @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)]) @parametrize("inference_mode", [subtest(True), subtest(False)]) @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) - def test_linear(self, inference_mode, device, backend): + def test_linear(self, dense_input_shape, inference_mode, device, backend): """ Test nn.Linear has the same numerics """ SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") - input = torch.rand(64, 128, 128, device=device).half() - model = nn.Linear(128, 128).to(device).half() + input = torch.rand((dense_input_shape), device=device).half() + model = nn.Linear(128, 256).to(device).half() m, n = model.weight.shape mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool) # set masked weight @@ -277,14 +310,15 @@ def test_linear(self, inference_mode, device, backend): assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)]) @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) - def test_mlp(self, device, backend): + def test_mlp(self, device, dense_input_shape, backend): SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" - input = torch.rand(64, 768, 768, device=device).half() + input = torch.rand(dense_input_shape, device=device).half() model = ( nn.Sequential( - nn.Linear(768, 3072), - nn.Linear(3072, 768), + nn.Linear(128, 256), + nn.Linear(256, 128), ) .half() .to(device) diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index b6f11a816ef78..a5828ef2c8dc8 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -53,7 +53,7 @@ class SparseSemiStructuredTensor(torch.Tensor): _FUSE_TRANSPOSE = False _FORCE_CUTLASS = True - _WARNING_SHOWN = False + _PROTOTYPE_WARNING_SHOWN = False @staticmethod def __new__( @@ -88,7 +88,7 @@ def __new__( """ assert compressed_tensor_cusparselt is None or (sparse_tensor_cutlass is None and meta_tensor_cutlass is None) - if not cls._WARNING_SHOWN: + if not cls._PROTOTYPE_WARNING_SHOWN: warnings.warn( ( "The PyTorch API of SparseSemiStructuredTensor is in prototype stage " @@ -98,7 +98,7 @@ def __new__( ), UserWarning, ) - cls._WARNING_SHOWN = True + cls._PROTOTYPE_WARNING_SHOWN = True if original_tensor is not None: previous_tensor = original_tensor @@ -232,6 +232,26 @@ def __repr__(self) -> str: # type: ignore[override] __torch_function__ = torch._C._disabled_torch_function_impl + def _pad_tensor_for_matmul(self, original_tensor : torch.Tensor) -> torch.Tensor: + """ + Calculates padding for dense tensor and pads tensor if necessary. + If padding is not required, this function returns the original tensor. + """ + # only 2d matmul + assert original_tensor.dim() == 2 + + # check shape + m, n = original_tensor.shape + min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].min_rows + min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].min_cols + to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 + to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 + if to_pad_m or to_pad_n: + return torch.nn.functional.pad(original_tensor, (0, to_pad_n, 0, to_pad_m)) + else: + return original_tensor + + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: """Overload __torch_dispatch__ to use torch._sparse_semi_structured_linear. @@ -290,38 +310,52 @@ def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: # F.linear(x) = addmm(bias, input, weight.t()) = b + xW' = (b + xW')'' # = (W''x' + b')' = (Wx' + b')' = addmm(bias.T, weight, input).T if isinstance(input_B, cls) and input_B.transposed: + row, col = input_A.shape + input_A_padded = input_B._pad_tensor_for_matmul(input_A) if input_B.compressed_tensor_cusparselt is None: assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None - return torch._sparse_semi_structured_linear( - input_A, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass, bias=bias + res = torch._sparse_semi_structured_linear( + input_A_padded, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass, bias=bias ) else: - return torch._cslt_sparse_mm( - input_B.compressed_tensor_cusparselt, input_A.T, bias # type: ignore[arg-type] + res = torch._cslt_sparse_mm( + input_B.compressed_tensor_cusparselt, input_A_padded.t(), bias # type: ignore[arg-type] ).t() + return res[:row, :] # handle mm if func is torch.ops.aten.mm.default: input_A, input_B = args + # first element sparse if isinstance(input_A, cls) and not input_A.transposed: + row, col = input_B.shape + input_B_padded = input_A._pad_tensor_for_matmul(input_B) if input_A.compressed_tensor_cusparselt is None: assert input_A.sparse_tensor_cutlass is not None and input_A.meta_tensor_cutlass is not None - return torch._sparse_semi_structured_linear( - input_B.t(), input_A.sparse_tensor_cutlass, input_A.meta_tensor_cutlass + res = torch._sparse_semi_structured_linear( + input_B_padded.t(), input_A.sparse_tensor_cutlass, input_A.meta_tensor_cutlass ).t() else: - return torch._cslt_sparse_mm( - input_A.compressed_tensor_cusparselt, input_B, None # type: ignore[arg-type] + res = torch._cslt_sparse_mm( + input_A.compressed_tensor_cusparselt, input_B_padded, None # type: ignore[arg-type] ) + return res[:, :col] + + # second element sparse elif isinstance(input_B, cls) and input_B.transposed: + row, col = input_A.shape + input_A_padded = input_B._pad_tensor_for_matmul(input_A) + if input_B.compressed_tensor_cusparselt is None: assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None - return torch._sparse_semi_structured_linear( - input_A, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass + res = torch._sparse_semi_structured_linear( + input_A_padded, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass ) else: - return torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A.T, None).t() # type: ignore[arg-type] + res = torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A_padded.t(), None).t() # type: ignore[arg-type] + + return res[:row, :] # When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(), # so we must match the aten.linear op. In this case, we need to explicitly handle collapsing to 2d matmul @@ -329,21 +363,29 @@ def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: if func is torch.ops.aten.linear.default: input_tensor, weight, bias = args shape = input_tensor.shape + + input_tensor_2d = input_tensor.view(-1, shape[-1]) + row, col = input_tensor_2d.shape + # this is a noop if already padded + input_tensor_2d_padded = weight._pad_tensor_for_matmul(input_tensor_2d) + if isinstance(weight, cls): if weight.compressed_tensor_cusparselt is None: assert weight.sparse_tensor_cutlass is not None and weight.meta_tensor_cutlass is not None - return torch._sparse_semi_structured_linear( - input_tensor, + res = torch._sparse_semi_structured_linear( + input_tensor_2d_padded, weight.sparse_tensor_cutlass, weight.meta_tensor_cutlass, bias=bias ) else: - return torch._cslt_sparse_mm( + res = torch._cslt_sparse_mm( weight.compressed_tensor_cusparselt, # type: ignore[arg-type] - input_tensor.view(-1, shape[-1]).t(), + input_tensor_2d_padded.t(), bias - ).t().view(*shape[:-1], -1) + ).t() + return res[:row, :].view(*shape[:-1], -1) + # handle values if func is torch.ops.aten.values.default: