Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[core][pruning][sparse][feature] SparseSemiStructured tensor subclass
This PR adds in support for semi-structured sparsity via a tensor subclass. It currently uses the CUTLASS kernels merged in PR #100881. In the future we plan to add in cuSPARSELt support (see the other PRs in the stack), which will give us larger performance gains. This PR adds in 2 things: - a Tensor subclass, `SparseSemiStructuredTensor` to store the sparse tensor in copmressed form and override `__torch_dispatch__`. - a conversion function that takes in a dense tensor and a semi-structured sparse bool mask and creates an instance of the subclass. The subclass stores the dense tensor in a contiguous flattened tensor for future compatability with cuSPARSELt, which expects this format. Note that the CUTLASS kernels do not have this limitation, as the specified values and the metadata are passed separately in `_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings [here](#103700) for faster matmul, better dtype converage, and relaxed shape constraints. Since we currently don't have a way to go back from the sparse representation to the dense representation, and we store the weights in compressed form, we don't have a great way to handle .t(). Instead, we keep track of how often we've called transpose on our tensor, and if it's an unexpected number we throw an error. When the first argument is sparse, we expect an even number of calls to transpose, while when the second argument is sparse, we expect an odd number of calls. This is because we support second argument sparse matrix multiplications by using transpose properties. We have implemented support for the following ops for `torch.float16` and `torch.int8`: ``` torch.addmm(bias, dense, sparse.t()) torch.mm(dense, sparse) torch.mm(sparse, dense) aten.linear.default aten.t.default aten.t.detach ``` The end user interface to accelerate a nn.Linaer module with the subclass would look like this: ``` from torch.sparse import to_sparse_semi_structured mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool() linear = Model(128, 128).half().cuda() linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight, mask=linear.weight.bool()) ``` This also updates tests and the `torch.sparse` module docstring to refleect these changes. ops supported: ghstack-source-id: bef558914004d60a149964afcad7a0b87ae199df Pull Request resolved: #102135
- Loading branch information