In this example, we show how to build sparse model from scratch using a simple MLP.
In the process we also explain how user can provide tensors in their own formats.
As reference we use the following implementation of a dense MLP:

In [1]:
import torch
class MLP(torch.nn.Module):
    def __init__(self, channel_sizes):
        super().__init__()
        self.layers = torch.nn.Sequential()
        in_out_pairs = list(zip(channel_sizes[:-1], channel_sizes[1:]))
        for idx, (in_channels, out_channels) in enumerate(in_out_pairs):
            if idx != 0:
                self.layers.append(torch.nn.ReLU())
            self.layers.append(torch.nn.Linear(in_channels, out_channels))
    def forward(self, input):
        return self.layers(input)
model = MLP([50, 40, 30, 20, 30, 10])
output = model(torch.randn(15, 50))
print(output.shape)

torch.Size([15, 10])


We are going to replace `torch.nn.Linear` with our custom `SparseLinear` module, which will call our sparse implementation of `torch.nn.functional.linear`.

In [2]:
import sten
class SparseLinear(torch.nn.Module):
    def __init__(self, input_features, output_features, weight_sparsity):
        super().__init__()
        self.weight_sparsity = weight_sparsity
        dense_weight = sten.random_mask_sparsify(
            torch.randn(output_features, input_features), frac=weight_sparsity
        )
        self.weight = sten.SparseParameterWrapper(
            sten.CscTensor.from_dense(dense_weight)
        )
        self.weight.grad_fmt = (
            sten.KeepAll(),
            torch.Tensor,
            sten.RandomFractionSparsifier(self.weight_sparsity),
            sten.CscTensor,
        )
        self.bias = torch.nn.Parameter(torch.rand(output_features))
        self.bias.grad_fmt = (
            sten.KeepAll(),
            torch.Tensor,
            sten.KeepAll(),
            torch.Tensor,
        )

    def forward(self, input):
        sparse_op = sten.sparsified_op(
            orig_op=torch.nn.functional.linear,
            out_fmt=tuple(
                [(sten.KeepAll(), torch.Tensor, sten.KeepAll(), torch.Tensor)]
            ),
            grad_out_fmt=tuple(
                [(sten.KeepAll(), torch.Tensor, sten.KeepAll(), torch.Tensor)]
            ),
        )
        return sparse_op(input, self.weight, self.bias)

The important aspect is the use of `SparseParameterWrapper` to hold the data of sparse tensors. The code above shows the sparsity configuration of weight and intermediate tensors gradients that will appear in the backward pass, although they are dense in this example. 

In [3]:
class SparseMLP(torch.nn.Module):
    def __init__(self, channel_sizes, weight_sparsity):
        super().__init__()
        self.layers = torch.nn.Sequential()
        in_out_pairs = list(zip(channel_sizes[:-1], channel_sizes[1:]))
        for idx, (in_channels, out_channels) in enumerate(in_out_pairs):
            if idx != 0:
                self.layers.append(torch.nn.ReLU())
            self.layers.append(SparseLinear(in_channels, out_channels, weight_sparsity))

    def forward(self, input):
        return self.layers(input)

Finally, after the replacement of `torch.nn.Linear` with the `SparseLinear` in the `MLP` implementation, we call it and observe the expected output.

In [4]:
model = SparseMLP([50, 40, 30, 20, 30, 10], 0.8)
output = model(torch.randn(15, 50))
print(output.shape)

torch.Size([15, 10])
