diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index 57a71f91ee54e..9f403407af0d6 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -178,23 +178,23 @@ def test_mm_sparse_first_NT(self, dtype, device, backend): sparse_result = torch.mm(A_sparse, B.t()) assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) - @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) - @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) - def test_mm_sparse_first_T(self, dtype, device, backend): + @dtypes(*semi_structured_supported_dtypes) + @parametrize("backend", semi_structured_supported_backends) + def test_mm_sparse_first_t(self, dtype, device, backend): """ - Ensure torch.mm(A_sparse.t(), B) throws error + ensure torch.mm(a_sparse.t(), b) throws error """ - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") - A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) - A_sparse = to_sparse_semi_structured(A) + sparsesemistructuredtensor._force_cutlass = (backend == "cutlass") + a = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) + a_sparse = to_sparse_semi_structured(a) - B = torch.rand((128, 128), device=A_sparse.device).to(dtype) + b = torch.rand((128, 128), device=a_sparse.device).to(dtype) - with self.assertRaisesRegex( - NotImplementedError, - r"arg0: SparseSemiStructuredTensor\(.*transposed=True", + with self.assertraisesregex( + notimplementederror, + r"arg0: sparsesemistructuredtensor\(.*transposed=true", ): - torch.mm(A_sparse.t(), B) + torch.mm(a_sparse.t(), b) @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) @@ -277,6 +277,27 @@ def test_linear(self, inference_mode, device, backend): assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) + def test_padding_linear(self, device, backend): + """ + Test padding for inputs that aren't multiples of 8 + """ + SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + input = torch.rand(1, 128, device=device).half() + model = nn.Linear(128, 128).to(device).half() + m, n = model.weight.shape + mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool) + # set masked weight + model.weight = nn.Parameter(model.weight * mask) + + dense_result = model(input) + + model.weight = nn.Parameter(to_sparse_semi_structured(model.weight)) + with torch.inference_mode(): + sparse_result = model(input) + + assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) def test_values(self, backend): SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index f1e78317d5b11..1ed5089f1db12 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -179,21 +179,6 @@ def __init__( "dtype must be one of: {_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG}" ) - # check shape - m, n = original_tensor.shape - min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[ - original_tensor.dtype - ].min_rows - min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[ - original_tensor.dtype - ].min_cols - if m < min_rows or m % min_rows or n < min_cols or n % min_cols: - # TODO in the future we can add in padding to support dimensions that aren't perfect multiples - raise RuntimeError( - f"Error original_tensor.shape {original_tensor.shape} is not supported! " - f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})" - ) - compressed_tensor_cusparselt = None sparse_tensor_cutlass = None meta_tensor_cutlass = None @@ -232,6 +217,25 @@ def __repr__(self) -> str: # type: ignore[override] __torch_function__ = torch._C._disabled_torch_function_impl + def _pad_tensor_for_matmul(self, original_tensor): + # only work on 2d tensors + assert original_tensor.dim() == 2 + + # check shape + m, n = original_tensor.shape + min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[ + original_tensor.dtype + ].min_rows + min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[ + original_tensor.dtype + ].min_cols + if m < min_rows or m % min_rows or n < min_cols or n % min_cols: + to_pad_m, to_pad_n = -m % min_rows, -n % min_cols + return torch.nn.functional.pad(input_tensor_2d, (0, to_pad_n, 0, to_pad_m)) + else: + return original_tensor + + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: """Overload __torch_dispatch__ to use torch._sparse_semi_structured_linear. @@ -303,24 +307,35 @@ def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: if func is torch.ops.aten.mm.default: input_A, input_B = args + # first element sparse if isinstance(input_A, cls) and not input_A.transposed: + input_B_padded = cls._pad_tensor_for_matmul(input_B) + row, col = input_B_padded.shape if input_A.compressed_tensor_cusparselt is None: assert input_A.sparse_tensor_cutlass is not None and input_A.meta_tensor_cutlass is not None - return torch._sparse_semi_structured_linear( - input_B.t(), input_A.sparse_tensor_cutlass, input_A.meta_tensor_cutlass + res = torch._sparse_semi_structured_linear( + input_B_padded.t(), input_A.sparse_tensor_cutlass, input_A.meta_tensor_cutlass ).t() else: - return torch._cslt_sparse_mm( - input_A.compressed_tensor_cusparselt, input_B, None # type: ignore[arg-type] + res = torch._cslt_sparse_mm( + input_A.compressed_tensor_cusparselt, input_B_padded, None # type: ignore[arg-type] ) + return res[:row, :col] + + # second element sparse elif isinstance(input_B, cls) and input_B.transposed: + input_A_padded = cls._pad_tensor_for_matmul(input_A) + row, col = input_A_padded.shape + if input_B.compressed_tensor_cusparselt is None: assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None - return torch._sparse_semi_structured_linear( - input_A, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass + res = torch._sparse_semi_structured_linear( + input_A_padded, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass ) else: - return torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A.T, None).t() # type: ignore[arg-type] + res = torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A.T, None).t() # type: ignore[arg-type] + + return res[:row, :col] # When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(), # so we must match the aten.linear op. In this case, we need to explicitly handle collapsing to 2d matmul @@ -328,22 +343,30 @@ def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: if func is torch.ops.aten.linear.default: input_tensor, weight, bias = args shape = input_tensor.shape + + input_tensor_2d = input_tensor.view(-1, shape[-1]) + row, col = input_tensor_2d.shape + # this is a noop if already padded + input_tensor_2d_padded = cld._pad_tensor_for_matmul(input_tensor_2d) + if isinstance(weight, cls): if weight.compressed_tensor_cusparselt is None: assert weight.sparse_tensor_cutlass is not None and weight.meta_tensor_cutlass is not None - return torch._sparse_semi_structured_linear( - input_tensor, + res = torch._sparse_semi_structured_linear( + input_tensor_2d_padded, weight.sparse_tensor_cutlass, weight.meta_tensor_cutlass, bias=bias ) else: - return torch._cslt_sparse_mm( + res = torch._cslt_sparse_mm( weight.compressed_tensor_cusparselt, # type: ignore[arg-type] - input_tensor.view(-1, shape[-1]).t(), + input_tensor_2d_padded.t(), bias ).t().view(*shape[:-1], -1) + return res[:row, :col] + # handle values if func is torch.ops.aten.values.default: if args[0].compressed_tensor_cusparselt is None: