In [1]:
import torch
import sten
import scipy

Start from the dense implementation of $d = (a + b) c$

In [2]:
a = torch.randn(10, 20, requires_grad=True)
b = torch.randn(10, 20, requires_grad=True)
c = torch.randn(20, 30, requires_grad=True)
grad_d = torch.randn(10, 30)

d = torch.mm(torch.add(a, b), c)
d.backward(grad_d)

First we define custom random fraction sparsifier functioning the same as `sten.RandomFractionSparsifier`.

In [3]:
class MyRandomFractionSparsifier:
    def __init__(self, fraction):
        self.fraction = fraction

Then declare tensor in CSC format that will utilize scipy CSC implementation under the hood.

In [4]:
class MyCscTensor:
    def __init__(self, data):
        self.data = data

    @staticmethod
    def from_dense(tensor):
        return MyCscTensor(scipy.sparse.csc_matrix(tensor))

    def to_dense(self):
        return torch.from_numpy(self.data.todense())

Make the result of addition $a + b$ sparse. To achieve this, we need to replace addition operator by its sparse counterpart. For simplicity, we do not use inline sparsifier, that's why operator outputs dense `torch.Tensor` after applying `KeepAll` sparsifier. We use external random fraction sparsifier with 0.5 dropout probability and output tensor in the newly defined CSC format. The same specification is assigned to the gradient format, but nothing prevents us from applying different sparsifier and useing different format for the gradient.

In [5]:
sparse_add = sten.sparsified_op(
    orig_op=torch.add,
    out_fmt=(
        (sten.KeepAll(), torch.Tensor, MyRandomFractionSparsifier(0.5), MyCscTensor),
    ),
    grad_out_fmt=(
        (sten.KeepAll(), torch.Tensor, MyRandomFractionSparsifier(0.5), MyCscTensor),
    ),
)

Then we try to use the operator.

In [6]:
d = torch.mm(sparse_add(a, b), c)



The first error message indicates the operator implementation which is required is not registered. Here we register it and try calling the method again.

In [7]:
@sten.register_fwd_op_impl(
    operator=torch.add,
    inp=(torch.Tensor, torch.Tensor, None, None),
    out=tuple([(sten.KeepAll, torch.Tensor)]),
)
def sparse_add_fwd_impl(ctx, inputs, output_sparsifiers):
    input, other, alpha, out = inputs
    return torch.add(input, other, alpha=alpha, out=out)
d = torch.mm(sparse_add(a, b), c)



Here we see that sparsifier implementation is not registered. Let's provide it.

In [8]:
@sten.register_sparsifier_implementation(
    sparsifer=MyRandomFractionSparsifier, inp=torch.Tensor, out=MyCscTensor
)
def scalar_fraction_sparsifier_dense_coo(sparsifier, tensor):
    return sten.SparseTensorWrapper(
        MyCscTensor.from_dense(
            sten.random_mask_sparsify(tensor, frac=sparsifier.fraction)
        )
    )
d = torch.mm(sparse_add(a, b), c)



Since $a + b$ is sparse now and it is used as an input of `torch.mm`, we need to provide sparse operator implementation for it as well.

In [9]:
@sten.register_fwd_op_impl(
    operator=torch.mm,
    inp=(MyCscTensor, torch.Tensor),
    out=tuple([(sten.KeepAll, torch.Tensor)]),
)
def torch_mm_fwd_impl(ctx, inputs, output_sparsifiers):
    input1, input2 = inputs
    ctx.save_for_backward(input1, input2)
    output = torch.from_numpy(input1.wrapped_tensor.data @ input2.numpy())
    return output
d = torch.mm(sparse_add(a, b), c)

As expected, it works. The next step is to call backward pass and see what is remaining to be implemented there.

In [10]:
d = torch.mm(sparse_add(a, b), c)
try:
    d.backward(grad_d)
except sten.DispatchError as e:
    print(str(e))

Sparse operator implementation is not registered (bwd). op: <built-in method mm of type object at 0x7f033a216ea0> grad_out: (<class 'torch.Tensor'>,) grad_inp: ((<class 'sten.KeepAll'>, <class 'torch.Tensor'>), (<class 'sten.KeepAll'>, <class 'torch.Tensor'>)) inp: (<class '__main__.MyCscTensor'>, <class 'torch.Tensor'>).


Registering backward implementation for `torch.mm`.

In [11]:
@sten.register_bwd_op_impl(
    operator=torch.mm,
    grad_out=(torch.Tensor,),
    grad_inp=(
        (sten.KeepAll, torch.Tensor),
        (sten.KeepAll, torch.Tensor),
    ),
    inp=(MyCscTensor, torch.Tensor),
)
def torch_mm_bwd_impl(ctx, grad_outputs, input_sparsifiers):
    input1, input2 = ctx.saved_tensors
    [grad_output] = grad_outputs
    grad_input1 = torch.mm(grad_output, input2.T)
    grad_input2 = torch.from_numpy(input1.wrapped_tensor.data.transpose() @ grad_output)
    return grad_input1, grad_input2

d = torch.mm(sparse_add(a, b), c)
try:
    d.backward(grad_d)
except sten.DispatchError as e:
    print(str(e))


Sparse operator implementation is not registered (bwd). op: <built-in method add of type object at 0x7f033a216ea0> grad_out: (<class '__main__.MyCscTensor'>,) grad_inp: ((<class 'sten.KeepAll'>, <class 'torch.Tensor'>), (<class 'sten.KeepAll'>, <class 'torch.Tensor'>), None, None) inp: (<class 'torch.Tensor'>, <class 'torch.Tensor'>, None, None).


Backward implementation for `torch.add`:

In [12]:
@sten.register_bwd_op_impl(
    operator=torch.add,
    grad_out=(MyCscTensor,),
    grad_inp=(
        (sten.KeepAll, torch.Tensor),
        (sten.KeepAll, torch.Tensor),
        None,
        None,
    ),
    inp=(torch.Tensor, torch.Tensor, None, None),
)
def torch_add_bwd_impl(ctx, grad_outputs, input_sparsifiers):
    [grad_output] = grad_outputs
    dense_output = grad_output.wrapped_tensor.to_dense()
    return dense_output, dense_output, None, None

d = torch.mm(sparse_add(a, b), c)
d.backward(grad_d)

Now backward pass is also fully functional.