In [1]:
import paragami
import autograd
import autograd.numpy as np

import time

In [2]:
np.random.seed(42)

group_size = 3
num_groups = 10
d = group_size * num_groups

def get_pd_mat(d): 
    a = np.random.random((d, d))
    a = a + a.T + np.eye(d)
    return a

group_mats = np.array([ get_pd_mat(group_size) for g in range(num_groups) ])
pattern = paragami.NumericArrayPattern((num_groups, group_size))
def f(x_array, w):
    return 0.5 * np.einsum('n,nij,ni,nj', w, group_mats, x_array, x_array)

f_flat = paragami.FlattenFunctionInput(
    f, argnums=0, free=True, patterns=pattern)

x = np.random.random((num_groups, group_size))
x_flat = pattern.flatten(x, free=True)
w = np.ones(num_groups)
f(x, w)

f_grad = autograd.grad(f_flat, argnum=0)
f_hess = autograd.hessian(f_flat, argnum=0)

hess_time = time.time()
h0 = f_hess(x_flat, w)
hess_time = time.time() - hess_time
print('Hessian time: ', hess_time)

Hessian time:  0.011384963989257812


In [3]:
inds = []
for g in range(num_groups):
    x_bool = pattern.empty_bool(False)
    x_bool[g, :] = True
    inds.append(pattern.flat_indices(x_bool, free=True))
inds = np.array(inds)

In [6]:
sparse_hess = paragami.SparseBlockHessian(f_flat, inds)
block_hess = sparse_hess.get_block_hessian(x_flat, w)

np.linalg.norm(block_hess - h0)

0.0