-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core][pruning][sparse][feature] SparseSemiStructured tensor subclass #102135
Conversation
This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight)) `` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `contiguous_output` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With contiguous_output set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight, contiguous_output=True)) `` However there are some kinks within the workflow we still need to resolve. 1. Integration with CUTLASS - sometimes CUTLASS can be faster and supports more advanced epilogue fusions, like SwigLu, so we need some way of deciding when to dispatch into CUTLASS vs CUSPARSELT. 2. No way to go from compressed matrix to sparse matrix - Currently this means we do not support .t() fully on SemiStructuredSparseTensors, since we can't transpose the compressed form. We need some functionality to go back to the dense representation from the compressed form. CUTLASS has this ability and we may be able to reuse their implementation if they share the same meta/mask layout. 3. Bias propogating wrong way. For when the sparse matrix is first in addmm, the bias is propogated columnwise instead of rowwise. Note that this is fine when the second matrix is dense, since we transpose the result anyways. 4. Padding - currently cusparselt only supports dimensions that are multiples of 16, 8, 4 (depending on dtype), so we should add padding so that we can satisfy this constraint for all matrices. dtypes supported: - int8 - fp16 - bf16 - fp32 Ops supported: `` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.addmm(dense, sparse) torch.addmm(sparse, dense) ``` stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265 Pull Request resolved: #97546 [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/102135
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b5d528e: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight)) `` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `contiguous_output` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With contiguous_output set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight, contiguous_output=True)) `` However there are some kinks within the workflow we still need to resolve. 1. Integration with CUTLASS - sometimes CUTLASS can be faster and supports more advanced epilogue fusions, like SwigLu, so we need some way of deciding when to dispatch into CUTLASS vs CUSPARSELT. 2. No way to go from compressed matrix to sparse matrix - Currently this means we do not support .t() fully on SemiStructuredSparseTensors, since we can't transpose the compressed form. We need some functionality to go back to the dense representation from the compressed form. CUTLASS has this ability and we may be able to reuse their implementation if they share the same meta/mask layout. 3. Bias propogating wrong way. For when the sparse matrix is first in addmm, the bias is propogated columnwise instead of rowwise. Note that this is fine when the second matrix is dense, since we transpose the result anyways. 4. Padding - currently cusparselt only supports dimensions that are multiples of 16, 8, 4 (depending on dtype), so we should add padding so that we can satisfy this constraint for all matrices. dtypes supported: - int8 - fp16 - bf16 - fp32 Ops supported: `` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.addmm(dense, sparse) torch.addmm(sparse, dense) ``` stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265 Pull Request resolved: #97546 ghstack-source-id: 4fe4de500601c5224bbcb7429775ab48f91f42b1 Pull Request resolved: #102135
…lass" This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight)) `` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `contiguous_output` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With contiguous_output set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight, contiguous_output=True)) `` However there are some kinks within the workflow we still need to resolve. 1. Integration with CUTLASS - sometimes CUTLASS can be faster and supports more advanced epilogue fusions, like SwigLu, so we need some way of deciding when to dispatch into CUTLASS vs CUSPARSELT. 2. No way to go from compressed matrix to sparse matrix - Currently this means we do not support .t() fully on SemiStructuredSparseTensors, since we can't transpose the compressed form. We need some functionality to go back to the dense representation from the compressed form. CUTLASS has this ability and we may be able to reuse their implementation if they share the same meta/mask layout. 3. Bias propogating wrong way. For when the sparse matrix is first in addmm, the bias is propogated columnwise instead of rowwise. Note that this is fine when the second matrix is dense, since we transpose the result anyways. 4. Padding - currently cusparselt only supports dimensions that are multiples of 16, 8, 4 (depending on dtype), so we should add padding so that we can satisfy this constraint for all matrices. dtypes supported: - int8 - fp16 - bf16 - fp32 Ops supported: `` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.addmm(dense, sparse) torch.addmm(sparse, dense) ``` stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265 Pull Request resolved: #97546 [ghstack-poisoned]
This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight)) `` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `contiguous_output` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With contiguous_output set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight, contiguous_output=True)) `` However there are some kinks within the workflow we still need to resolve. 1. Integration with CUTLASS - sometimes CUTLASS can be faster and supports more advanced epilogue fusions, like SwigLu, so we need some way of deciding when to dispatch into CUTLASS vs CUSPARSELT. 2. No way to go from compressed matrix to sparse matrix - Currently this means we do not support .t() fully on SemiStructuredSparseTensors, since we can't transpose the compressed form. We need some functionality to go back to the dense representation from the compressed form. CUTLASS has this ability and we may be able to reuse their implementation if they share the same meta/mask layout. 3. Bias propogating wrong way. For when the sparse matrix is first in addmm, the bias is propogated columnwise instead of rowwise. Note that this is fine when the second matrix is dense, since we transpose the result anyways. 4. Padding - currently cusparselt only supports dimensions that are multiples of 16, 8, 4 (depending on dtype), so we should add padding so that we can satisfy this constraint for all matrices. dtypes supported: - int8 - fp16 - bf16 - fp32 Ops supported: `` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.addmm(dense, sparse) torch.addmm(sparse, dense) ``` stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265 Pull Request resolved: #97546 ghstack-source-id: 59b12262540227d409a5a328c6db218c00583f6a Pull Request resolved: #102135
…lass" This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight)) `` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `contiguous_output` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With contiguous_output set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight, contiguous_output=True)) `` However there are some kinks within the workflow we still need to resolve. 1. Integration with CUTLASS - sometimes CUTLASS can be faster and supports more advanced epilogue fusions, like SwigLu, so we need some way of deciding when to dispatch into CUTLASS vs CUSPARSELT. 2. No way to go from compressed matrix to sparse matrix - Currently this means we do not support .t() fully on SemiStructuredSparseTensors, since we can't transpose the compressed form. We need some functionality to go back to the dense representation from the compressed form. CUTLASS has this ability and we may be able to reuse their implementation if they share the same meta/mask layout. 3. Bias propogating wrong way. For when the sparse matrix is first in addmm, the bias is propogated columnwise instead of rowwise. Note that this is fine when the second matrix is dense, since we transpose the result anyways. 4. Padding - currently cusparselt only supports dimensions that are multiples of 16, 8, 4 (depending on dtype), so we should add padding so that we can satisfy this constraint for all matrices. dtypes supported: - int8 - fp16 - bf16 - fp32 Ops supported: `` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.addmm(dense, sparse) torch.addmm(sparse, dense) ``` stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265 Pull Request resolved: #97546 [ghstack-poisoned]
This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight)) `` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `contiguous_output` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With contiguous_output set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight, contiguous_output=True)) `` However there are some kinks within the workflow we still need to resolve. 1. Integration with CUTLASS - sometimes CUTLASS can be faster and supports more advanced epilogue fusions, like SwigLu, so we need some way of deciding when to dispatch into CUTLASS vs CUSPARSELT. 2. No way to go from compressed matrix to sparse matrix - Currently this means we do not support .t() fully on SemiStructuredSparseTensors, since we can't transpose the compressed form. We need some functionality to go back to the dense representation from the compressed form. CUTLASS has this ability and we may be able to reuse their implementation if they share the same meta/mask layout. 3. Bias propogating wrong way. For when the sparse matrix is first in addmm, the bias is propogated columnwise instead of rowwise. Note that this is fine when the second matrix is dense, since we transpose the result anyways. 4. Padding - currently cusparselt only supports dimensions that are multiples of 16, 8, 4 (depending on dtype), so we should add padding so that we can satisfy this constraint for all matrices. dtypes supported: - int8 - fp16 - bf16 - fp32 Ops supported: `` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.addmm(dense, sparse) torch.addmm(sparse, dense) ``` stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265 Pull Request resolved: #97546 ghstack-source-id: 6ecb9000462ba5e0f343ece7b2f6de95a9e4fdf7 Pull Request resolved: #102135
…lass" This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight)) `` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `contiguous_output` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With contiguous_output set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight, contiguous_output=True)) `` However there are some kinks within the workflow we still need to resolve. 1. Integration with CUTLASS - sometimes CUTLASS can be faster and supports more advanced epilogue fusions, like SwigLu, so we need some way of deciding when to dispatch into CUTLASS vs CUSPARSELT. 2. No way to go from compressed matrix to sparse matrix - Currently this means we do not support .t() fully on SemiStructuredSparseTensors, since we can't transpose the compressed form. We need some functionality to go back to the dense representation from the compressed form. CUTLASS has this ability and we may be able to reuse their implementation if they share the same meta/mask layout. 3. Bias propogating wrong way. For when the sparse matrix is first in addmm, the bias is propogated columnwise instead of rowwise. Note that this is fine when the second matrix is dense, since we transpose the result anyways. 4. Padding - currently cusparselt only supports dimensions that are multiples of 16, 8, 4 (depending on dtype), so we should add padding so that we can satisfy this constraint for all matrices. dtypes supported: - int8 - fp16 - bf16 - fp32 Ops supported: `` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.addmm(dense, sparse) torch.addmm(sparse, dense) ``` stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265 Pull Request resolved: #97546 [ghstack-poisoned]
This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight)) `` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `contiguous_output` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With contiguous_output set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight, contiguous_output=True)) `` However there are some kinks within the workflow we still need to resolve. 1. Integration with CUTLASS - sometimes CUTLASS can be faster and supports more advanced epilogue fusions, like SwigLu, so we need some way of deciding when to dispatch into CUTLASS vs CUSPARSELT. 2. No way to go from compressed matrix to sparse matrix - Currently this means we do not support .t() fully on SemiStructuredSparseTensors, since we can't transpose the compressed form. We need some functionality to go back to the dense representation from the compressed form. CUTLASS has this ability and we may be able to reuse their implementation if they share the same meta/mask layout. 3. Bias propogating wrong way. For when the sparse matrix is first in addmm, the bias is propogated columnwise instead of rowwise. Note that this is fine when the second matrix is dense, since we transpose the result anyways. 4. Padding - currently cusparselt only supports dimensions that are multiples of 16, 8, 4 (depending on dtype), so we should add padding so that we can satisfy this constraint for all matrices. dtypes supported: - int8 - fp16 - bf16 - fp32 Ops supported: `` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.addmm(dense, sparse) torch.addmm(sparse, dense) ``` stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265 Pull Request resolved: #97546 ghstack-source-id: 5a3afdda1ee8bb872ed713a935f3618e64694314 Pull Request resolved: #102135
…lass" This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight)) `` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `contiguous_output` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With contiguous_output set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight, contiguous_output=True)) `` However there are some kinks within the workflow we still need to resolve. 1. Integration with CUTLASS - sometimes CUTLASS can be faster and supports more advanced epilogue fusions, like SwigLu, so we need some way of deciding when to dispatch into CUTLASS vs CUSPARSELT. 2. No way to go from compressed matrix to sparse matrix - Currently this means we do not support .t() fully on SemiStructuredSparseTensors, since we can't transpose the compressed form. We need some functionality to go back to the dense representation from the compressed form. CUTLASS has this ability and we may be able to reuse their implementation if they share the same meta/mask layout. 3. Bias propogating wrong way. For when the sparse matrix is first in addmm, the bias is propogated columnwise instead of rowwise. Note that this is fine when the second matrix is dense, since we transpose the result anyways. 4. Padding - currently cusparselt only supports dimensions that are multiples of 16, 8, 4 (depending on dtype), so we should add padding so that we can satisfy this constraint for all matrices. dtypes supported: - int8 - fp16 - bf16 - fp32 Ops supported: `` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.addmm(dense, sparse) torch.addmm(sparse, dense) ``` stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265 Pull Request resolved: #97546 [ghstack-poisoned]
This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight)) `` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `contiguous_output` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With contiguous_output set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight, contiguous_output=True)) `` However there are some kinks within the workflow we still need to resolve. 1. Integration with CUTLASS - sometimes CUTLASS can be faster and supports more advanced epilogue fusions, like SwigLu, so we need some way of deciding when to dispatch into CUTLASS vs CUSPARSELT. 2. No way to go from compressed matrix to sparse matrix - Currently this means we do not support .t() fully on SemiStructuredSparseTensors, since we can't transpose the compressed form. We need some functionality to go back to the dense representation from the compressed form. CUTLASS has this ability and we may be able to reuse their implementation if they share the same meta/mask layout. 3. Bias propogating wrong way. For when the sparse matrix is first in addmm, the bias is propogated columnwise instead of rowwise. Note that this is fine when the second matrix is dense, since we transpose the result anyways. 4. Padding - currently cusparselt only supports dimensions that are multiples of 16, 8, 4 (depending on dtype), so we should add padding so that we can satisfy this constraint for all matrices. dtypes supported: - int8 - fp16 - bf16 - fp32 Ops supported: `` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.addmm(dense, sparse) torch.addmm(sparse, dense) ``` stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265 Pull Request resolved: #97546 ghstack-source-id: 7aad1072bae0d757c991b2954aea2107c0bbca03 Pull Request resolved: #102135
…ass" This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight)) `` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `contiguous_output` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With contiguous_output set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight, contiguous_output=True)) `` However there are some kinks within the workflow we still need to resolve. 1. Integration with CUTLASS - sometimes CUTLASS can be faster and supports more advanced epilogue fusions, like SwigLu, so we need some way of deciding when to dispatch into CUTLASS vs CUSPARSELT. 2. No way to go from compressed matrix to sparse matrix - Currently this means we do not support .t() fully on SemiStructuredSparseTensors, since we can't transpose the compressed form. We need some functionality to go back to the dense representation from the compressed form. CUTLASS has this ability and we may be able to reuse their implementation if they share the same meta/mask layout. 3. Bias propogating wrong way. For when the sparse matrix is first in addmm, the bias is propogated columnwise instead of rowwise. Note that this is fine when the second matrix is dense, since we transpose the result anyways. 4. Padding - currently cusparselt only supports dimensions that are multiples of 16, 8, 4 (depending on dtype), so we should add padding so that we can satisfy this constraint for all matrices. dtypes supported: - int8 - fp16 - bf16 - fp32 Ops supported: `` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.addmm(dense, sparse) torch.addmm(sparse, dense) ``` stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265 Pull Request resolved: #97546 [ghstack-poisoned]
This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight)) `` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `contiguous_output` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With contiguous_output set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight, contiguous_output=True)) `` However there are some kinks within the workflow we still need to resolve. 1. Integration with CUTLASS - sometimes CUTLASS can be faster and supports more advanced epilogue fusions, like SwigLu, so we need some way of deciding when to dispatch into CUTLASS vs CUSPARSELT. 2. No way to go from compressed matrix to sparse matrix - Currently this means we do not support .t() fully on SemiStructuredSparseTensors, since we can't transpose the compressed form. We need some functionality to go back to the dense representation from the compressed form. CUTLASS has this ability and we may be able to reuse their implementation if they share the same meta/mask layout. 3. Bias propogating wrong way. For when the sparse matrix is first in addmm, the bias is propogated columnwise instead of rowwise. Note that this is fine when the second matrix is dense, since we transpose the result anyways. 4. Padding - currently cusparselt only supports dimensions that are multiples of 16, 8, 4 (depending on dtype), so we should add padding so that we can satisfy this constraint for all matrices. dtypes supported: - int8 - fp16 - bf16 - fp32 Ops supported: `` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.addmm(dense, sparse) torch.addmm(sparse, dense) ``` stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265 Pull Request resolved: #97546 ghstack-source-id: 7aad1072bae0d757c991b2954aea2107c0bbca03 Pull Request resolved: #102135
…ass" This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight)) `` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `contiguous_output` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With contiguous_output set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. `` from torch.ao.pruning import SemiStructuredSparseTensor model = Model() model.linear.weight = nn.Parameter(SemiStructuredSparseTensor(model.linear.weight, contiguous_output=True)) `` However there are some kinks within the workflow we still need to resolve. 1. Integration with CUTLASS - sometimes CUTLASS can be faster and supports more advanced epilogue fusions, like SwigLu, so we need some way of deciding when to dispatch into CUTLASS vs CUSPARSELT. 2. No way to go from compressed matrix to sparse matrix - Currently this means we do not support .t() fully on SemiStructuredSparseTensors, since we can't transpose the compressed form. We need some functionality to go back to the dense representation from the compressed form. CUTLASS has this ability and we may be able to reuse their implementation if they share the same meta/mask layout. 3. Bias propogating wrong way. For when the sparse matrix is first in addmm, the bias is propogated columnwise instead of rowwise. Note that this is fine when the second matrix is dense, since we transpose the result anyways. 4. Padding - currently cusparselt only supports dimensions that are multiples of 16, 8, 4 (depending on dtype), so we should add padding so that we can satisfy this constraint for all matrices. dtypes supported: - int8 - fp16 - bf16 - fp32 Ops supported: `` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.addmm(dense, sparse) torch.addmm(sparse, dense) ``` stack-source-id: a683d4725f1c7af0f59f9a35eaec76e3a9cfd265 Pull Request resolved: #97546 [ghstack-poisoned]
…lass" This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `fuse_transpose` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With fuse_transpose set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor SemiStructuredSparseTensor.fuse_transpose = True model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` dtypes supported: ``` - int8 - fp16 - bf16 - fp32 ``` ops supported: ``` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` [ghstack-poisoned]
…lass" This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `fuse_transpose` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With fuse_transpose set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor SemiStructuredSparseTensor.fuse_transpose = True model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` dtypes supported: ``` - int8 - fp16 - bf16 - fp32 ``` ops supported: ``` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` [ghstack-poisoned]
This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `fuse_transpose` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With fuse_transpose set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor SemiStructuredSparseTensor.fuse_transpose = True model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` dtypes supported: ``` - int8 - fp16 - bf16 - fp32 ``` ops supported: ``` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` ghstack-source-id: bb303a8abaed1284ca124b873951739ce14c9476 Pull Request resolved: #102135
…lass" This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `fuse_transpose` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With fuse_transpose set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor SemiStructuredSparseTensor.fuse_transpose = True model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` dtypes supported: ``` - int8 - fp16 - bf16 - fp32 ``` ops supported: ``` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` [ghstack-poisoned]
This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `fuse_transpose` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With fuse_transpose set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor SemiStructuredSparseTensor.fuse_transpose = True model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` dtypes supported: ``` - int8 - fp16 - bf16 - fp32 ``` ops supported: ``` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` ghstack-source-id: 6d0ad44bcc38fdbe6de3ce9832c677d63dc0dc2e Pull Request resolved: #102135
…lass" This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `fuse_transpose` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With fuse_transpose set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor SemiStructuredSparseTensor.fuse_transpose = True model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` dtypes supported: ``` - int8 - fp16 - bf16 - fp32 ``` ops supported: ``` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` [ghstack-poisoned]
This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `fuse_transpose` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With fuse_transpose set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor SemiStructuredSparseTensor.fuse_transpose = True model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` dtypes supported: ``` - int8 - fp16 - bf16 - fp32 ``` ops supported: ``` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` ghstack-source-id: edef56e23332289393bab0f213d0de53920c39de Pull Request resolved: #102135
…lass" This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `fuse_transpose` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With fuse_transpose set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor SemiStructuredSparseTensor.fuse_transpose = True model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` dtypes supported: ``` - int8 - fp16 - bf16 - fp32 ``` ops supported: ``` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` [ghstack-poisoned]
This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `fuse_transpose` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With fuse_transpose set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor SemiStructuredSparseTensor.fuse_transpose = True model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` dtypes supported: ``` - int8 - fp16 - bf16 - fp32 ``` ops supported: ``` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` ghstack-source-id: e94e61260348e45b5895611b49e7111ad4e091ad Pull Request resolved: #102135
…lass" This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `fuse_transpose` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With fuse_transpose set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor SemiStructuredSparseTensor.fuse_transpose = True model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` dtypes supported: ``` - int8 - fp16 - bf16 - fp32 ``` ops supported: ``` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` [ghstack-poisoned]
This PR integrates cuSPARSELt v0.4.0.7 into pytorch. It is composed of two elements: 1. A torch custom class that is used to store and manage the cusparselt constructs needed to do sparse matrix multiplication 2. A tensor subclass that overrides the dispatch of torch.t(), torch.mm() and torch.addmm() to use the cusparselt sparse matmul and also store the custom class state. For performance and memory overhead reasons, we'd like to cache the descriptors and compressed matrix that are used in cusparselt. However this makes it a bit tricky, since this means there's some state that we have to manage. Previously, we were holding this state in a cuSPARSELtLinear module and swapping that module with nn.Linear. This works fine for Linear, since the forward() function is just an addmm, but doesn't work great when expanding to modules that have a more complicated forward() function, since we need to copy over all the custom logic. With tensor subclasses, we can store the state on the tensor itself, and then at dispatch time retrieve it from the tensor. This essentially defines a custom matmul function for each tensor. Additionally, conceptually cusparselt matmul is closer to torch.addmm/torch.mm so it makes more sense to do the replacement at that level. It also leads to a cleaner UX, where previously a user had to use our pruning flow e2e in order to utilize `convert`, now all they have to do is get their weights into a 2:4 dense format (with 0s) and then all they have to do to get accelerated inference is ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` Our pruning flow has functionality to get the weights in this format by using `pruner.squash_mask()` I've also added an addtional `fuse_transpose` flag which lets us fuse a subsequent transpose operation into the cusparselt matmul call. This is especially useful for distributed settings, since the output of our cusparselt matmul is Transposed and that messes up the collect/gather. With fuse_transpose set to True, the output will be contiguous, meaning it should perfectly match F.linear and can be used as a drop-in replacement in distributed settings. You can see an example of how to use it below. ``` from torch.sparse import SemiStructuredSparseTensor from torch.sparse import to_semi_structured_sparse_tensor SemiStructuredSparseTensor.fuse_transpose = True model = Model() model.linear.weight = nn.Parameter(to_semi_structured_sparse_tensor(model.linear.weight)) ``` dtypes supported: ``` - int8 - fp16 - bf16 - fp32 ``` ops supported: ``` torch.addmm(bias, dense, sparse) torch.addmm(bias, sparse, dense) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` ghstack-source-id: db88323f8035d45715b6eb21759b18beb091af04 Pull Request resolved: #102135
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot revert -m 'test_sparse_semi_structured.py::TestSparseSemiStructuredCUDA::test_mm_sparse_first_NT_cuda_int8 is still failing CUDA trunk jobs https://hud.pytorch.org/pytorch/pytorch/commit/aea771de30427998e83010459b69da1ab66f0879' -c landrace Sorry for being unclear. As this is a landrace, you would need to rebase your PR to surface the error and fix the issue before trying to reland the change. Also |
@pytorchbot successfully started a revert job. Check the current status here. |
@jcaip your PR has been successfully reverted. |
…subclass (#102135)" This reverts commit aea771d. Reverted #102135 on behalf of https://github.com/huydhn due to test_sparse_semi_structured.py::TestSparseSemiStructuredCUDA::test_mm_sparse_first_NT_cuda_int8 is still failing CUDA trunk jobs https://hud.pytorch.org/pytorch/pytorch/commit/aea771de30427998e83010459b69da1ab66f0879 ([comment](#102135 (comment)))
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Rebase failed due to
Raised by https://github.com/pytorch/pytorch/actions/runs/5385570068 |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Rebase failed due to
Raised by https://github.com/pytorch/pytorch/actions/runs/5385648408 |
Oh well, please do a rebase locally and push on your end then. This is probably ghstack-related. |
…or subclass" This PR adds in support for semi-structured sparsity via a tensor subclass. It currently uses the CUTLASS kernels merged in PR #100881. In the future we plan to add in cuSPARSELt support (see the other PRs in the stack), which will give us larger performance gains. This PR adds in 2 things: - a Tensor subclass, `SparseSemiStructuredTensor` to store the sparse tensor in copmressed form and override `__torch_dispatch__`. - a conversion function that takes in a dense tensor and a semi-structured sparse bool mask and creates an instance of the subclass. **SparseSemiStructuredTensor** The subclass stores the dense tensor in a contiguous flattened tensor for future compatability with cuSPARSELt, which expects this format. Note that the CUTLASS kernels do not have this limitation, as the specified values and the metadata are passed separately in `_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings [here](#103700) for faster matmul, better dtype converage, and relaxed shape constraints. Since we currently don't have a way to go back from the sparse representation to the dense representation, and we store the weights in compressed form, we don't have a great way to handle .t(). Instead, we keep track of how often we've called transpose on our tensor, and if it's an unexpected number we throw an error. When the first argument is sparse, we expect an even number of calls to transpose, while when the second argument is sparse, we expect an odd number of calls. This is because we support second argument sparse matrix multiplications by using transpose properties. **to_sparse_semi_structured** This is a conversion function to convert a dense tensor and a semi-structured sparse bool mask into a subclass. Currently, we must pass in a bool mask, since we can't infer it becuase there may be additional zero elements in the dense tensor, so `tensor !=0` is not 2:4 sparse. Once we add either a method to derive the mask from the dense tensor or cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's own helper functions to create the metadata mask. **User Details** We have implemented support for the following ops for `torch.float16` and `torch.int8`: ``` torch.addmm(bias, dense, sparse.t()) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` The end user interface to accelerate a nn.Linaer module with the subclass would look like this: ``` from torch.sparse import to_sparse_semi_structured mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool() linear = Model(128, 128).half().cuda() linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight, mask=linear.weight.bool()) ``` This also updates tests and the `torch.sparse` module docstring to reflect these changes. cc alexsamardzic nikitaved pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
This PR adds in support for semi-structured sparsity via a tensor subclass. It currently uses the CUTLASS kernels merged in PR #100881. In the future we plan to add in cuSPARSELt support (see the other PRs in the stack), which will give us larger performance gains. This PR adds in 2 things: - a Tensor subclass, `SparseSemiStructuredTensor` to store the sparse tensor in copmressed form and override `__torch_dispatch__`. - a conversion function that takes in a dense tensor and a semi-structured sparse bool mask and creates an instance of the subclass. **SparseSemiStructuredTensor** The subclass stores the dense tensor in a contiguous flattened tensor for future compatability with cuSPARSELt, which expects this format. Note that the CUTLASS kernels do not have this limitation, as the specified values and the metadata are passed separately in `_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings [here](#103700) for faster matmul, better dtype converage, and relaxed shape constraints. Since we currently don't have a way to go back from the sparse representation to the dense representation, and we store the weights in compressed form, we don't have a great way to handle .t(). Instead, we keep track of how often we've called transpose on our tensor, and if it's an unexpected number we throw an error. When the first argument is sparse, we expect an even number of calls to transpose, while when the second argument is sparse, we expect an odd number of calls. This is because we support second argument sparse matrix multiplications by using transpose properties. **to_sparse_semi_structured** This is a conversion function to convert a dense tensor and a semi-structured sparse bool mask into a subclass. Currently, we must pass in a bool mask, since we can't infer it becuase there may be additional zero elements in the dense tensor, so `tensor !=0` is not 2:4 sparse. Once we add either a method to derive the mask from the dense tensor or cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's own helper functions to create the metadata mask. **User Details** We have implemented support for the following ops for `torch.float16` and `torch.int8`: ``` torch.addmm(bias, dense, sparse.t()) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` The end user interface to accelerate a nn.Linaer module with the subclass would look like this: ``` from torch.sparse import to_sparse_semi_structured mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool() linear = Model(128, 128).half().cuda() linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight, mask=linear.weight.bool()) ``` This also updates tests and the `torch.sparse` module docstring to reflect these changes. ghstack-source-id: 6c5db7b796cfa853809cfa4d27082992054426ea Pull Request resolved: #102135
…or subclass" This PR adds in support for semi-structured sparsity via a tensor subclass. It currently uses the CUTLASS kernels merged in PR #100881. In the future we plan to add in cuSPARSELt support (see the other PRs in the stack), which will give us larger performance gains. This PR adds in 2 things: - a Tensor subclass, `SparseSemiStructuredTensor` to store the sparse tensor in copmressed form and override `__torch_dispatch__`. - a conversion function that takes in a dense tensor and a semi-structured sparse bool mask and creates an instance of the subclass. **SparseSemiStructuredTensor** The subclass stores the dense tensor in a contiguous flattened tensor for future compatability with cuSPARSELt, which expects this format. Note that the CUTLASS kernels do not have this limitation, as the specified values and the metadata are passed separately in `_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings [here](#103700) for faster matmul, better dtype converage, and relaxed shape constraints. Since we currently don't have a way to go back from the sparse representation to the dense representation, and we store the weights in compressed form, we don't have a great way to handle .t(). Instead, we keep track of how often we've called transpose on our tensor, and if it's an unexpected number we throw an error. When the first argument is sparse, we expect an even number of calls to transpose, while when the second argument is sparse, we expect an odd number of calls. This is because we support second argument sparse matrix multiplications by using transpose properties. **to_sparse_semi_structured** This is a conversion function to convert a dense tensor and a semi-structured sparse bool mask into a subclass. Currently, we must pass in a bool mask, since we can't infer it becuase there may be additional zero elements in the dense tensor, so `tensor !=0` is not 2:4 sparse. Once we add either a method to derive the mask from the dense tensor or cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's own helper functions to create the metadata mask. **User Details** We have implemented support for the following ops for `torch.float16` and `torch.int8`: ``` torch.addmm(bias, dense, sparse.t()) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` The end user interface to accelerate a nn.Linaer module with the subclass would look like this: ``` from torch.sparse import to_sparse_semi_structured mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool() linear = Model(128, 128).half().cuda() linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight, mask=linear.weight.bool()) ``` This also updates tests and the `torch.sparse` module docstring to reflect these changes. cc alexsamardzic nikitaved pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
This PR adds in support for semi-structured sparsity via a tensor subclass. It currently uses the CUTLASS kernels merged in PR #100881. In the future we plan to add in cuSPARSELt support (see the other PRs in the stack), which will give us larger performance gains. This PR adds in 2 things: - a Tensor subclass, `SparseSemiStructuredTensor` to store the sparse tensor in copmressed form and override `__torch_dispatch__`. - a conversion function that takes in a dense tensor and a semi-structured sparse bool mask and creates an instance of the subclass. **SparseSemiStructuredTensor** The subclass stores the dense tensor in a contiguous flattened tensor for future compatability with cuSPARSELt, which expects this format. Note that the CUTLASS kernels do not have this limitation, as the specified values and the metadata are passed separately in `_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings [here](#103700) for faster matmul, better dtype converage, and relaxed shape constraints. Since we currently don't have a way to go back from the sparse representation to the dense representation, and we store the weights in compressed form, we don't have a great way to handle .t(). Instead, we keep track of how often we've called transpose on our tensor, and if it's an unexpected number we throw an error. When the first argument is sparse, we expect an even number of calls to transpose, while when the second argument is sparse, we expect an odd number of calls. This is because we support second argument sparse matrix multiplications by using transpose properties. **to_sparse_semi_structured** This is a conversion function to convert a dense tensor and a semi-structured sparse bool mask into a subclass. Currently, we must pass in a bool mask, since we can't infer it becuase there may be additional zero elements in the dense tensor, so `tensor !=0` is not 2:4 sparse. Once we add either a method to derive the mask from the dense tensor or cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's own helper functions to create the metadata mask. **User Details** We have implemented support for the following ops for `torch.float16` and `torch.int8`: ``` torch.addmm(bias, dense, sparse.t()) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` The end user interface to accelerate a nn.Linaer module with the subclass would look like this: ``` from torch.sparse import to_sparse_semi_structured mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool() linear = Model(128, 128).half().cuda() linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight, mask=linear.weight.bool()) ``` This also updates tests and the `torch.sparse` module docstring to reflect these changes. ghstack-source-id: 1ebad9ebcf9df1e275f449459aec48ccc3e80639 Pull Request resolved: #102135
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.
In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.
This PR adds in 2 things:
SparseSemiStructuredTensor
to store thesparse tensor in copmressed form and override
__torch_dispatch__
.semi-structured sparse bool mask and creates an instance of the
subclass.
SparseSemiStructuredTensor
The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
_structured_sparse_linear
. In the future we can use the cuSPARSELT bindingshere for faster matmul, better dtype converage, and relaxed shape
constraints.
Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().
Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.
to_sparse_semi_structured
This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so
tensor !=0
is not 2:4sparse.
Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.
User Details
We have implemented support for the following ops for
torch.float16
and
torch.int8
:The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:
This also updates tests and the
torch.sparse
module docstring toreflect these changes.
cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer