From a5ba4f827c9b193c5f072475407900952418a13d Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 10 Oct 2023 11:09:35 -0700 Subject: [PATCH] [sparse] Add padding for dense matrices in semi-structured sparse Summary: Currently we have shape constraints in semi-structured sparsity for both CUTLASS and cuSPARSELt These shape constraints unfortunately apply to both the dense and sparse matrices in sparsedense matmul. This PR adds in support for calling `F.pad` in order to pad dense matrices to the right size with zeros and then pull out the corresponding rows from the resultant result matrix. We also throw a warning in this case. The tests have also been updated to take in a dense_input_shape parameter. Test Plan: ``` python test/test_sparse_semi_structured.py ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 4e07920e2eb7108731ed31a80cc5dc8740bfa74e Pull Request resolved: https://github.com/pytorch/pytorch/pull/110583 --- .../benchmark_semi_structured_sparsity.py | 8 +- test/test_sparse_semi_structured.py | 110 ++++++++++++------ torch/sparse/semi_structured.py | 90 +++++++++++--- 3 files changed, 148 insertions(+), 60 deletions(-) diff --git a/benchmarks/sparse/benchmark_semi_structured_sparsity.py b/benchmarks/sparse/benchmark_semi_structured_sparsity.py index 78b7df3bba103..c20a32d83a71f 100644 --- a/benchmarks/sparse/benchmark_semi_structured_sparsity.py +++ b/benchmarks/sparse/benchmark_semi_structured_sparsity.py @@ -5,7 +5,7 @@ import torch import torch.utils.benchmark as benchmark from torch import nn -from torch.sparse import to_sparse_semi_structured +from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor from tqdm import tqdm @@ -47,9 +47,9 @@ def rand_sparse_semi_structured_mask( .contiguous() ) - +#@torch.inference_mode() def test_linear(m, k, n, dtype, contiguous, backend): - SparseSemiStructuredTensor.fuse_transpose = contiguous + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype) sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask input_tensor = torch.zeros(n, k).to(dtype).cuda() @@ -61,6 +61,7 @@ def test_linear(m, k, n, dtype, contiguous, backend): ).blocked_autorange() dense_output = model(input_tensor) + print(dense_output.shape) # sparsify weights model.linear.weight = nn.Parameter( @@ -70,6 +71,7 @@ def test_linear(m, k, n, dtype, contiguous, backend): ) sparse_output = model(input_tensor) + print(sparse_output.shape) sparse_measurement = benchmark.Timer( stmt="model(input_tensor)", diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index 8cfd6dc7cba3c..db17e69001634 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -139,18 +139,18 @@ def test_to_sparse_semi_structured(self, dtype, backend): @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + @parametrize("dense_input_shape", [(128, 1), (128, 128)]) @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) - def test_mm_sparse_first_NT(self, dtype, device, backend): + def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend): """ Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8 - Ensure torch.mm(A_sparse, B.t()) is correct """ SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) A_sparse = to_sparse_semi_structured(A) - B = torch.rand((128, 128), device=A_sparse.device).to(dtype) + B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype) # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over if dtype is torch.int8: @@ -162,7 +162,38 @@ def test_mm_sparse_first_NT(self, dtype, device, backend): with self.assertRaisesRegex(RuntimeError, "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"): sparse_result = torch.mm(A_sparse, B) + else: + dense_result = torch.mm(A, B) + sparse_result = torch.mm(A_sparse, B) + assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + + @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + @parametrize("dense_input_shape", [(1, 128), (128, 128)]) + @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) + def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend): + """ + Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16 + and will throw an error for int8 + padding + """ + SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + + A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) + A_sparse = to_sparse_semi_structured(A) + + B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype) + # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over + if dtype is torch.int8 and dense_input_shape == (1, 128): + # padding with int8 throws an error because transposing B yields a contiguous output + # and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS. + if backend == "cutlass": + with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"): + sparse_result = torch.mm(A_sparse, B.t()) + else: + with self.assertRaisesRegex(RuntimeError, + "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"): + sparse_result = torch.mm(A_sparse, B.t()) + elif dtype is torch.int8: # test transpose # NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior. # CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor @@ -171,17 +202,15 @@ def test_mm_sparse_first_NT(self, dtype, device, backend): sparse_result = torch.mm(A_sparse, B.t()) assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) else: - dense_result = torch.mm(A, B) - sparse_result = torch.mm(A_sparse, B) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) # test transpose dense_result = torch.mm(A, B.t()) sparse_result = torch.mm(A_sparse, B.t()) assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + @parametrize("dense_input_shape", [(1, 128), (128, 128)]) @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) - def test_mm_sparse_first_T(self, dtype, device, backend): + def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend): """ Ensure torch.mm(A_sparse.t(), B) throws error """ @@ -189,7 +218,7 @@ def test_mm_sparse_first_T(self, dtype, device, backend): A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) A_sparse = to_sparse_semi_structured(A) - B = torch.rand((128, 128), device=A_sparse.device).to(dtype) + B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype) with self.assertRaisesRegex( NotImplementedError, @@ -198,8 +227,9 @@ def test_mm_sparse_first_T(self, dtype, device, backend): torch.mm(A_sparse.t(), B) @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + @parametrize("dense_input_shape", [(1, 128), (128, 128)]) @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) - def test_mm_sparse_second_T(self, dtype, device, backend): + def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend): """ Ensure torch.mm(A, B_sparse.t()) is correct """ @@ -207,7 +237,7 @@ def test_mm_sparse_second_T(self, dtype, device, backend): B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) B_sparse = to_sparse_semi_structured(B) - A = torch.rand((128, 128), device=B_sparse.device).to(dtype) + A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype) # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over if dtype is torch.int8: @@ -220,7 +250,27 @@ def test_mm_sparse_second_T(self, dtype, device, backend): assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) - def test_cslt_sparse_mm_int8_in_fp16_out(self, device): + @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + @parametrize("dense_input_shape", [(1, 128), (128, 128)]) + @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) + def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend): + """ + Ensure torch.mm(A, B_sparse) throws error + """ + SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) + B_sparse = to_sparse_semi_structured(B) + + A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype) + + with self.assertRaisesRegex( + NotImplementedError, + r"arg1: SparseSemiStructuredTensor\(.*transposed=False", + ): + sparse_result = torch.mm(A, B_sparse) + + @parametrize("dense_input_shape", [(128, 128)]) + def test_cslt_sparse_mm_int8_in_fp16_out(self, dense_input_shape, device): """ Test sparse mam with int8 input with fp16 output for cuSPARSELt """ @@ -229,9 +279,9 @@ def test_cslt_sparse_mm_int8_in_fp16_out(self, device): A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8) A_sparse = to_sparse_semi_structured(A) - B = torch.rand((128, 128), device=A_sparse.device).to(torch.int8) + B = torch.rand(dense_input_shape, device=A_sparse.device).to(torch.int8) - dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.float16) + dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=torch.float16) sparse_result = torch._cslt_sparse_mm(A_sparse.compressed_tensor_cusparselt, B.t(), out_dtype=torch.float16) assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) @@ -250,33 +300,16 @@ def test_cslt_sparse_mm_int8_in_int32_out(self, device): sparse_result = torch._cslt_sparse_mm(A_sparse.compressed_tensor_cusparselt, B.t(), out_dtype=torch.int32) assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) - @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) - @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) - def test_mm_sparse_second_NT(self, dtype, device, backend): - """ - Ensure torch.mm(A, B_sparse) throws error - """ - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") - B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) - B_sparse = to_sparse_semi_structured(B) - - A = torch.rand((128, 128), device=B_sparse.device).to(dtype) - - with self.assertRaisesRegex( - NotImplementedError, - r"arg1: SparseSemiStructuredTensor\(.*transposed=False", - ): - sparse_result = torch.mm(A, B_sparse) - + @parametrize("dense_input_shape", [(1, 128), (128, 128), (64, 128, 128)]) @parametrize("inference_mode", [subtest(True), subtest(False)]) @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) - def test_linear(self, inference_mode, device, backend): + def test_linear(self, dense_input_shape, inference_mode, device, backend): """ Test nn.Linear has the same numerics """ SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") - input = torch.rand(64, 128, 128, device=device).half() - model = nn.Linear(128, 128).to(device).half() + input = torch.rand((dense_input_shape), device=device).half() + model = nn.Linear(128, 256).to(device).half() m, n = model.weight.shape mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool) # set masked weight @@ -294,14 +327,15 @@ def test_linear(self, inference_mode, device, backend): assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + @parametrize("dense_input_shape", [(1, 128), (128, 128), (64, 128, 128)]) @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) - def test_mlp(self, device, backend): + def test_mlp(self, device, dense_input_shape, backend): SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" - input = torch.rand(64, 768, 768, device=device).half() + input = torch.rand(dense_input_shape, device=device).half() model = ( nn.Sequential( - nn.Linear(768, 3072), - nn.Linear(3072, 768), + nn.Linear(128, 3072), + nn.Linear(3072, 128), ) .half() .to(device) diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index b6f11a816ef78..a35a5543448e0 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -53,7 +53,7 @@ class SparseSemiStructuredTensor(torch.Tensor): _FUSE_TRANSPOSE = False _FORCE_CUTLASS = True - _WARNING_SHOWN = False + _PROTOTYPE_WARNING_SHOWN = False @staticmethod def __new__( @@ -88,7 +88,7 @@ def __new__( """ assert compressed_tensor_cusparselt is None or (sparse_tensor_cutlass is None and meta_tensor_cutlass is None) - if not cls._WARNING_SHOWN: + if not cls._PROTOTYPE_WARNING_SHOWN: warnings.warn( ( "The PyTorch API of SparseSemiStructuredTensor is in prototype stage " @@ -98,7 +98,7 @@ def __new__( ), UserWarning, ) - cls._WARNING_SHOWN = True + cls._PROTOTYPE_WARNING_SHOWN = True if original_tensor is not None: previous_tensor = original_tensor @@ -208,6 +208,7 @@ def __init__( compressed_tensor_cusparselt = torch._cslt_compress(original_tensor) # set values + self._PADDING_WARNING_SHOWN = self.original_tensor = None self.compressed_tensor_cusparselt = compressed_tensor_cusparselt self.sparse_tensor_cutlass = sparse_tensor_cutlass @@ -232,6 +233,33 @@ def __repr__(self) -> str: # type: ignore[override] __torch_function__ = torch._C._disabled_torch_function_impl + def _pad_tensor_for_matmul(self, original_tensor : torch.Tensor) -> torch.Tensor: + """ + Calculates padding for dense tensor and pads tensor if necessary. + If padding is not required, this function returns the original tensor. + """ + # only 2d matmul + assert original_tensor.dim() == 2 + + # check shape + m, n = original_tensor.shape + min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].min_rows + min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].min_cols + to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 + to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 + if to_pad_m or to_pad_n: + warnings.warn( + ( + "Attempting to do matmul with a dense tensor that does not meet shape requirements." + f"Padding dense input tensor of shape ({m}, {n}) to ({m+to_pad_m}, {n+to_pad_n})." + ), + UserWarning, + ) + return torch.nn.functional.pad(original_tensor, (0, to_pad_n, 0, to_pad_m)) + else: + return original_tensor + + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: """Overload __torch_dispatch__ to use torch._sparse_semi_structured_linear. @@ -251,6 +279,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: Raises: NotImplementedError: If the dispatched operation is not implemented. """ + #print(func) # Since this code runs below autograd, a detach corresponds to only returning a new object if func is torch.ops.aten.detach.default: return SparseSemiStructuredTensor( @@ -290,38 +319,52 @@ def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: # F.linear(x) = addmm(bias, input, weight.t()) = b + xW' = (b + xW')'' # = (W''x' + b')' = (Wx' + b')' = addmm(bias.T, weight, input).T if isinstance(input_B, cls) and input_B.transposed: + row, col = input_A.shape + input_A_padded = input_B._pad_tensor_for_matmul(input_A) if input_B.compressed_tensor_cusparselt is None: assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None - return torch._sparse_semi_structured_linear( - input_A, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass, bias=bias + res = torch._sparse_semi_structured_linear( + input_A_padded, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass, bias=bias ) else: - return torch._cslt_sparse_mm( - input_B.compressed_tensor_cusparselt, input_A.T, bias # type: ignore[arg-type] + res = torch._cslt_sparse_mm( + input_B.compressed_tensor_cusparselt, input_A_padded.t(), bias # type: ignore[arg-type] ).t() + return res[:row, :col] # handle mm if func is torch.ops.aten.mm.default: input_A, input_B = args + # first element sparse if isinstance(input_A, cls) and not input_A.transposed: + row, col = input_B.shape + input_B_padded = input_A._pad_tensor_for_matmul(input_B) if input_A.compressed_tensor_cusparselt is None: assert input_A.sparse_tensor_cutlass is not None and input_A.meta_tensor_cutlass is not None - return torch._sparse_semi_structured_linear( - input_B.t(), input_A.sparse_tensor_cutlass, input_A.meta_tensor_cutlass + res = torch._sparse_semi_structured_linear( + input_B_padded.t(), input_A.sparse_tensor_cutlass, input_A.meta_tensor_cutlass ).t() else: - return torch._cslt_sparse_mm( - input_A.compressed_tensor_cusparselt, input_B, None # type: ignore[arg-type] + res = torch._cslt_sparse_mm( + input_A.compressed_tensor_cusparselt, input_B_padded, None # type: ignore[arg-type] ) + return res[:row, :col] + + # second element sparse elif isinstance(input_B, cls) and input_B.transposed: + row, col = input_A.shape + input_A_padded = input_B._pad_tensor_for_matmul(input_A) + if input_B.compressed_tensor_cusparselt is None: assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None - return torch._sparse_semi_structured_linear( - input_A, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass + res = torch._sparse_semi_structured_linear( + input_A_padded, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass ) else: - return torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A.T, None).t() # type: ignore[arg-type] + res = torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A_padded.t(), None).t() # type: ignore[arg-type] + + return res[:row, :col] # When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(), # so we must match the aten.linear op. In this case, we need to explicitly handle collapsing to 2d matmul @@ -329,21 +372,30 @@ def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: if func is torch.ops.aten.linear.default: input_tensor, weight, bias = args shape = input_tensor.shape + + input_tensor_2d = input_tensor.view(-1, shape[-1]) + row, col = input_tensor_2d.shape + # this is a noop if already padded + input_tensor_2d_padded = weight._pad_tensor_for_matmul(input_tensor_2d) + if isinstance(weight, cls): if weight.compressed_tensor_cusparselt is None: assert weight.sparse_tensor_cutlass is not None and weight.meta_tensor_cutlass is not None - return torch._sparse_semi_structured_linear( - input_tensor, + res = torch._sparse_semi_structured_linear( + input_tensor_2d_padded, weight.sparse_tensor_cutlass, weight.meta_tensor_cutlass, bias=bias ) + return res[:row, :col] else: - return torch._cslt_sparse_mm( + res = torch._cslt_sparse_mm( weight.compressed_tensor_cusparselt, # type: ignore[arg-type] - input_tensor.view(-1, shape[-1]).t(), + input_tensor_2d_padded.t(), bias - ).t().view(*shape[:-1], -1) + ).t() + return res[:row, :col].view(*shape[:-1], -1) + # handle values if func is torch.ops.aten.values.default: