In [1]:
from paragami.base_patterns import Pattern
from paragami.pattern_containers import register_pattern_json

import jax
import jax.numpy as np
import numpy as onp

import math

import time
import timeit
#from jax import custom_jvp




In [2]:

def assert_equal(x, y, tol=1e-12):
    assert(onp.max(onp.abs(x - y)) < tol)
    
def time_jit(f):
    tic = time.time()
    f()
    print('1st time: ', time.time() - tic)

    tic = time.time()
    f()
    print('2nd time: ', time.time() - tic)

    
def mark_tic(tic, op):
    print(f'{op}:\t{time.time() - tic}')
    return time.time()

## Pattern array

In [38]:
import paragami

k_approx = 30
dim = 4

base_pattern = paragami.PSDSymmetricMatrixPattern(size=dim)
covar_array_pattern = \
        paragami.PatternArray(array_shape = (k_approx, ), \
                    base_pattern = base_pattern)

covar_array = covar_array_pattern.random()

covar_array_flattened = covar_array_pattern.flatten(covar_array, free = False)
covar_array_flattened_free = covar_array_pattern.flatten(covar_array, free = True)

In [11]:
def fun(covar_array): 
    return (covar_array**2).sum()

# flattened function
fun_flattened = paragami.FlattenFunctionInput(original_fun=fun, 
                                patterns = covar_array_pattern,
                                free = False,
                                argnums = 0) 

# flattened and freed function
fun_flattened_free = paragami.FlattenFunctionInput(original_fun=fun, 
                                patterns = covar_array_pattern,
                                free = True,
                               argnums = 0) 

assert_equal(
    fun_flattened(covar_array_flattened),
    fun_flattened_free(covar_array_flattened_free))


In [15]:
grad_fun_flattened = jax.jit(jax.grad(fun_flattened))
grad_fun_flattened_free = jax.jit(jax.grad(fun_flattened_free))

time_jit(lambda: grad_fun_flattened(covar_array_flattened))
time_jit(lambda: grad_fun_flattened_free(covar_array_flattened_free))

1st time:  0.8329370021820068
2nd time:  0.0001614093780517578
1st time:  14.864873886108398
2nd time:  0.0004405975341796875


In [34]:
print(jax.lax.map(lambda x: x + 1, np.arange(0, 10)))


[ 1  2  3  4  5  6  7  8  9 10]


In [45]:

import itertools

__array_shape = (2, 3)
__array_ranges = [range(0, t) for t in __array_shape]
__array_ranges


empty_pattern = base_pattern.empty(valid=True)
__shape = tuple(__array_shape) + empty_pattern.shape

repeated_array = np.array(
    [empty_pattern
     for item in itertools.product(*__array_ranges)])
empty_orig = np.reshape(repeated_array, __shape)


In [61]:
# Works but inefficient, you don't want to create the entries array.
entries = np.array([ i for i in itertools.product(*__array_ranges) ])
repeated_array = jax.lax.map(lambda x: empty_pattern, entries)
empty_jax = np.reshape(repeated_array, __shape)
assert_equal(empty_jax, empty_orig)

## PSD patterns

In [8]:
assert False

AssertionError: 

In [4]:

def _sym_index(k1, k2):
    """
    Get the index of an entry in a folded symmetric array.

    Parameters
    ------------
    k1, k2: int
        0-based indices into a symmetric matrix.

    Returns
    --------
    int
        Return the linear index of the (k1, k2) element of a symmetric
        matrix where the triangular part has been stacked into a vector.
    """
    def ld_ind(k1, k2):
        return int(k2 + k1 * (k1 + 1) / 2)

    if k2 <= k1:
        return ld_ind(k1, k2)
    else:
        return ld_ind(k2, k1)


def _vectorize_ld_matrix(mat):
    """
    Linearize the lower diagonal of a square matrix.
    """
    nrow, ncol = np.shape(mat)
    if nrow != ncol:
        raise ValueError('mat must be square')
    return mat[np.tril_indices(nrow)]


def _unvectorize_ld_matrix(vec):
    """
    Invert the mapping of `_vectorize_ld_matrix`.
    """
    mat_size = int(0.5 * (math.sqrt(1 + 8 * vec.size) - 1))
    if mat_size * (mat_size + 1) / 2 != vec.size:
        raise ValueError('Vector is an impossible size')

    mat = np.zeros((mat_size, mat_size))
    inds = np.tril_indices(mat_size)
    return(jax.ops.index_update(mat, inds, vec))


def _exp_matrix_diagonal(mat):
    assert mat.shape[0] == mat.shape[1]
    dim = mat.shape[0]
    diag_inds = (np.arange(dim), np.arange(dim))
    exp_diags = np.exp(np.diag(mat))
    return(jax.ops.index_update(mat, diag_inds, exp_diags))


def _log_matrix_diagonal(mat):
    assert mat.shape[0] == mat.shape[1]
    dim = mat.shape[0]
    diag_inds = (np.arange(dim), np.arange(dim))
    log_diags = np.log(np.diag(mat))
    return(jax.ops.index_update(mat, diag_inds, log_diags))


def _pack_posdef_matrix(mat, diag_lb=0.0):
    k = mat.shape[0]
    # mat_lb = mat - np.make_diagonal(
    #     np.full(k, diag_lb), offset=0, axis1=-1, axis2=-2)
    mat_lb = mat - np.diag(np.full(k, diag_lb))
    return _vectorize_ld_matrix(
        _log_matrix_diagonal(np.linalg.cholesky(mat_lb)))


def _unpack_posdef_matrix(free_vec, diag_lb=0.0):
    mat_raw = _unvectorize_ld_matrix(free_vec)
    #return mat_raw # 0.20

    mat_chol = _exp_matrix_diagonal(mat_raw)
    #return mat_chol # 0.60
    #mat_chol = mat_raw

    # Doesn't seem to matter much what you do
    #mat = np.einsum('ik,jk->ij', mat_chol, mat_chol)
    mat = np.matmul(mat_chol, np.transpose(mat_chol))
    #return mat # 1.3

    dim = mat.shape[0]
    diag_inds = (np.arange(dim), np.arange(dim))
    new_mat = jax.ops.index_update(mat, diag_inds, np.diag(mat) + diag_lb)
    return new_mat # 1.1 ?!

    
# Convert a vector containing the lower diagonal portion of a symmetric
# matrix into the full symmetric matrix.
#
# This is not currently used but could be useful for a symmetric matrix type.
def _unvectorize_symmetric_matrix(vec_val):
    ld_mat = _unvectorize_ld_matrix(vec_val)
    mat_val = ld_mat + ld_mat.transpose()
    # We have double counted the diagonal.  For some reason the autograd
    # diagonal functions require axis1=-1 and axis2=-2
    # mat_val = mat_val - \
    #     np.make_diagonal(np.diagonal(ld_mat, axis1=-1, axis2=-2),
    #                      axis1=-1, axis2=-2)
    mat_val = mat_val - np.diag(np.diagonal(ld_mat))

    return mat_val


In [5]:
dim = 50
#mat = np.eye(dim) * dim + np.full((dim, dim), 0.1)
foo = onp.random.random((dim, dim))
mat = np.eye(dim) * dim + foo + foo.T
mat = np.array(mat)
vec = _pack_posdef_matrix(mat)
assert_equal(_unpack_posdef_matrix(vec), mat)


In [6]:
pack_grad = jax.jit(jax.jacobian(_pack_posdef_matrix))

unpack_grad = jax.jit(jax.jacobian(
    lambda x, diag_lb: _unpack_posdef_matrix(x, diag_lb=diag_lb)),
                      static_argnums=1)
time_jit(lambda: unpack_grad(vec, 0.0))

tic = time.time()
unpack_grad(vec, 0.0)
tic = mark_tic(tic, 'subsequent')
unpack_grad(vec, 0.1)
tic = mark_tic(tic, 'subsequent but new par')
unpack_grad(vec, 0.1)
tic = mark_tic(tic, 'subsequent 0.0')
unpack_grad(vec, 0.0);
tic = mark_tic(tic, 'subsequent 0.1')
unpack_grad(vec, 0.1);


1st time:  1.1781713962554932
2nd time:  0.0005805492401123047
subsequent:	0.0006048679351806641
subsequent but new par:	1.2682888507843018
subsequent 0.0:	0.00041222572326660156
subsequent 0.1:	0.00035071372985839844


In [7]:
exp_grad = jax.jit(jax.jacobian(_exp_matrix_diagonal))

print('Exp')
time_jit(lambda: exp_grad(mat))

time_jit(lambda: np.matmul(mat, mat.T))
time_jit(lambda: np.matmul(mat, np.transpose(mat)))


pack_grad = jax.jit(jax.jacobian(_vectorize_ld_matrix))
unpack_grad = jax.jit(jax.jacobian(_unvectorize_ld_matrix))

print('Pack')
time_jit(lambda: pack_grad(mat))

print('Unpack')
time_jit(lambda: unpack_grad(vec))

Exp
1st time:  0.8487815856933594
2nd time:  0.0005159378051757812
1st time:  0.0006289482116699219
2nd time:  0.0005216598510742188
1st time:  0.0005910396575927734
2nd time:  0.0005042552947998047
Pack
1st time:  0.1282823085784912
2nd time:  0.00017905235290527344
Unpack
1st time:  0.25592041015625
2nd time:  0.00032973289489746094
