This example demonstrates the API to register custom operator implementations for specific input and output tensor formats. This example demonstrates customization API to define new sparse tensor formats and sparsifier. It shows how to register custom operator and sparsifier implementations for them.

In [1]:
import torch
import sten
import scipy



Start from the dense implementation of $d = (a + b) c$.

In [2]:
a = torch.randn(10, 20, requires_grad=True)
b = torch.randn(10, 20, requires_grad=True)
c = torch.randn(20, 30, requires_grad=True)
grad_d = torch.randn(10, 30)

d = torch.mm(torch.add(a, b), c)
d.backward(grad_d)

First we define a custom random fraction sparsifier functioning the same as `sten.RandomFractionSparsifier`. The sparsifier implementation is not defined here since it is characterized not only by the sparsifier itself but also by the input and output tensor formats. The sparsifier class only needs to declare its configurable parameters.

In [3]:
class MyRandomFractionSparsifier:
    def __init__(self, fraction):
        self.fraction = fraction

Then declare a tensor in CSC format that will utilize scipy CSC implementation under the hood.

In [4]:
class MyCscTensor:
    def __init__(self, data):
        self.data = data

    def to_dense(self):
        return torch.from_numpy(self.data.todense())

Then we make the result of addition $a + b$ sparse. To achieve this, we need to replace the addition operator with its sparse counterpart. For simplicity, we do not use an inline sparsifier, which is why the operator outputs a dense `torch.Tensor` after applying the `KeepAll` sparsifier. We use an external random fraction sparsifier with 0.5 dropout probability and output the tensor in the newly defined CSC format. The same specification is assigned to the gradient format, but nothing prevents us from applying a different sparsifier and using a different format for the gradient.

In [5]:
sparse_add = sten.sparsified_op(
    orig_op=torch.add,
    out_fmt=(
        (sten.KeepAll(), torch.Tensor,
         MyRandomFractionSparsifier(0.5), MyCscTensor),
    ),
    grad_out_fmt=(
        (sten.KeepAll(), torch.Tensor,
         MyRandomFractionSparsifier(0.5), MyCscTensor),
    ),
)

Then we try to use the operator.

In [6]:
try:
    d = torch.mm(sparse_add(a, b), c)
except sten.DispatchError as e:
    print(str(e))

Sparsifier implementation is not registered:
@sten.register_sparsifier_implementation(
    sparsifier=MyRandomFractionSparsifier, inp=torch.Tensor, out=MyCscTensor
)
def my_sparsifier_implementation(sparsifier, tensor, grad_fmt=None):
    return sparsified_tensor_wrapper


Here we see that sparsifier implementation is not registered. Let's provide it.

In [7]:
@sten.register_sparsifier_implementation(
    sparsifier=MyRandomFractionSparsifier, inp=torch.Tensor, out=MyCscTensor
)
def torch_tensor_to_csc_random_fraction(sparsifier, tensor, grad_fmt=None):
    return sten.SparseTensorWrapper.wrapped_from_dense(
        MyCscTensor(scipy.sparse.csc_matrix(sten.random_mask_sparsify(tensor, sparsifier.fraction))),
        tensor,
        grad_fmt,
    )
    
try:
    d = torch.mm(sparse_add(a, b), c)
except sten.DispatchError as e:
    print(str(e))

@sten.register_fwd_op_impl(
    operator=torch.spmm,
    inp=(MyCscTensor, torch.Tensor),  
    out=None,  # default (dense)
)
def my_operator(ctx, inputs, output_sparsifiers):
    return outputs
Fallback to dense implementation.


Since $a + b$ is sparse now and it is used as an input of `torch.mm`, we need to provide sparse operator implementation for it as well.

In [8]:
@sten.register_fwd_op_impl(
    operator=torch.mm,
    inp=(MyCscTensor, torch.Tensor),
    out=[(sten.KeepAll, torch.Tensor)],
)
def torch_mm_fwd_impl(ctx, inputs, output_sparsifiers):
    input1, input2 = inputs
    ctx.save_for_backward(input1, input2)
    output = torch.from_numpy(input1.wrapped_tensor.data @ input2.numpy())
    return output
d = torch.mm(sparse_add(a, b), c)

As expected, it works. The next step is to call the backward pass and see what remains to be implemented there.

In [9]:
d = torch.mm(sparse_add(a, b), c)
try:
    d.backward(grad_d)
except sten.DispatchError as e:
    print(str(e))

Sparse operator implementation is not registered (bwd):
@sten.register_bwd_op_impl(
    operator=torch.spmm,
    grad_out=None,  # default (dense)
    grad_inp=None,  # default (dense)
    inp=(MyCscTensor, torch.Tensor),  
)
def my_operator(ctx, grad_outputs, input_sparsifiers):
    return grad_inputs


Registering the backward implementation for `torch.mm`.

In [10]:
@sten.register_bwd_op_impl(
    operator=torch.mm,
    grad_out=[torch.Tensor],
    grad_inp=(
        (sten.KeepAll, torch.Tensor),
        (sten.KeepAll, torch.Tensor),
    ),
    inp=(MyCscTensor, torch.Tensor),
)
def torch_mm_bwd_impl(ctx, grad_outputs, input_sparsifiers):
    input1, input2 = ctx.saved_tensors
    [grad_output] = grad_outputs
    grad_input1 = torch.mm(grad_output, input2.T)
    grad_input2 = torch.from_numpy(
        input1.wrapped_tensor.data.transpose() @ grad_output)
    return grad_input1, grad_input2

d = torch.mm(sparse_add(a, b), c)
d.backward(grad_d)

Now backward pass is also fully functional.