Skip to content

Commit

Permalink
Update on "[core][pruning][sparse][feature] SparseSemiStructured tens…
Browse files Browse the repository at this point in the history
…or subclass"

This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.

In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.

This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
  sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
  semi-structured sparse bool mask and creates an instance of the
  subclass.

** SparseSemiStructuredTensor **

The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](#103700) for faster matmul, better dtype converage, and relaxed shape
constraints.

Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().

Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.

** to_sparse_semi_structured **

This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.

Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.

** User Details **

We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:

```
from torch.sparse import to_sparse_semi_structured

mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()

linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
                                                       mask=linear.weight.bool())

```

This also updates tests and the `torch.sparse` module docstring to
refleect these changes.

[ghstack-poisoned]
  • Loading branch information
jcaip committed Jun 17, 2023
2 parents ea54d48 + 5ea84df commit e67e8fc
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 9 deletions.
4 changes: 2 additions & 2 deletions benchmarks/sparse/benchmark_semi_structured_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_linear(m, k, n, dtype, contiguous, backend):
dense_output = model(input_tensor)

# sparsify weights
model.linear.weight = nn.Parameter(to_sparse_semi_structured(sparse_weight))
model.linear.weight = nn.Parameter(to_sparse_semi_structured(sparse_weight, mask=mask.bool()))

sparse_output = model(input_tensor)

Expand Down Expand Up @@ -93,7 +93,7 @@ def test_tensor(m, k, n, dtype, contiguous, backend):
B = torch.zeros(k, n).to(dtype).cuda()
bias = torch.rand(n).to(dtype).cuda()

sA = to_sparse_semi_structured(A)
sA = to_sparse_semi_structured(A, mask=A.bool())

# torch.mm calculation
if dtype is not torch.int8:
Expand Down
63 changes: 56 additions & 7 deletions torch/sparse/semi_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@


class SparseSemiStructuredTensor(torch.Tensor):
"""
This class implementes semi-structured sparsity as a Tensor subclass.
"""This class implementes semi-structured sparsity as a Tensor subclass.
Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
Expand All @@ -44,7 +43,22 @@ class SparseSemiStructuredTensor(torch.Tensor):
"""

@staticmethod
def __new__(cls, custom_shape, compressed_tensor, transposed):
def __new__(cls, custom_shape: torch.Size, compressed_tensor: torch.Tensor, transposed: bool = Flase) -> torch.Tensor:

Check failure on line 46 in torch/sparse/semi_structured.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY [misc]

Incompatible return type for "__new__" (returns "Tensor", but must return a subtype of "SparseSemiStructuredTensor")

Check failure on line 46 in torch/sparse/semi_structured.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY [name-defined]

Name "Flase" is not defined
"""
Create a new instance of the class.
Args:
custom_shape (tuple): The custom shape for the new instance.
compressed_tensor (torch.Tensor): The compressed tensor to use for the new instance.
transposed (bool): Indicates whether the tensor is transposed.
Returns:
torch.Tensor: A torch.Tensor wrapper subclass.
Raises:
None
"""
kwargs = {}
kwargs["device"] = compressed_tensor.device
kwargs["dtype"] = compressed_tensor.dtype

Check failure on line 64 in torch/sparse/semi_structured.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY [assignment]

Incompatible types in assignment (expression has type "dtype", target has type "device")
Expand All @@ -68,11 +82,22 @@ def __init__(
Returns:
None
Raises:
None
"""
self.compressed_tensor = compressed_tensor
self.transposed = transposed

def __repr__(self):
def __repr__(self) -> str:
"""Return string representation of SparseSemiStructuredTensor
Returns:
str: String representation
Raises:
None
"""
return (
f"SparseSemiStructuredTensor(shape={self.shape}, "
f"transposed={self.transposed}"
Expand All @@ -83,7 +108,24 @@ def __repr__(self):
__torch_function__ = torch._C._disabled_torch_function_impl

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
def __torch_dispatch__(cls, func, types, args, kwargs) -> torch.Tensor:
"""Overload __torch_dispatch__ to use torch._structred_sparse_linear.
`torch.structured_sparse_linear` uses accelerated sparse CUTLASS kernels.
In the future we plan to also add in support for cuSPARSELt kernels.
Args:
func: The function being dispatched.
types: The types of the arguments.
args: The arguments passed to the function.
kwargs: The keyword arguments passed to the function.
Returns:
torch.Tensor: The result of the dispatched operation.
Raises:
NotImplementedError: If the dispatched operation is not implemented.
"""
# since we create a new compressed tensor, the tensor will already be detached
# this effecitvely functions as a no-op.
if func is torch.ops.aten.detach.default:
Expand Down Expand Up @@ -190,12 +232,19 @@ def to_sparse_semi_structured(
- torch.float16 (r, c) must be >= and a multiple of 64
- torch.int8 (r, c) must be >= and a multiple of 128
Args::
Args:
original_tensor (Tensor): the dense tensor to convert
mask (Optional BoolTensor): boolean mask to apply to the original tensor
transposed (bool, optional): whether the dense tensor is transposed
Example::
Returns:
SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor and mask
Raises:
NotImplementedError: If ``mask=None``, as we currently do not support inferring a mask from the dense tensor.
RuntimeError: If original_tensor is not a supported dtype, dim, shape, or device.
Example:
>>> from torch.sparse import to_sparse_semi_structured
>>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
tensor([[0., 0., 1., ..., 0., 1., 1.],
Expand Down

0 comments on commit e67e8fc

Please sign in to comment.