In [14]:
import autograd
from autograd import numpy as np
import scipy as sp
import paragami
from paragami import autograd_supplement_lib

from autograd.test_util import check_grads
from numpy.testing import assert_array_almost_equal


In [21]:
n_groups = 10
n_per_group = 5
n_obs = n_groups * n_per_group

x = np.random.random(n_obs)
groups = np.repeat(np.arange(0, n_groups), n_per_group)

In [22]:
bincount_result = np.bincount(groups, x)

In [23]:
cols = np.arange(0, n_obs)
data = np.ones(n_obs)
rows = groups
grouping_mat = sp.sparse.csr_matrix((data, (rows, cols)), (np.max(groups) + 1, n_obs))
assert_array_almost_equal(grouping_mat @ x, bincount_result)

In [24]:
get_grouped_sum, _ = autograd_supplement_lib.get_sparse_product(grouping_mat) 
assert_array_almost_equal(get_grouped_sum(x), bincount_result)

In [25]:
check_grads(get_grouped_sum)(x)

In [34]:
# Sparse multiplication doesn't work for 2d arrays
check_grads(get_grouped_sum, modes=['fwd'])(x2) # ok

In [88]:
from autograd.core import primitive, defvjp, defjvp

def get_sparse_product(z_mat):
    """
    Return an autograd-compatible function that calculates
    ``z_mat @ a`` and ``z_mat.T @ a`` when ``z_mat`` is a sparse matrix.

    Parameters
    ------------
    z_mat: A 2d matrix
        The matrix by which to multiply.  The matrix can be dense, but the only
        reason to use ``get_sparse_product`` is with a sparse matrix since
        dense matrix multiplication is supported natively by ``autograd``.

    Returns
    -----------
    z_mult:
        A function such that ``z_mult(b) = z_mat @ b``.
    zt_mult:
        A function such that ``zt_mult(b) = z_mat.T @ b``.
    Unlike standard sparse matrix multiplication, ``z_mult`` and ``zt_mult``
    can be used with ``autograd``.
    """

    if z_mat.ndim != 2:
        raise ValueError(
            'get_sparse_product can only be used with 2d arrays.')

    def check_b(b):
        b = np.atleast_1d(b)
        if (b.ndim > 2):
            raise ValueError('The argument must be at most two dimensional.')
        return b
        
    @primitive
    def z_mult(b):
        return z_mat @ check_b(b)

    @primitive
    def zt_mult(b):
        return z_mat.T @ check_b(b)

    def z_mult_jvp(g, ans, b):
        return z_mult(g) # z_mat @ g
    defjvp(z_mult, z_mult_jvp)

    def z_mult_vjp(ans, b):
        def vjp(g):
            return zt_mult(g) # z_mat.T @ g
        return vjp
    defvjp(z_mult, z_mult_vjp)

    def zt_mult_jvp(g, ans, b):
        return zt_mult(g) # z_mat.T @ g
    defjvp(zt_mult, zt_mult_jvp)

    def zt_mult_vjp(ans, b):
        def vjp(g):
            return z_mult(g) # (zt_mat.T) @ g
        return vjp
    defvjp(zt_mult, zt_mult_vjp)

    return z_mult, zt_mult

z_mult, zt_mult = get_sparse_product(grouping_mat)

In [89]:
print(x2.shape)
print(grouping_mat.shape)
print((grouping_mat @ x2).shape)

check_grads(z_mult, order=4)(x2)

y2 = np.random.random((grouping_mat.shape[0], 3))
print(y2.shape)
print((grouping_mat.T).shape)
#(grouping_mat.T) @ y2
#zt_mult(y2)
check_grads(zt_mult, order=4)(y2)


(50, 3)
(10, 50)
(10, 3)
(10, 3)
(50, 10)


In [91]:
x3 = np.random.random((grouping_mat.shape[1], 3, 3))
z_mult(x3)

ValueError: The argument must be at most two dimensional.

In [36]:
check_grads(get_grouped_sum, modes=['rev'], order=1)(x2)

AssertionError: 

In [31]:
x2 = np.random.random((n_obs, 3))
grouping_mat @ x2
get_grouped_sum(x2)

check_grads(get_grouped_sum)(x2)

ValueError: operands could not be broadcast together with shapes (50,3) (3,50) (50,3) 