In [6]:
from paragami.base_patterns import Pattern
from paragami.pattern_containers import register_pattern_json

import jax
import jax.numpy as np
import numpy as onp

import math

import time
import timeit
#from jax import custom_jvp

import itertools


In [2]:

def assert_equal(x, y, tol=1e-12):
    assert(onp.max(onp.abs(x - y)) < tol)
    
def time_jit(f):
    tic = time.time()
    f()
    print('1st time: ', time.time() - tic)

    tic = time.time()
    f()
    print('2nd time: ', time.time() - tic)

    
def mark_tic(tic, op):
    print(f'{op}:\t{time.time() - tic}')
    return time.time()

## Pattern array

In [22]:
class EmptyClass():
    def __init__(self):
        pass

    
k_approx = 30
dim = 4
base_pattern = paragami.PSDSymmetricMatrixPattern(size=dim)

############

self = EmptyClass()

self.__array_shape = (2, 3)
self.__array_ranges = [range(0, t) for t in self.__array_shape]
self.__array_indices = np.array([ i for i in itertools.product(*self.__array_ranges) ])

self.__base_pattern = base_pattern

empty_val = base_pattern.empty(valid=True)
self.__shape = tuple(self.__array_shape) + empty_val.shape


def _stacked_obs_slice(self, item, flat_length):
    assert len(item) == len(self.__array_shape)
    linear_item = onp.ravel_multi_index(item, self.__array_shape) * flat_length
    return np.arange(linear_item, linear_item + flat_length)
    #return slice(linear_item, linear_item + flat_length)

free = True
flat_length = base_pattern.flat_length(free=free)
self.__slices_array = np.array([
    _stacked_obs_slice(self, item, flat_length)
    for item in self.__array_indices
])


def fold(self, flat_val, free=True, validate_value=False):
    #free = self._free_with_default(free)
    flat_val = np.atleast_1d(flat_val)
    if len(flat_val.shape) != 1:
        raise ValueError('The argument to fold must be a 1d vector.')

    flat_length = self.__base_pattern.flat_length(free)
#     if flat_val.size != self.flat_length(free):
#        error_string = \
#            'Wrong size for parameter.  Expected {}, got {}'.format(
#                str(self.flat_length(free)), str(flat_val.size))
#        raise ValueError(error_string)

    op = 1
    print(f'op: {op}')
    if (op == 1):
        folded_array = np.array([
            self.__base_pattern.fold(
                flat_val[item_slice],
                free=free, validate_value=validate_value)
            for item_slice in self.__slices_array ])
    if (op == 2):
        folded_array = np.array([
            self.__base_pattern.fold(
                flat_val[_stacked_obs_slice(self, item, flat_length)],
                free=free, validate_value=validate_value)
            for item in itertools.product(*self.__array_ranges)])

    folded_val = np.reshape(folded_array, self.__shape)

    if not free:
        valid, msg = self.validate_folded(
            folded_val, validate_value=validate_value)
        if not valid:
            raise ValueError(msg)
    return folded_val


flat_val = onp.random.random(onp.prod(self.__array_shape) * flat_length)
fold(self, flat_val, free=True, validate_value=False)

op: 1


DeviceArray([[[[2.78443999, 0.06417093, 1.36966589, 0.50874879],
               [0.06417093, 1.73970146, 0.84172978, 0.68382472],
               [1.36966589, 0.84172978, 3.95157879, 1.0164845 ],
               [0.50874879, 0.68382472, 1.0164845 , 6.30171981]],

              [[3.55208217, 1.47830364, 0.36675653, 1.2873947 ],
               [1.47830364, 4.09454905, 0.7771385 , 1.22717463],
               [0.36675653, 0.7771385 , 6.94947158, 1.1761852 ],
               [1.2873947 , 1.22717463, 1.1761852 , 2.96973594]],

              [[5.51920234, 1.99217454, 0.70656138, 1.35060106],
               [1.99217454, 3.26181202, 0.97521351, 1.5022506 ],
               [0.70656138, 0.97521351, 5.43759648, 1.91238427],
               [1.35060106, 1.5022506 , 1.91238427, 3.50536329]]],


             [[[1.40957826, 1.08127177, 0.35390947, 0.54044772],
               [1.08127177, 4.03914488, 0.30937159, 0.9131805 ],
               [0.35390947, 0.30937159, 1.60242157, 0.18413326],
               [0

In [23]:
fold_grad_jit = jax.jit(
    jax.jacobian(lambda flat_val: fold(self, flat_val, free=True, validate_value=False)))

time_jit(lambda: fold_grad_jit(flat_val))

op: 1
1st time:  2.5252795219421387
2nd time:  0.00039315223693847656


```
op: 1
1st time:  2.5252795219421387
2nd time:  0.00039315223693847656

op: 2
1st time:  2.73453950881958
2nd time:  0.0005466938018798828
```

In [4]:
import paragami

k_approx = 30
dim = 4

base_pattern = paragami.PSDSymmetricMatrixPattern(size=dim)
covar_array_pattern = \
        paragami.PatternArray(array_shape = (k_approx, ), \
                    base_pattern = base_pattern)

covar_array = covar_array_pattern.random()
print(covar_array.shape)
#covar_array_flattened = covar_array_pattern.flatten(covar_array, free = False)
#covar_array_flattened_free = covar_array_pattern.flatten(covar_array, free = True)

(30, 4, 4)


In [None]:
def fun(covar_array): 
    return (covar_array**2).sum()

# flattened function
fun_flattened = paragami.FlattenFunctionInput(original_fun=fun, 
                                patterns = covar_array_pattern,
                                free = False,
                                argnums = 0) 

# flattened and freed function
fun_flattened_free = paragami.FlattenFunctionInput(original_fun=fun, 
                                patterns = covar_array_pattern,
                                free = True,
                               argnums = 0) 

assert_equal(
    fun_flattened(covar_array_flattened),
    fun_flattened_free(covar_array_flattened_free))


In [None]:
grad_fun_flattened = jax.jit(jax.grad(fun_flattened))
grad_fun_flattened_free = jax.jit(jax.grad(fun_flattened_free))

time_jit(lambda: grad_fun_flattened(covar_array_flattened))
time_jit(lambda: grad_fun_flattened_free(covar_array_flattened_free))

In [None]:
print(jax.lax.map(lambda x: x + 1, np.arange(0, 10)))

In [None]:
__array_shape = (2, 3)
__array_ranges = [range(0, t) for t in __array_shape]
__array_ranges


empty_pattern = base_pattern.empty(valid=True)
__shape = tuple(__array_shape) + empty_pattern.shape

repeated_array = np.array(
    [empty_pattern
     for item in itertools.product(*__array_ranges)])
empty_orig = np.reshape(repeated_array, __shape)


In [None]:
# Works but inefficient, you don't want to create the entries array.
__array_indices = np.array([ i for i in itertools.product(*__array_ranges) ])
repeated_array = jax.lax.map(lambda x: empty_pattern, __array_indices)
empty_jax = np.reshape(repeated_array, __shape)
assert_equal(empty_jax, empty_orig)

In [None]:
_stacked_obs_slice(__array_indices[2], flat_length)

In [None]:
foo = np.array(onp.random.random(10))
bar = np.arange(2, 5)
foo[bar]

In [5]:
__array_indices = onp.array([ i for i in itertools.product(*__array_ranges) ])


def _stacked_obs_slice(item, flat_length):
    assert len(item) == len(__array_shape)
    # Man maybe we just need to do this ourselves
    linear_item = onp.ravel_multi_index(item, __array_shape) * flat_length
    return np.arange(linear_item, linear_item + flat_length)
    #return slice(linear_item, linear_item + flat_length)


flat_length = base_pattern.flat_length(True)
flat_val = onp.random.random(flat_length * onp.prod(__array_shape))
flat_val = np.array(flat_val)
__slices = [ _stacked_obs_slice(item, flat_length) for item in __array_indices ]
print(__slices[0])
print(type(__slices[0]))
print(flat_val[__slices[0]])
print('----------------')


def fold_item(item):
    print('type in fold: ', type(item))
    item_slice = _stacked_obs_slice(item, flat_length)
    #print(flat_val[item_slice])
    return base_pattern.fold(
        flat_val[item_slice],
        free=True, validate_value=False),

print(type(__array_indices[2]))
fold_item(__array_indices[2])
print('----------------')


# folded_array = jax.lax.map(
#     fold_item,
#     __array_indices
# )


def fold_item2(flat_val, item_slice):
#     print('type in fold: ', type(item_slice))
#     print('val in fold:  ', flat_val[item_slice])
#     print('len in fold:  ', len(flat_val[item_slice]), len(item_slice))
    return base_pattern.fold(
        flat_val[item_slice],
        free=True, validate_value=False),

#fold_item2(flat_val, __slices[0])

def fold_with_map(flat_val):
    folded_array = jax.lax.map(
        lambda item_slice: fold_item2(flat_val, item_slice),
        np.array(__slices))
    return folded_array


def fold_with_for(flat_val):
    folded_array = np.array([
            fold_item2(flat_val, item_slice)
            for item_slice in __slices ])
    return folded_array

# Same speed. :(

fold_with_map_jit = jax.jit(fold_with_map)
time_jit(lambda: fold_with_map_jit)

fold_with_for_jit = jax.jit(fold_with_for)
time_jit(lambda: fold_with_for_jit)



NameError: name 'itertools' is not defined

## PSD patterns

In [None]:
assert False

In [None]:

def _sym_index(k1, k2):
    """
    Get the index of an entry in a folded symmetric array.

    Parameters
    ------------
    k1, k2: int
        0-based indices into a symmetric matrix.

    Returns
    --------
    int
        Return the linear index of the (k1, k2) element of a symmetric
        matrix where the triangular part has been stacked into a vector.
    """
    def ld_ind(k1, k2):
        return int(k2 + k1 * (k1 + 1) / 2)

    if k2 <= k1:
        return ld_ind(k1, k2)
    else:
        return ld_ind(k2, k1)


def _vectorize_ld_matrix(mat):
    """
    Linearize the lower diagonal of a square matrix.
    """
    nrow, ncol = np.shape(mat)
    if nrow != ncol:
        raise ValueError('mat must be square')
    return mat[np.tril_indices(nrow)]


def _unvectorize_ld_matrix(vec):
    """
    Invert the mapping of `_vectorize_ld_matrix`.
    """
    mat_size = int(0.5 * (math.sqrt(1 + 8 * vec.size) - 1))
    if mat_size * (mat_size + 1) / 2 != vec.size:
        raise ValueError('Vector is an impossible size')

    mat = np.zeros((mat_size, mat_size))
    inds = np.tril_indices(mat_size)
    return(jax.ops.index_update(mat, inds, vec))


def _exp_matrix_diagonal(mat):
    assert mat.shape[0] == mat.shape[1]
    dim = mat.shape[0]
    diag_inds = (np.arange(dim), np.arange(dim))
    exp_diags = np.exp(np.diag(mat))
    return(jax.ops.index_update(mat, diag_inds, exp_diags))


def _log_matrix_diagonal(mat):
    assert mat.shape[0] == mat.shape[1]
    dim = mat.shape[0]
    diag_inds = (np.arange(dim), np.arange(dim))
    log_diags = np.log(np.diag(mat))
    return(jax.ops.index_update(mat, diag_inds, log_diags))


def _pack_posdef_matrix(mat, diag_lb=0.0):
    k = mat.shape[0]
    # mat_lb = mat - np.make_diagonal(
    #     np.full(k, diag_lb), offset=0, axis1=-1, axis2=-2)
    mat_lb = mat - np.diag(np.full(k, diag_lb))
    return _vectorize_ld_matrix(
        _log_matrix_diagonal(np.linalg.cholesky(mat_lb)))


def _unpack_posdef_matrix(free_vec, diag_lb=0.0):
    mat_raw = _unvectorize_ld_matrix(free_vec)
    #return mat_raw # 0.20

    mat_chol = _exp_matrix_diagonal(mat_raw)
    #return mat_chol # 0.60
    #mat_chol = mat_raw

    # Doesn't seem to matter much what you do
    #mat = np.einsum('ik,jk->ij', mat_chol, mat_chol)
    mat = np.matmul(mat_chol, np.transpose(mat_chol))
    #return mat # 1.3

    dim = mat.shape[0]
    diag_inds = (np.arange(dim), np.arange(dim))
    new_mat = jax.ops.index_update(mat, diag_inds, np.diag(mat) + diag_lb)
    return new_mat # 1.1 ?!

    
# Convert a vector containing the lower diagonal portion of a symmetric
# matrix into the full symmetric matrix.
#
# This is not currently used but could be useful for a symmetric matrix type.
def _unvectorize_symmetric_matrix(vec_val):
    ld_mat = _unvectorize_ld_matrix(vec_val)
    mat_val = ld_mat + ld_mat.transpose()
    # We have double counted the diagonal.  For some reason the autograd
    # diagonal functions require axis1=-1 and axis2=-2
    # mat_val = mat_val - \
    #     np.make_diagonal(np.diagonal(ld_mat, axis1=-1, axis2=-2),
    #                      axis1=-1, axis2=-2)
    mat_val = mat_val - np.diag(np.diagonal(ld_mat))

    return mat_val


In [None]:
dim = 50
#mat = np.eye(dim) * dim + np.full((dim, dim), 0.1)
foo = onp.random.random((dim, dim))
mat = np.eye(dim) * dim + foo + foo.T
mat = np.array(mat)
vec = _pack_posdef_matrix(mat)
assert_equal(_unpack_posdef_matrix(vec), mat)


In [None]:
pack_grad = jax.jit(jax.jacobian(_pack_posdef_matrix))

unpack_grad = jax.jit(jax.jacobian(
    lambda x, diag_lb: _unpack_posdef_matrix(x, diag_lb=diag_lb)),
                      static_argnums=1)
time_jit(lambda: unpack_grad(vec, 0.0))

tic = time.time()
unpack_grad(vec, 0.0)
tic = mark_tic(tic, 'subsequent')
unpack_grad(vec, 0.1)
tic = mark_tic(tic, 'subsequent but new par')
unpack_grad(vec, 0.1)
tic = mark_tic(tic, 'subsequent 0.0')
unpack_grad(vec, 0.0);
tic = mark_tic(tic, 'subsequent 0.1')
unpack_grad(vec, 0.1);


In [None]:
exp_grad = jax.jit(jax.jacobian(_exp_matrix_diagonal))

print('Exp')
time_jit(lambda: exp_grad(mat))

time_jit(lambda: np.matmul(mat, mat.T))
time_jit(lambda: np.matmul(mat, np.transpose(mat)))


pack_grad = jax.jit(jax.jacobian(_vectorize_ld_matrix))
unpack_grad = jax.jit(jax.jacobian(_unvectorize_ld_matrix))

print('Pack')
time_jit(lambda: pack_grad(mat))

print('Unpack')
time_jit(lambda: unpack_grad(vec))