# Library Nodes

`LibraryNode`s facilitate the abstraction of common operations, enabling easy reuse in different SDFGs and Data-Centric progrmas. This tutorial covers creating `LibraryNode`s with different implelementations (called *expansions* or `ExpandTransformation`s), and how to use them in SDFGs or Data-Centric programs.

For this tutorial, we use as an example the SDDMM (sampled dense-dense matrix multiplication) operation:
$$\bm{D} = \bm{A} \odot \left(\bm{B} \times \bm{C}\right)$$
$\bm{A}$ is a sparse matrix, while $\bm{B}$ and $\bm{C}$ are dense matrices. The ouput $\bm{D}$ is the Hadamard (element-wise) product of $\bm{A}$ and the matrix product of $\bm{B}$ and $\bm{C}$, and has the same sparsity pattern as $\bm{A}$. Effectively, $\bm{A}$ *samples* (or filters) the dense product $\bm{B} \times \bm{C}$. Assuming $\bm{A}$ is in CSR format, the SDDMM algorithm is as follows:

```python
# A (D) has shape (M, N) with nnz non-zero values
# A_data (D_data) is the non-zero values of A (D)
# A_indices (D_indices) is the column indices of A (D)
# A_indptr (D_indptr) is the row pointers of A (D)
# B has shape (M, K)
# C has shape (K, N)
D_data = np.zeros_like(A_data)
D_indices = np.copy(A_indices)
D_indptr = np.copy(A_indptr)
for i in range(M):
    for j in range(A_indptr[i], A_indptr[i + 1]):
        for k in range(K):
            D_data[j] += B[i, k] * C[k, A_indices[j]]
        D_data[j] *= A_data[j]
```

We start by creating a LibraryNode that represents the SDDMM operation. We create a class that inherits from `dace.sdfg.nodes.LibraryNode`, and we decorate it with `@dace.library.node`. The class must include an `implementations` dictionary, and an `defaul_implementation` string, which we will discuss later. The `LibraryNode`'s initialization method must call the initialization method of the super-class and pass the node's name, location, inputs, and outputs. The inputs and the outputs are the node's connector names.

In [3]:
import dace

from dace import library
from dace.sdfg import nodes
from dace.transformation import ExpandTransformation
from typing import Dict


@library.node
class MySDDMM(nodes.LibraryNode):

    # We will fill those later
    implementations: Dict[str, ExpandTransformation] = {}
    default_implementation: str = None

    def __init__(self, name, location=None):
        super().__init__(name,
                         location=location,
                         inputs={'_a_data', '_a_indices', '_a_indptr', '_b', '_c'},
                         outputs={'_d_data', '_d_indices', '_d_indptr'})


A `LibraryNode` can have different implemenetations (expansions), generic or specialized for specific architectures. These implementations can use the SDFG API but they can also be written as Data-Centric programs. We start by creating a *pure* expansion, which is an implementation that does not use any components, e.g., libraries, external to DaCe. We write this expansion as a Data-Centric Python program:

In [4]:
@library.expansion
class MySDDMMPureExpansion(ExpandTransformation):

    environments = []

    @staticmethod
    def expansion(node, state, sdfg):

        # Find shapes and datatypes of inputs and outputs

        # A matrix
        a_indptr_name = list(state.in_edges_by_connector(node, '_a_indptr'))[0].data.data
        a_indptr_arr = sdfg.arrays[a_indptr_name]
        a_data_name = list(state.in_edges_by_connector(node, '_a_data'))[0].data.data
        a_data_arr = sdfg.arrays[a_data_name]
        a_rowsp1 = a_indptr_arr.shape[0]
        a_nnz = a_data_arr.shape[0]
        a_dtype = a_data_arr.dtype

        # B matrix
        b_name = list(state.in_edges_by_connector(node, '_b'))[0].data.data
        b_arr = sdfg.arrays[b_name]
        b_rows = b_arr.shape[0]
        b_cols = b_arr.shape[1]
        b_dtype = b_arr.dtype

        # C matrix
        c_name = list(state.in_edges_by_connector(node, '_c'))[0].data.data
        c_arr = sdfg.arrays[c_name]
        c_rows = c_arr.shape[0]
        c_cols = c_arr.shape[1]
        c_dtype = c_arr.dtype

        # D matrix
        # We assume that it has the same shape and datatype as A

        @dace.program
        def sddmm_pure(_a_data: a_dtype[a_nnz], _a_indices: dace.int32[a_nnz], _a_indptr: dace.int32[a_rowsp1],
                       _b: b_dtype[b_rows, b_cols], _c: c_dtype[c_rows, c_cols],
                       _d_data: a_dtype[a_nnz], _d_indices: dace.int32[a_nnz], _d_indptr: dace.int32[a_rowsp1]):

            _d_data[:] = 0
            _d_indices[:] = _a_indices
            _d_indptr[:] = _a_indptr

            for i in dace.map[0:a_rowsp1 - 1]:
                for j in dace.map[_a_indptr[i]:_a_indptr[i + 1]]:
                    for k in dace.map[0:b_cols]:
                        _d_data[j] += _b[i, k] * _c[k, _a_indices[j]]
                    _d_data[j] *= _a_data[j]

        return sddmm_pure.to_sdfg()



To enable the above expansion, we add it to the `implementations` dictionary:

In [5]:
@library.node
class MySDDMM(nodes.LibraryNode):

    implementations: Dict[str, ExpandTransformation] = {'pure': MySDDMMPureExpansion}
    default_implementation: str = None

    def __init__(self, name, location=None):
        super().__init__(name,
                         location=location,
                         inputs={'_a_data', '_a_indices', '_a_indptr', '_b', '_c'},
                         outputs={'_d_data', '_d_indices', '_d_indptr'})


Now that there is at least one expansion for the `LibraryNode`, we can use it in an SDFG like any other `CodeNode` or `Tasklet`. However, it is also possible to automate its use in Data-Centric Python programs. We us as an example the inference formula for a single-layer of the Vanilla Attention (VA) Graph Neural Network (GNN):
$$\bm{H}^\prime = \sigma\left(\bm{A} \odot \left(\bm{H} \times \bm{H}^T\right) \times \bm{H} \times \bm{W}\right)$$
We implement the above formula as a Data-Centric Python program:

In [10]:
import numpy as np

# A is N x N, H is N x K0, W is K0 x K1, H' is N x K1
N, K0, K1, NNZ = (dace.symbol(s) for s in ('N', 'K0', 'K1', 'NNZ'))

@dace.program
def va_inference_layer(A_data: dace.float32[NNZ], A_indices: dace.int32[NNZ], A_indptr: dace.int32[N + 1],
                       H: dace.float32[N, K0],
                       W: dace.float32[K0, K1],
                       H_prime: dace.float32[N, K1]):
    
    # S = A \odot (H \times H^T)
    # S_data = np.empty_like(A_data)
    # S_indices = np.empty_like(A_indices)
    # S_indptr = np.empty_like(A_indptr)
    # dace.sddmm_op(A_data, A_indices, A_indptr, W, H, np.transpose(H), S_data, S_indices, S_indptr)
    S_data, S_indices, S_indptr = dace.sddmm_op(A_data, A_indices, A_indptr, H, np.transpose(H))

    H_prime[:] = np.maximum(0, dace.csrmm_op(S_data, S_indices, S_indptr, H) @ W)


To be able to convert the above program to SDFG, we need to define `SDDMM_op` and `CSRMM_op`:

In [17]:
from dace.frontend.common import op_repository 


@op_repository.replaces('dace.sddmm_op')
def sddmm_libnode(pv: 'ProgramVisitor',
                  sdfg: dace.SDFG,
                  state: dace.SDFGState,
                  A_data: str,
                  A_indices: str,
                  A_indptr: str,
                  B: str,
                  C: str):
    # Input access nodes
    A_data_acc, A_indices_acc, A_indptr_acc, B_acc, C_acc = (
        state.add_access(n) for n in (A_data, A_indices, A_indptr, B, C))
    # Output D
    A_data_arr = sdfg.arrays[A_data]
    A_indices_arr = sdfg.arrays[A_indices]
    A_indptr_arr = sdfg.arrays[A_indptr]
    D_data, D_data_arr = sdfg.add_temp_transient_like(A_data_arr)
    D_indices, D_indices_arr = sdfg.add_temp_transient_like(A_indices_arr)
    D_indptr, D_indptr_arr = sdfg.add_temp_transient_like(A_indptr_arr)
    D_data_acc, D_indices_acc, D_indptr_acc = (state.add_access(n) for n in (D_data, D_indices, D_indptr))

    libnode = MySDDMM('sddmm')
    state.add_node(libnode)

    # Connect nodes
    state.add_edge(A_indptr_acc, None, libnode, '_a_indptr', dace.Memlet(A_indptr))
    state.add_edge(A_indices_acc, None, libnode, '_a_indices', dace.Memlet(A_indices))
    state.add_edge(A_data_acc, None, libnode, '_a_data', dace.Memlet(A_data))
    state.add_edge(B_acc, None, libnode, '_b', dace.Memlet(B))
    state.add_edge(C_acc, None, libnode, '_c', dace.Memlet(C))
    state.add_edge(libnode, '_d_data', D_data_acc, None, dace.Memlet(D_data))
    state.add_edge(libnode, '_d_indices', D_indices_acc, None, dace.Memlet(D_indices))
    state.add_edge(libnode, '_d_indptr', D_indptr_acc, None, dace.Memlet(D_indptr))

    return [D_data, D_indices, D_indptr]


@op_repository.replaces('dace.csrmm_op')
def csrmm_libnode(pv: 'ProgramVisitor',
                  sdfg: dace.SDFG,
                  state: dace.SDFGState,
                  A_data: str,
                  A_indices: str,
                  A_indptr: str,
                  B: str):
    # Input access nodes
    A_data_acc, A_indices_acc, A_indptr_acc, B_acc = (state.add_access(n) for n in (A_data, A_indices, A_indptr, B))
    # Output C
    A_indptr_arr = sdfg.arrays[A_indptr]
    rows = A_indptr_arr.shape[0] - 1
    cols = sdfg.arrays[B].shape[1]
    A_data_arr = sdfg.arrays[A_data]
    dtype = A_data_arr.dtype
    C, C_arr = sdfg.add_temp_transient([rows, cols], dtype)
    C_acc = state.add_write(C)

    from dace.libraries.sparse import CSRMM
    libnode = CSRMM('csrmm')
    state.add_node(libnode)

    # Connect nodes
    state.add_edge(A_indptr_acc, None, libnode, '_a_rows', dace.Memlet(A_indptr))
    state.add_edge(A_indices_acc, None, libnode, '_a_cols', dace.Memlet(A_indices))
    state.add_edge(A_data_acc, None, libnode, '_a_vals', dace.Memlet(A_data))
    state.add_edge(B_acc, None, libnode, '_b', dace.Memlet(B))
    state.add_edge(libnode, '_c', C_acc, None, dace.Memlet(C))

    return [C]


In [18]:
sdfg = va_inference_layer.to_sdfg()

In [19]:
sdfg.save('sddmm_tutorial.sdfg')

'663bc3186b54f9811d87ff2ae3ec4f58847e7efad341b9672d2e1c3fc6c9ce0c'