In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, hessian, jit, vmap

def logprob_fun(mu, x):
    return np.sum(0.5 * (mu - x)**2)

grad_fun = jit(grad(logprob_fun))
hess_fun = jit(hessian(logprob_fun))

# Paragami debugging

In [2]:
import paragami
import copy
import unittest
from numpy.testing import assert_array_almost_equal
import scipy as sp
import scipy as osp

import itertools
import json
import collections


In [3]:
# A pattern that matches no actual types for causing errors to test.
class BadTestPattern(paragami.base_patterns.Pattern):
    def __init__(self):
        pass

    def __str__(self):
        return 'BadTestPattern'

    def as_dict(self):
        return { 'pattern': 'bad_test_pattern' }

    def fold(self, flat_val, validate_value=None):
        return 0

    def flatten(self, flat_val, validate_value=None):
        return 0

    def empty(self):
        return 0

    def validate_folded(self, folded_val, validate_value=None):
        return True, ''

    def flat_indices(self, folded_bool, free):
        return []


def _test_pattern(testcase, pattern, valid_value,
                  check_equal=assert_array_almost_equal,
                  jacobian_ad_test=True):

    print('Testing pattern {}'.format(pattern))

    # Execute required methods.
    empty_val = pattern.empty(valid=True)
    pattern.flatten(empty_val, free=False)
    empty_val = pattern.empty(valid=False)

    random_val = pattern.random()
    pattern.flatten(random_val, free=False)

    str(pattern)

    pattern.empty_bool(True)

    # Make sure to test != using a custom test.
    testcase.assertTrue(pattern == pattern)

    ###############################
    # Test folding and unfolding.
    for free in [True, False, None]:
        for free_default in [True, False, None]:
            pattern.free_default = free_default
            if (free_default is None) and (free is None):
                with testcase.assertRaises(ValueError):
                    flat_val = pattern.flatten(valid_value, free=free)
                with testcase.assertRaises(ValueError):
                    folded_val = pattern.fold(flat_val, free=free)
            else:
                flat_val = pattern.flatten(valid_value, free=free)
                testcase.assertEqual(len(flat_val), pattern.flat_length(free))
                folded_val = pattern.fold(flat_val, free=free)
                check_equal(valid_value, folded_val)
                if hasattr(valid_value, 'shape'):
                    testcase.assertEqual(valid_value.shape, folded_val.shape)

    ####################################
    # Test conversion to and from JSON.
    pattern_dict = pattern.as_dict()
    json_typename = pattern.json_typename()
    json_string = pattern.to_json()
    json_dict = json.loads(json_string)
    testcase.assertTrue('pattern' in json_dict.keys())
    testcase.assertTrue(json_dict['pattern'] == json_typename)
    new_pattern = paragami.get_pattern_from_json(json_string)
    testcase.assertTrue(new_pattern == pattern)

    # Test that you cannot covert from a different patter.
    bad_test_pattern = BadTestPattern()
    bad_json_string = bad_test_pattern.to_json()
    testcase.assertFalse(pattern == bad_test_pattern)
    testcase.assertRaises(
        ValueError,
        lambda: pattern.__class__.from_json(bad_json_string))

    ############################################
    # Test the freeing and unfreeing Jacobians.
    def freeing_transform(flat_val):
        return pattern.flatten(
            pattern.fold(flat_val, free=False), free=True)

    def unfreeing_transform(free_flat_val):
        return pattern.flatten(
            pattern.fold(free_flat_val, free=True), free=False)

    ad_freeing_jacobian = jax.jacobian(freeing_transform)
    ad_unfreeing_jacobian = jax.jacobian(unfreeing_transform)

    for sparse in [True, False]:
        flat_val = pattern.flatten(valid_value, free=False)
        freeflat_val = pattern.flatten(valid_value, free=True)
        freeing_jac = pattern.freeing_jacobian(valid_value, sparse)
        unfreeing_jac = pattern.unfreeing_jacobian(valid_value, sparse)
        free_len = pattern.flat_length(free=False)
        flatfree_len = pattern.flat_length(free=True)

        # Check the shapes.
        testcase.assertTrue(freeing_jac.shape == (flatfree_len, free_len))
        testcase.assertTrue(unfreeing_jac.shape == (free_len, flatfree_len))

        # Check the values of the Jacobians.
        if sparse:
            # The Jacobians should be inverses of one another and full rank
            # in the free flat space.
            assert_array_almost_equal(
                np.eye(flatfree_len),
                np.array((freeing_jac @ unfreeing_jac).todense()))
            if jacobian_ad_test:
                assert_array_almost_equal(
                    ad_freeing_jacobian(flat_val),
                    np.array(freeing_jac.todense()))
                assert_array_almost_equal(
                    ad_unfreeing_jacobian(freeflat_val),
                    np.array(unfreeing_jac.todense()))
        else:
            # The Jacobians should be inverses of one another and full rank
            # in the free flat space.
            assert_array_almost_equal(
                np.eye(flatfree_len), freeing_jac @ unfreeing_jac)
            if jacobian_ad_test:
                assert_array_almost_equal(
                    ad_freeing_jacobian(flat_val), freeing_jac)
                assert_array_almost_equal(
                    ad_unfreeing_jacobian(freeflat_val), unfreeing_jac)


In [4]:
osp.sparse.diags(np.array([1.]))
osp.sparse.diags([1.])

<1x1 sparse matrix of type '<class 'numpy.float64'>'
	with 1 stored elements (1 diagonals) in DIAgonal format>

In [5]:
def isdense(x):
    return isinstance(x, np.ndarray)



def isscalarlike(x):
    """Is x either a scalar, an array scalar, or a 0-dim array?"""
    return np.isscalar(x) or (isdense(x) and x.ndim == 0)


test_shape = (1, )
valid_value = np.random.random(test_shape)
pattern = paragami.NumericArrayPattern(test_shape)

sparse = False

flat_val = pattern.flatten(valid_value, free=False)
freeflat_val = pattern.flatten(valid_value, free=True)

jac_array = \
    paragami.numeric_array_patterns._unconstrain_array_jacobian(
        valid_value, pattern._lb, pattern._ub)
jac_array = np.atleast_1d(jac_array).flatten()
jac_array

osp.sparse.diags(jac_array, offsets=np.zeros_like(jac_array, dtype=int))
#print(osp.sparse.diags(freeing_jac.flatten()))
# unfreeing_jac = pattern.unfreeing_jacobian(valid_value, sparse)
# free_len = pattern.flat_length(free=False)
# flatfree_len = pattern.flat_length(free=True)




<1x1 sparse matrix of type '<class 'numpy.float64'>'
	with 1 stored elements (1 diagonals) in DIAgonal format>

In [6]:
class DummyTest(unittest.TestCase):
    pass

for test_shape in [(1, ), (2, ), (2, 3), (2, 3, 4)]:
    valid_value = np.random.random(test_shape)
    pattern = paragami.NumericArrayPattern(test_shape)
    _test_pattern(DummyTest(), pattern, valid_value)


Testing pattern NumericArrayPattern (1,) (lb=-inf, ub=inf)
Testing pattern NumericArrayPattern (2,) (lb=-inf, ub=inf)
Testing pattern NumericArrayPattern (2, 3) (lb=-inf, ub=inf)
Testing pattern NumericArrayPattern (2, 3, 4) (lb=-inf, ub=inf)


In [7]:
assert False

AssertionError: 

# Generic Jax stuff

In [None]:
x = onp.random.random(100)
mu = onp.random.random(100)

# I would have expected this to get rid of the annoying warning but it does not.
cpu_device = jax.devices('cpu')[0]
jax.device_put(x, cpu_device);
jax.device_put(mu, cpu_device);

In [None]:
g = grad_fun(mu, x)
h = hess_fun(mu, x)
print(np.max(np.abs(g - (mu - x))))
print(np.max(np.abs(h - np.eye(100))))

In [None]:
# https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

In [None]:
from jax import custom_jvp
import jax.numpy as jnp

# f :: a -> b
@custom_jvp
def f(x):
    return jnp.sin(x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
    x, = primals
    t, = tangents
    return f(x), jnp.cos(x) * t

f.defjvp(f_jvp)

print(type(f(0.5)))

print('Use jax')
foo = jax.numpy.asarray(f(0.5) + 3)
print(foo, type(foo))
print(isinstance(foo, onp.ndarray))
print(isinstance(foo, jax.numpy.ndarray))

print('Use numpy')
foo = onp.asarray(f(0.5) + 3)
print(isinstance(foo, onp.ndarray))
print(isinstance(foo, jax.numpy.ndarray))


In [None]:
@custom_jvp
def f(x, y):
    return jnp.sin(x) * y

f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
          lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)

In [None]:
foo = np.array([1, 2, 3])
print(jax.ops.index_update(foo, [1, 2], [10, 20]))

foo = np.array([[1, 2], [3, 4]])
print(jax.ops.index_update(foo, [1, 2], [10, 20]))

inds = np.triu_indices(2)
print(inds)
print(jax.ops.index_update(foo, inds, [10, 20, 30]))


In [None]:
#jax.sp.logsumexp

In [None]:
vec = np.arange(0, 3, dtype=np.float32) + 1
np.diag(vec)

In [None]:

def _exp_matrix_diagonal(mat):
    assert mat.shape[0] == mat.shape[1]
    dim = mat.shape[0]
    diag_inds = (np.arange(dim), np.arange(dim))
    exp_diags = np.exp(np.diag(mat))
    return(jax.ops.index_update(mat, diag_inds, exp_diags))

def _log_matrix_diagonal(mat):
    assert mat.shape[0] == mat.shape[1]
    dim = mat.shape[0]
    diag_inds = (np.arange(dim), np.arange(dim))
    log_diags = np.log(np.diag(mat))
    return(jax.ops.index_update(mat, diag_inds, log_diags))

mat = onp.random.random((3, 3))
print(mat)
print(_exp_matrix_diagonal(mat))
print(jax.jacobian(_exp_matrix_diagonal)(mat))


In [None]:
np.triu_indices(5)

def pack_vec(vec, dim):
    assert len(vec) == dim * (dim + 1) / 2
    mat = np.zeros((dim, dim))
    inds = np.tril_indices(dim)
    return(jax.ops.index_update(mat, inds, vec))

vec = np.arange(0, 6, dtype=np.float32) + 1
print(vec.dtype)
print(pack_vec(vec, 3))

print('Raw:')
print(jax.jacobian(pack_vec)(vec, 3))

print('JIT:')
jac_fun = jit(jax.jacobian(pack_vec), static_argnums=1)
print(jac_fun(vec, 3))


In [None]:

# Fails

# @custom_jvp
# def replace_ind(x, v, i):
#     x[i] = v
#     return x

# replace_ind.defjvps(
#     lambda x_dot, ans, x, v, i: replace_ind(x_dot, 0.0, i),
#     lambda v_dot, ans, x, v, i: replace_ind(ans, v_dot, i),
#     None)

# x = np.array([1.0, 2.0, 3.5])
# replace_ind(x, 10.0, 1)