In [1]:
import tensorflow as tf
import tensorflow.keras as keras
print(tf.__version__)

2.4.0


# Sparse Matrix Multiplication
TensorFlow has a function [`tf.sparse.sparse_dense_matmul`](https://www.tensorflow.org/api_docs/python/tf/sparse/sparse_dense_matmul) that multiplies a sparse matrix with a dense matrix. In this exact order

$$
C_{den} = A_{sp} \cdot B_{den} 
$$

This allows Matrix-Vector multiplications that are common in Neural Networks, e.g.,

$$
h_t = f(W_{sp} \cdot h_{t-1})
$$

However, it's also common to program it as Vector-Matrix multiplication. For example, in tf2/Keras multiplications are always coded this way

$$
h_t = f(h_{t-1} \cdot W_{sp})
$$

However, `tf.sparse` has not implemented a `dense_sparse_matmul` yet. 
Thus, we develop wrapper function in the meanwhile.

## Toy Example

In [2]:
h = tf.constant([1.0, 2.0, 3.0])
W = tf.sparse.SparseTensor(
    indices=([0, 1], [1, 1], [1, 2], [2, 0], [2, 2]),
    values=[1.0, 2.0, 3.0, 4.0, 5.0],
    dense_shape=(3,3))

tf.print(h, "\n")
tf.print(tf.sparse.to_dense(W))

[1 2 3] 

[[0 1 0]
 [0 2 3]
 [4 0 5]]


## Sparse-Dense Multiplication
The dense vector `h` is not a "matrix". 
It has a `shape=(3,)` what is expected to be `shape=(3,1)`.

In [3]:
h

<tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 2., 3.], dtype=float32)>

In [4]:
h_col = tf.reshape(h, (-1, 1))
h_col

<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[1.],
       [2.],
       [3.]], dtype=float32)>

The result `net` will also be a column vector with `shape=(3,1)`.

In [5]:
net =  tf.sparse.sparse_dense_matmul(W, h_col)
tf.print(net)

[[2]
 [13]
 [19]]


## Dense-Sparse Multiplication
In most NN libraries a row vector is used, e.g.,

In [6]:
h_row = tf.reshape(h, (1, -1))
tf.print(h_row)
h_row

[[1 2 3]]


<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[1., 2., 3.]], dtype=float32)>

Thus, we need to transpose the row vector to a column vector before applying `sparse_dense_matmul` and transpose the result back into a row vector

In [7]:
net =  tf.transpose(tf.sparse.sparse_dense_matmul(W, tf.transpose(h_row)))
tf.print(net)

[[2 13 19]]


## Function Wrapper
see `keras_tweaks.dense_sparse_matmul`

In [8]:
def dense_sparse_matmul(denV: tf.Tensor, spW: tf.SparseTensor) -> tf.Tensor:
    # reshape to list of row vectors if neccessary
    if denV.shape.ndims == 1:
        denV = tf.reshape(denV, (1, -1))
    # transpose -> multiply -> transpose back
    return tf.transpose(tf.sparse.sparse_dense_matmul(spW, tf.transpose(denV)))

In [9]:
net = dense_sparse_matmul(h, W)
tf.print(net)

[[2 13 19]]
