Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][pruning][sparse][feature] SparseSemiStructured tensor subclass #102135

Closed
wants to merge 66 commits into from

Conversation

jcaip
Copy link
Contributor

@jcaip jcaip commented May 24, 2023

Stack from ghstack (oldest at bottom):

This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.

In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.

This PR adds in 2 things:

  • a Tensor subclass, SparseSemiStructuredTensor to store the
    sparse tensor in copmressed form and override __torch_dispatch__.
  • a conversion function that takes in a dense tensor and a
    semi-structured sparse bool mask and creates an instance of the
    subclass.

SparseSemiStructuredTensor

The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
_structured_sparse_linear. In the future we can use the cuSPARSELT bindings
here for faster matmul, better dtype converage, and relaxed shape
constraints.

Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().

Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.

to_sparse_semi_structured

This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so tensor !=0 is not 2:4
sparse.

Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.

User Details

We have implemented support for the following ops for torch.float16
and torch.int8:

torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach

The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:

from torch.sparse import to_sparse_semi_structured

mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()

linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
                                                       mask=linear.weight.bool())

This also updates tests and the torch.sparse module docstring to
reflect these changes.

cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer

This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight))

``

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `contiguous_output` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With contiguous_output set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight,
contiguous_output=True))

``

However there are some kinks within the workflow we still
need to resolve.

1. Integration with CUTLASS - sometimes CUTLASS can be faster and
supports more advanced epilogue fusions, like SwigLu, so we need some
way of deciding when to dispatch into CUTLASS vs CUSPARSELT.

2. No way to go from compressed matrix to sparse matrix - Currently
this means we do not support .t() fully on SemiStructuredSparseTensors,
since we can't transpose the compressed form. We need some functionality
to go back to the dense representation from the compressed form. CUTLASS
has this ability and we may be able to reuse their implementation if
they share the same meta/mask layout.

3. Bias propogating wrong way. For when the sparse matrix is first in
addmm, the bias is propogated columnwise instead of rowwise. Note that
this is
fine when the second matrix is dense, since we transpose the result
anyways.

4. Padding - currently cusparselt only supports dimensions that are
multiples of 16, 8, 4 (depending on dtype), so we should add padding so
that we can satisfy this constraint for all matrices.

dtypes supported:
- int8
- fp16
- bf16
- fp32

Ops supported:
``
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.addmm(dense, sparse)
torch.addmm(sparse, dense)
```
stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265
Pull Request resolved: #97546

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented May 24, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/102135

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b5d528e:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jcaip added a commit that referenced this pull request May 24, 2023
This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight))

``

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `contiguous_output` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With contiguous_output set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight,
contiguous_output=True))

``

However there are some kinks within the workflow we still
need to resolve.

1. Integration with CUTLASS - sometimes CUTLASS can be faster and
supports more advanced epilogue fusions, like SwigLu, so we need some
way of deciding when to dispatch into CUTLASS vs CUSPARSELT.

2. No way to go from compressed matrix to sparse matrix - Currently
this means we do not support .t() fully on SemiStructuredSparseTensors,
since we can't transpose the compressed form. We need some functionality
to go back to the dense representation from the compressed form. CUTLASS
has this ability and we may be able to reuse their implementation if
they share the same meta/mask layout.

3. Bias propogating wrong way. For when the sparse matrix is first in
addmm, the bias is propogated columnwise instead of rowwise. Note that
this is
fine when the second matrix is dense, since we transpose the result
anyways.

4. Padding - currently cusparselt only supports dimensions that are
multiples of 16, 8, 4 (depending on dtype), so we should add padding so
that we can satisfy this constraint for all matrices.

dtypes supported:
- int8
- fp16
- bf16
- fp32

Ops supported:
``
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.addmm(dense, sparse)
torch.addmm(sparse, dense)
```
stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265
Pull Request resolved: #97546

ghstack-source-id: 4fe4de500601c5224bbcb7429775ab48f91f42b1
Pull Request resolved: #102135
…lass"

This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight))

``

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `contiguous_output` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With contiguous_output set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight,
contiguous_output=True))

``

However there are some kinks within the workflow we still
need to resolve.

1. Integration with CUTLASS - sometimes CUTLASS can be faster and
supports more advanced epilogue fusions, like SwigLu, so we need some
way of deciding when to dispatch into CUTLASS vs CUSPARSELT.

2. No way to go from compressed matrix to sparse matrix - Currently
this means we do not support .t() fully on SemiStructuredSparseTensors,
since we can't transpose the compressed form. We need some functionality
to go back to the dense representation from the compressed form. CUTLASS
has this ability and we may be able to reuse their implementation if
they share the same meta/mask layout.

3. Bias propogating wrong way. For when the sparse matrix is first in
addmm, the bias is propogated columnwise instead of rowwise. Note that
this is
fine when the second matrix is dense, since we transpose the result
anyways.

4. Padding - currently cusparselt only supports dimensions that are
multiples of 16, 8, 4 (depending on dtype), so we should add padding so
that we can satisfy this constraint for all matrices.

dtypes supported:
- int8
- fp16
- bf16
- fp32

Ops supported:
``
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.addmm(dense, sparse)
torch.addmm(sparse, dense)
```
stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265
Pull Request resolved: #97546

[ghstack-poisoned]
jcaip added a commit that referenced this pull request May 26, 2023
This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight))

``

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `contiguous_output` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With contiguous_output set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight,
contiguous_output=True))

``

However there are some kinks within the workflow we still
need to resolve.

1. Integration with CUTLASS - sometimes CUTLASS can be faster and
supports more advanced epilogue fusions, like SwigLu, so we need some
way of deciding when to dispatch into CUTLASS vs CUSPARSELT.

2. No way to go from compressed matrix to sparse matrix - Currently
this means we do not support .t() fully on SemiStructuredSparseTensors,
since we can't transpose the compressed form. We need some functionality
to go back to the dense representation from the compressed form. CUTLASS
has this ability and we may be able to reuse their implementation if
they share the same meta/mask layout.

3. Bias propogating wrong way. For when the sparse matrix is first in
addmm, the bias is propogated columnwise instead of rowwise. Note that
this is
fine when the second matrix is dense, since we transpose the result
anyways.

4. Padding - currently cusparselt only supports dimensions that are
multiples of 16, 8, 4 (depending on dtype), so we should add padding so
that we can satisfy this constraint for all matrices.

dtypes supported:
- int8
- fp16
- bf16
- fp32

Ops supported:
``
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.addmm(dense, sparse)
torch.addmm(sparse, dense)
```
stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265
Pull Request resolved: #97546

ghstack-source-id: 59b12262540227d409a5a328c6db218c00583f6a
Pull Request resolved: #102135
…lass"

This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight))

``

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `contiguous_output` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With contiguous_output set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight,
contiguous_output=True))

``

However there are some kinks within the workflow we still
need to resolve.

1. Integration with CUTLASS - sometimes CUTLASS can be faster and
supports more advanced epilogue fusions, like SwigLu, so we need some
way of deciding when to dispatch into CUTLASS vs CUSPARSELT.

2. No way to go from compressed matrix to sparse matrix - Currently
this means we do not support .t() fully on SemiStructuredSparseTensors,
since we can't transpose the compressed form. We need some functionality
to go back to the dense representation from the compressed form. CUTLASS
has this ability and we may be able to reuse their implementation if
they share the same meta/mask layout.

3. Bias propogating wrong way. For when the sparse matrix is first in
addmm, the bias is propogated columnwise instead of rowwise. Note that
this is
fine when the second matrix is dense, since we transpose the result
anyways.

4. Padding - currently cusparselt only supports dimensions that are
multiples of 16, 8, 4 (depending on dtype), so we should add padding so
that we can satisfy this constraint for all matrices.

dtypes supported:
- int8
- fp16
- bf16
- fp32

Ops supported:
``
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.addmm(dense, sparse)
torch.addmm(sparse, dense)
```
stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265
Pull Request resolved: #97546

[ghstack-poisoned]
jcaip added a commit that referenced this pull request May 27, 2023
This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight))

``

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `contiguous_output` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With contiguous_output set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight,
contiguous_output=True))

``

However there are some kinks within the workflow we still
need to resolve.

1. Integration with CUTLASS - sometimes CUTLASS can be faster and
supports more advanced epilogue fusions, like SwigLu, so we need some
way of deciding when to dispatch into CUTLASS vs CUSPARSELT.

2. No way to go from compressed matrix to sparse matrix - Currently
this means we do not support .t() fully on SemiStructuredSparseTensors,
since we can't transpose the compressed form. We need some functionality
to go back to the dense representation from the compressed form. CUTLASS
has this ability and we may be able to reuse their implementation if
they share the same meta/mask layout.

3. Bias propogating wrong way. For when the sparse matrix is first in
addmm, the bias is propogated columnwise instead of rowwise. Note that
this is
fine when the second matrix is dense, since we transpose the result
anyways.

4. Padding - currently cusparselt only supports dimensions that are
multiples of 16, 8, 4 (depending on dtype), so we should add padding so
that we can satisfy this constraint for all matrices.

dtypes supported:
- int8
- fp16
- bf16
- fp32

Ops supported:
``
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.addmm(dense, sparse)
torch.addmm(sparse, dense)
```
stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265
Pull Request resolved: #97546

ghstack-source-id: 6ecb9000462ba5e0f343ece7b2f6de95a9e4fdf7
Pull Request resolved: #102135
…lass"

This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight))

``

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `contiguous_output` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With contiguous_output set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight,
contiguous_output=True))

``

However there are some kinks within the workflow we still
need to resolve.

1. Integration with CUTLASS - sometimes CUTLASS can be faster and
supports more advanced epilogue fusions, like SwigLu, so we need some
way of deciding when to dispatch into CUTLASS vs CUSPARSELT.

2. No way to go from compressed matrix to sparse matrix - Currently
this means we do not support .t() fully on SemiStructuredSparseTensors,
since we can't transpose the compressed form. We need some functionality
to go back to the dense representation from the compressed form. CUTLASS
has this ability and we may be able to reuse their implementation if
they share the same meta/mask layout.

3. Bias propogating wrong way. For when the sparse matrix is first in
addmm, the bias is propogated columnwise instead of rowwise. Note that
this is
fine when the second matrix is dense, since we transpose the result
anyways.

4. Padding - currently cusparselt only supports dimensions that are
multiples of 16, 8, 4 (depending on dtype), so we should add padding so
that we can satisfy this constraint for all matrices.

dtypes supported:
- int8
- fp16
- bf16
- fp32

Ops supported:
``
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.addmm(dense, sparse)
torch.addmm(sparse, dense)
```
stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265
Pull Request resolved: #97546

[ghstack-poisoned]
jcaip added a commit that referenced this pull request May 27, 2023
This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight))

``

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `contiguous_output` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With contiguous_output set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight,
contiguous_output=True))

``

However there are some kinks within the workflow we still
need to resolve.

1. Integration with CUTLASS - sometimes CUTLASS can be faster and
supports more advanced epilogue fusions, like SwigLu, so we need some
way of deciding when to dispatch into CUTLASS vs CUSPARSELT.

2. No way to go from compressed matrix to sparse matrix - Currently
this means we do not support .t() fully on SemiStructuredSparseTensors,
since we can't transpose the compressed form. We need some functionality
to go back to the dense representation from the compressed form. CUTLASS
has this ability and we may be able to reuse their implementation if
they share the same meta/mask layout.

3. Bias propogating wrong way. For when the sparse matrix is first in
addmm, the bias is propogated columnwise instead of rowwise. Note that
this is
fine when the second matrix is dense, since we transpose the result
anyways.

4. Padding - currently cusparselt only supports dimensions that are
multiples of 16, 8, 4 (depending on dtype), so we should add padding so
that we can satisfy this constraint for all matrices.

dtypes supported:
- int8
- fp16
- bf16
- fp32

Ops supported:
``
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.addmm(dense, sparse)
torch.addmm(sparse, dense)
```
stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265
Pull Request resolved: #97546

ghstack-source-id: 5a3afdda1ee8bb872ed713a935f3618e64694314
Pull Request resolved: #102135
…lass"

This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight))

``

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `contiguous_output` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With contiguous_output set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight,
contiguous_output=True))

``

However there are some kinks within the workflow we still
need to resolve.

1. Integration with CUTLASS - sometimes CUTLASS can be faster and
supports more advanced epilogue fusions, like SwigLu, so we need some
way of deciding when to dispatch into CUTLASS vs CUSPARSELT.

2. No way to go from compressed matrix to sparse matrix - Currently
this means we do not support .t() fully on SemiStructuredSparseTensors,
since we can't transpose the compressed form. We need some functionality
to go back to the dense representation from the compressed form. CUTLASS
has this ability and we may be able to reuse their implementation if
they share the same meta/mask layout.

3. Bias propogating wrong way. For when the sparse matrix is first in
addmm, the bias is propogated columnwise instead of rowwise. Note that
this is
fine when the second matrix is dense, since we transpose the result
anyways.

4. Padding - currently cusparselt only supports dimensions that are
multiples of 16, 8, 4 (depending on dtype), so we should add padding so
that we can satisfy this constraint for all matrices.

dtypes supported:
- int8
- fp16
- bf16
- fp32

Ops supported:
``
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.addmm(dense, sparse)
torch.addmm(sparse, dense)
```
stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265
Pull Request resolved: #97546

[ghstack-poisoned]
jcaip added a commit that referenced this pull request May 30, 2023
This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight))

``

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `contiguous_output` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With contiguous_output set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight,
contiguous_output=True))

``

However there are some kinks within the workflow we still
need to resolve.

1. Integration with CUTLASS - sometimes CUTLASS can be faster and
supports more advanced epilogue fusions, like SwigLu, so we need some
way of deciding when to dispatch into CUTLASS vs CUSPARSELT.

2. No way to go from compressed matrix to sparse matrix - Currently
this means we do not support .t() fully on SemiStructuredSparseTensors,
since we can't transpose the compressed form. We need some functionality
to go back to the dense representation from the compressed form. CUTLASS
has this ability and we may be able to reuse their implementation if
they share the same meta/mask layout.

3. Bias propogating wrong way. For when the sparse matrix is first in
addmm, the bias is propogated columnwise instead of rowwise. Note that
this is
fine when the second matrix is dense, since we transpose the result
anyways.

4. Padding - currently cusparselt only supports dimensions that are
multiples of 16, 8, 4 (depending on dtype), so we should add padding so
that we can satisfy this constraint for all matrices.

dtypes supported:
- int8
- fp16
- bf16
- fp32

Ops supported:
``
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.addmm(dense, sparse)
torch.addmm(sparse, dense)
```
stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265
Pull Request resolved: #97546

ghstack-source-id: 7aad1072bae0d757c991b2954aea2107c0bbca03
Pull Request resolved: #102135
…ass"

This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight))

``

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `contiguous_output` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With contiguous_output set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight,
contiguous_output=True))

``

However there are some kinks within the workflow we still
need to resolve.

1. Integration with CUTLASS - sometimes CUTLASS can be faster and
supports more advanced epilogue fusions, like SwigLu, so we need some
way of deciding when to dispatch into CUTLASS vs CUSPARSELT.

2. No way to go from compressed matrix to sparse matrix - Currently
this means we do not support .t() fully on SemiStructuredSparseTensors,
since we can't transpose the compressed form. We need some functionality
to go back to the dense representation from the compressed form. CUTLASS
has this ability and we may be able to reuse their implementation if
they share the same meta/mask layout.

3. Bias propogating wrong way. For when the sparse matrix is first in
addmm, the bias is propogated columnwise instead of rowwise. Note that
this is
fine when the second matrix is dense, since we transpose the result
anyways.

4. Padding - currently cusparselt only supports dimensions that are
multiples of 16, 8, 4 (depending on dtype), so we should add padding so
that we can satisfy this constraint for all matrices.

dtypes supported:
- int8
- fp16
- bf16
- fp32

Ops supported:
``
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.addmm(dense, sparse)
torch.addmm(sparse, dense)
```
stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265
Pull Request resolved: #97546

[ghstack-poisoned]
@jcaip jcaip changed the title [core][pruning][feature] cuSPARSELT bindings + tensor subclass core][pruning][feature] cuSPARSELT bindings + tensor subclass May 30, 2023
jcaip added a commit that referenced this pull request May 30, 2023
This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight))

``

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `contiguous_output` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With contiguous_output set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight,
contiguous_output=True))

``

However there are some kinks within the workflow we still
need to resolve.

1. Integration with CUTLASS - sometimes CUTLASS can be faster and
supports more advanced epilogue fusions, like SwigLu, so we need some
way of deciding when to dispatch into CUTLASS vs CUSPARSELT.

2. No way to go from compressed matrix to sparse matrix - Currently
this means we do not support .t() fully on SemiStructuredSparseTensors,
since we can't transpose the compressed form. We need some functionality
to go back to the dense representation from the compressed form. CUTLASS
has this ability and we may be able to reuse their implementation if
they share the same meta/mask layout.

3. Bias propogating wrong way. For when the sparse matrix is first in
addmm, the bias is propogated columnwise instead of rowwise. Note that
this is
fine when the second matrix is dense, since we transpose the result
anyways.

4. Padding - currently cusparselt only supports dimensions that are
multiples of 16, 8, 4 (depending on dtype), so we should add padding so
that we can satisfy this constraint for all matrices.

dtypes supported:
- int8
- fp16
- bf16
- fp32

Ops supported:
``
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.addmm(dense, sparse)
torch.addmm(sparse, dense)
```
stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265
Pull Request resolved: #97546

ghstack-source-id: 7aad1072bae0d757c991b2954aea2107c0bbca03
Pull Request resolved: #102135
…ass"

This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight))

``

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `contiguous_output` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With contiguous_output set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

``
from torch.ao.pruning import SemiStructuredSparseTensor
model = Model()

model.linear.weight =
nn.Parameter(SemiStructuredSparseTensor(model.linear.weight,
contiguous_output=True))

``

However there are some kinks within the workflow we still
need to resolve.

1. Integration with CUTLASS - sometimes CUTLASS can be faster and
supports more advanced epilogue fusions, like SwigLu, so we need some
way of deciding when to dispatch into CUTLASS vs CUSPARSELT.

2. No way to go from compressed matrix to sparse matrix - Currently
this means we do not support .t() fully on SemiStructuredSparseTensors,
since we can't transpose the compressed form. We need some functionality
to go back to the dense representation from the compressed form. CUTLASS
has this ability and we may be able to reuse their implementation if
they share the same meta/mask layout.

3. Bias propogating wrong way. For when the sparse matrix is first in
addmm, the bias is propogated columnwise instead of rowwise. Note that
this is
fine when the second matrix is dense, since we transpose the result
anyways.

4. Padding - currently cusparselt only supports dimensions that are
multiples of 16, 8, 4 (depending on dtype), so we should add padding so
that we can satisfy this constraint for all matrices.

dtypes supported:
- int8
- fp16
- bf16
- fp32

Ops supported:
``
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.addmm(dense, sparse)
torch.addmm(sparse, dense)
```
stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265
Pull Request resolved: #97546

[ghstack-poisoned]
…lass"

This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor
model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `fuse_transpose` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With fuse_transpose set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor

SemiStructuredSparseTensor.fuse_transpose = True

model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

dtypes supported:
```
- int8
- fp16
- bf16
- fp32
```

ops supported:
```
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

[ghstack-poisoned]
@jcaip jcaip changed the title core][pruning][feature] cuSPARSELT bindings + tensor subclass [core][pruning][feature] cuSPARSELT bindings + tensor subclass May 30, 2023
…lass"

This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor
model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `fuse_transpose` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With fuse_transpose set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor

SemiStructuredSparseTensor.fuse_transpose = True

model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

dtypes supported:
```
- int8
- fp16
- bf16
- fp32
```

ops supported:
```
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

[ghstack-poisoned]
jcaip added a commit that referenced this pull request May 30, 2023
This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor
model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `fuse_transpose` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With fuse_transpose set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor

SemiStructuredSparseTensor.fuse_transpose = True

model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

dtypes supported:
```
- int8
- fp16
- bf16
- fp32
```

ops supported:
```
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```
ghstack-source-id: bb303a8abaed1284ca124b873951739ce14c9476
Pull Request resolved: #102135
…lass"

This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor
model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `fuse_transpose` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With fuse_transpose set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor

SemiStructuredSparseTensor.fuse_transpose = True

model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

dtypes supported:
```
- int8
- fp16
- bf16
- fp32
```

ops supported:
```
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

[ghstack-poisoned]
jcaip added a commit that referenced this pull request May 30, 2023
This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor
model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `fuse_transpose` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With fuse_transpose set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor

SemiStructuredSparseTensor.fuse_transpose = True

model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

dtypes supported:
```
- int8
- fp16
- bf16
- fp32
```

ops supported:
```
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```
ghstack-source-id: 6d0ad44bcc38fdbe6de3ce9832c677d63dc0dc2e
Pull Request resolved: #102135
…lass"

This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor
model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `fuse_transpose` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With fuse_transpose set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor

SemiStructuredSparseTensor.fuse_transpose = True

model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

dtypes supported:
```
- int8
- fp16
- bf16
- fp32
```

ops supported:
```
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

[ghstack-poisoned]
jcaip added a commit that referenced this pull request May 30, 2023
This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor
model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `fuse_transpose` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With fuse_transpose set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor

SemiStructuredSparseTensor.fuse_transpose = True

model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

dtypes supported:
```
- int8
- fp16
- bf16
- fp32
```

ops supported:
```
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```
ghstack-source-id: edef56e23332289393bab0f213d0de53920c39de
Pull Request resolved: #102135
…lass"

This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor
model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `fuse_transpose` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With fuse_transpose set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor

SemiStructuredSparseTensor.fuse_transpose = True

model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

dtypes supported:
```
- int8
- fp16
- bf16
- fp32
```

ops supported:
```
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

[ghstack-poisoned]
jcaip added a commit that referenced this pull request May 30, 2023
This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor
model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `fuse_transpose` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With fuse_transpose set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor

SemiStructuredSparseTensor.fuse_transpose = True

model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

dtypes supported:
```
- int8
- fp16
- bf16
- fp32
```

ops supported:
```
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```
ghstack-source-id: e94e61260348e45b5895611b49e7111ad4e091ad
Pull Request resolved: #102135
…lass"

This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor
model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `fuse_transpose` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With fuse_transpose set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor

SemiStructuredSparseTensor.fuse_transpose = True

model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

dtypes supported:
```
- int8
- fp16
- bf16
- fp32
```

ops supported:
```
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

[ghstack-poisoned]
jcaip added a commit that referenced this pull request May 30, 2023
This PR integrates cuSPARSELt v0.4.0.7 into pytorch.

It is composed of two elements:
1. A torch custom class that is used to store and manage the cusparselt
constructs needed to do sparse matrix multiplication
2. A tensor subclass that overrides the dispatch of torch.t(),
torch.mm() and torch.addmm() to use the cusparselt sparse matmul and
also store the custom class state.

For performance and memory overhead reasons, we'd like to cache the
descriptors and
compressed matrix that are used in cusparselt. However this makes it a
bit tricky, since this means there's some state that we have to manage.

Previously, we were holding this state in a cuSPARSELtLinear module and
swapping that module with nn.Linear. This works fine for Linear, since
the forward() function is just an addmm, but doesn't work great when
expanding to modules that have a more complicated forward() function,
since we need to copy over all the custom logic.

With tensor subclasses, we can store the state on the tensor itself, and
then at dispatch time retrieve it from the tensor. This essentially
defines a custom matmul function for each tensor.

Additionally, conceptually cusparselt matmul is closer to
torch.addmm/torch.mm so it makes more sense to do the replacement at
that level. It also leads to a cleaner UX, where previously a user had
to use our pruning flow e2e in order to utilize `convert`, now all they
have to do is get their weights into a 2:4 dense format (with 0s) and
then all they have to do to get accelerated inference is

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor
model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

Our pruning flow has functionality to get the weights in this format by
using `pruner.squash_mask()`

I've also added an addtional `fuse_transpose` flag which lets us fuse
a subsequent transpose operation into the cusparselt matmul call. This
is especially useful for distributed settings, since the output of our
cusparselt matmul is Transposed and that messes up the collect/gather.

With fuse_transpose set to True, the output will be contiguous,
meaning it should perfectly match F.linear and can be used as a drop-in
replacement in distributed settings. You can see an example of how to
use it below.

```
from torch.sparse import SemiStructuredSparseTensor
from torch.sparse import to_semi_structured_sparse_tensor

SemiStructuredSparseTensor.fuse_transpose = True

model = Model()

model.linear.weight =
nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight))

```

dtypes supported:
```
- int8
- fp16
- bf16
- fp32
```

ops supported:
```
torch.addmm(bias, dense, sparse)
torch.addmm(bias, sparse, dense)
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```
ghstack-source-id: db88323f8035d45715b6eb21759b18beb091af04
Pull Request resolved: #102135
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@huydhn
Copy link
Contributor

huydhn commented Jun 27, 2023

@pytorchbot revert -m 'test_sparse_semi_structured.py::TestSparseSemiStructuredCUDA::test_mm_sparse_first_NT_cuda_int8 is still failing CUDA trunk jobs https://hud.pytorch.org/pytorch/pytorch/commit/aea771de30427998e83010459b69da1ab66f0879' -c landrace

Sorry for being unclear. As this is a landrace, you would need to rebase your PR to surface the error and fix the issue before trying to reland the change. Also @pytorchbot merge -r would also be useful here to rebase and merge in one go.

@huydhn huydhn reopened this Jun 27, 2023
@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@jcaip your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Jun 27, 2023
…subclass (#102135)"

This reverts commit aea771d.

Reverted #102135 on behalf of https://github.com/huydhn due to test_sparse_semi_structured.py::TestSparseSemiStructuredCUDA::test_mm_sparse_first_NT_cuda_int8 is still failing CUDA trunk jobs https://hud.pytorch.org/pytorch/pytorch/commit/aea771de30427998e83010459b69da1ab66f0879 ([comment](#102135 (comment)))
@huydhn
Copy link
Contributor

huydhn commented Jun 27, 2023

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to

Aborting rebase because rebasing the branch resulted in the same sha as the target branch.
This usually happens because the PR has already been merged.  Please rebase locally and push.

Raised by https://github.com/pytorch/pytorch/actions/runs/5385570068

@huydhn
Copy link
Contributor

huydhn commented Jun 27, 2023

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to

Aborting rebase because rebasing the branch resulted in the same sha as the target branch.
This usually happens because the PR has already been merged.  Please rebase locally and push.

Raised by https://github.com/pytorch/pytorch/actions/runs/5385648408

@huydhn
Copy link
Contributor

huydhn commented Jun 27, 2023

Rebase failed due to

Aborting rebase because rebasing the branch resulted in the same sha as the target branch.
This usually happens because the PR has already been merged.  Please rebase locally and push.

Raised by https://github.com/pytorch/pytorch/actions/runs/5385648408

Oh well, please do a rebase locally and push on your end then. This is probably ghstack-related.

…or subclass"

This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.

In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.

This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
  sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
  semi-structured sparse bool mask and creates an instance of the
  subclass.

**SparseSemiStructuredTensor**

The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](#103700) for faster matmul, better dtype converage, and relaxed shape
constraints.

Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().

Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.

**to_sparse_semi_structured**

This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.

Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.

**User Details**

We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:

```
from torch.sparse import to_sparse_semi_structured

mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()

linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
                                                       mask=linear.weight.bool())

```

This also updates tests and the `torch.sparse` module docstring to
reflect these changes.

cc alexsamardzic nikitaved pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
jcaip added a commit that referenced this pull request Jun 27, 2023
This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.

In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.

This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
  sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
  semi-structured sparse bool mask and creates an instance of the
  subclass.

**SparseSemiStructuredTensor**

The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](#103700) for faster matmul, better dtype converage, and relaxed shape
constraints.

Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().

Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.

**to_sparse_semi_structured**

This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.

Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.

**User Details**

We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:

```
from torch.sparse import to_sparse_semi_structured

mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()

linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
                                                       mask=linear.weight.bool())

```

This also updates tests and the `torch.sparse` module docstring to
reflect these changes.

ghstack-source-id: 6c5db7b796cfa853809cfa4d27082992054426ea
Pull Request resolved: #102135
…or subclass"

This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.

In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.

This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
  sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
  semi-structured sparse bool mask and creates an instance of the
  subclass.

**SparseSemiStructuredTensor**

The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](#103700) for faster matmul, better dtype converage, and relaxed shape
constraints.

Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().

Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.

**to_sparse_semi_structured**

This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.

Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.

**User Details**

We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:

```
from torch.sparse import to_sparse_semi_structured

mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()

linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
                                                       mask=linear.weight.bool())

```

This also updates tests and the `torch.sparse` module docstring to
reflect these changes.

cc alexsamardzic nikitaved pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
jcaip added a commit that referenced this pull request Jun 27, 2023
This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.

In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.

This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
  sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
  semi-structured sparse bool mask and creates an instance of the
  subclass.

**SparseSemiStructuredTensor**

The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](#103700) for faster matmul, better dtype converage, and relaxed shape
constraints.

Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().

Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.

**to_sparse_semi_structured**

This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.

Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.

**User Details**

We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:

```
from torch.sparse import to_sparse_semi_structured

mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()

linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
                                                       mask=linear.weight.bool())

```

This also updates tests and the `torch.sparse` module docstring to
reflect these changes.

ghstack-source-id: 1ebad9ebcf9df1e275f449459aec48ccc3e80639
Pull Request resolved: #102135
@jcaip
Copy link
Contributor Author

jcaip commented Jun 27, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: sparse Related to torch.sparse release notes: AO Pruning pruning in the torch.ao and nn.utils.prune Reverted topic: new features topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants