Skip to content

Commit

Permalink
[sparse] Fix semi-structured sparse shape mismatch bug (#110420)
Browse files Browse the repository at this point in the history
Summary:

Currently, PyTorch incorrectly calculates the size of the returned
matrix when we pass a non-contiguous batched (>2d) input to the
semi-structured sparse subclass.

This is most common in MLP layers, where we have 2 linear layers back to back.

This will lead to an error like the following:
```
RuntimeError: shape '[20, 64, 64, 3072]' is invalid for input of size
62914560

```
Where the size of the sparse matmul result is off because we infer the
output shape with the wrong tensor shape.

This happens because of a bug where we did not update the subclass
tensor shape when doing transpose.
For semi-structured sparsity, transposing is a no-op where we just set
the boolean flag, but we forgot to also update the tensor shape.

Note that this error goes away in inference mode, since we avoid
decomposing the aten.linear op and handle shape folding ourselves,
which changes the execution path.

An alternative way to fix this issue is to set
TORCH_FLATTEN_LINEAR_3D=True, which will also fix this error.

Test Plan:
```
python test/test_sparse_semi_structured.py -k test_mlp

```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: #110420
Approved by: https://github.com/alexsamardzic, https://github.com/cpuhrsch
  • Loading branch information
jcaip authored and pytorchmergebot committed Oct 10, 2023
1 parent 468a73f commit f10aab0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
30 changes: 30 additions & 0 deletions test/test_sparse_semi_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,36 @@ def test_linear(self, inference_mode, device, backend):

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mlp(self, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
input = torch.rand(64, 768, 768, device=device).half()
model = (
nn.Sequential(
nn.Linear(768, 3072),
nn.Linear(3072, 768),
)
.half()
.to(device)
)

for i in range(2):
m, n = model[i].weight.shape
mask = rand_sparse_semi_structured_mask(
m, n, device=device, dtype=torch.bool
)
# set masked weight
model[i].weight = nn.Parameter(model[i].weight * mask)

dense_result = model(input)

for i in range(2):
model[i].weight = nn.Parameter(to_sparse_semi_structured(model[i].weight))

sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_values(self, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
Expand Down
9 changes: 6 additions & 3 deletions torch/sparse/semi_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
"""

_FUSE_TRANSPOSE = False
_FORCE_CUTLASS = False
_FORCE_CUTLASS = True
_WARNING_SHOWN = False

@staticmethod
Expand Down Expand Up @@ -268,7 +268,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
if func is torch.ops.aten.t.default:
return SparseSemiStructuredTensor(
args[0].original_tensor,
original_shape=args[0].shape,
# transpose shape
original_shape=torch.Size([args[0].shape[1], args[0].shape[0]]),
compressed_tensor_cusparselt=args[0].compressed_tensor_cusparselt,
sparse_tensor_cutlass=args[0].sparse_tensor_cutlass,
meta_tensor_cutlass=args[0].meta_tensor_cutlass,
Expand Down Expand Up @@ -438,4 +439,6 @@ def to_sparse_semi_structured(
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0',
dtype=torch.int16))
"""
return SparseSemiStructuredTensor(original_tensor, original_shape=original_tensor.shape, transposed=transposed)
return SparseSemiStructuredTensor(
original_tensor, original_shape=original_tensor.shape, transposed=transposed
)

0 comments on commit f10aab0

Please sign in to comment.