In [1]:
import autograd
from autograd import numpy as np
import scipy as sp
import paragami
from paragami import autograd_supplement_lib

from autograd.test_util import check_grads
from numpy.testing import assert_array_almost_equal


In [16]:
n_groups = 4
n_per_group = 2
n_obs = n_groups * n_per_group

x = np.random.random(n_obs)
groups = np.repeat(np.arange(0, n_groups), n_per_group)

In [17]:
bincount_result = np.bincount(groups, x)

In [18]:
# cols = np.arange(0, n_obs)
# data = np.ones(n_obs)
# rows = groups
# grouping_mat = sp.sparse.csr_matrix((data, (rows, cols)), (np.max(groups) + 1, n_obs))
# assert_array_almost_equal(grouping_mat @ x, bincount_result)

In [19]:
# get_grouped_sum, _ = autograd_supplement_lib.get_sparse_product(grouping_mat) 
# assert_array_almost_equal(get_grouped_sum(x), bincount_result)

aggregate = autograd_supplement_lib.get_grouped_aggregator(groups)

In [20]:
check_grads(aggregate)(x)

In [21]:
x2 = np.random.random((n_obs, 3))
check_grads(aggregate)(x2)

In [28]:
# Well, this works.

groups
def ungroup(x):
    assert len(x) == n_groups
    return x[groups, :]

x = np.random.random((n_groups, 2, 2))
print(autograd.jacobian(ungroup)(x).shape)
print(ungroup(x).shape)
print(x.shape[1:])

(8, 2, 2, 4, 2, 2)
(8, 2, 2)
(2, 2)


In [29]:
np.random.random(5).shape[1:]

()

In [40]:
from autograd.core import primitive, defvjp, defjvp


@primitive
def grouped_sum(x, groups, num_groups=None):
    x = np.atleast_1d(x)
    groups = np.atleast_1d(groups).astype('int64')
    if (groups.ndim > 1):
        raise ValueError('groups must be a vector.')

    n_obs = len(groups)
    if x.shape[0] != n_obs:
        raise ValueError('The first dimension of x must match the length of groups')
    max_group = np.max(groups)
    if num_groups is None:
        num_groups = max_group + 1
    else:
        if max_group >= num_groups:
            raise ValueError(
                'The largest group is >= the number of groups.')

    result = np.zeros((num_groups, ) + x.shape[1:])
    for n in range(n_obs):
        result[groups[n], :] += x[n, :]
    return result


assert_array_almost_equal(
    grouped_sum(x2, groups),
    np.array([np.bincount(groups, x2[:, d]) for d in range(x2.shape[1])]).T)



In [41]:
def ungroup(v, groups):
    return v[groups, :]

def grouped_sum_vjp(ans, x, groups, num_groups=None):
    def vjp(v):
        return ungroup(v, groups)
    return vjp
defvjp(grouped_sum, grouped_sum_vjp)

def grouped_sum_jvp(v, ans, x, groups, num_groups=None):
    return grouped_sum(v, groups, num_groups=num_groups)
defjvp(grouped_sum, grouped_sum_jvp)



In [43]:
check_grads(grouped_sum)(x2, groups)
check_grads(grouped_sum)(x2, groups, num_groups=n_groups + 4)