In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, hessian, jit, vmap

def logprob_fun(mu, x):
    return np.sum(0.5 * (mu - x)**2)

grad_fun = jit(grad(logprob_fun))
hess_fun = jit(hessian(logprob_fun))

# Paragami debugging

In [2]:
import paragami
import copy
import unittest
from numpy.testing import assert_array_almost_equal
import scipy as sp
import scipy as osp

import itertools
import json
import collections

import time




In [3]:
def jax_random(shape):
    key = jax.random.PRNGKey(42)
    if not type(shape) is tuple:
        shape = (shape, )
    return jax.random.uniform(key, shape=shape, dtype='float64')


# Profiling

In [4]:
class DummyTest(unittest.TestCase):
    pass
testcase = DummyTest()
self = DummyTest()
check_equal = assert_array_almost_equal

# array_pattern = paragami.NumericArrayPattern(
#     shape=(4, ), lb=-1, ub=10.0)
# pattern = paragami.PatternArray((2, 3), array_pattern)

pattern = paragami.PSDSymmetricMatrixPattern(100)

#pattern = paragami.NumericArrayPattern(shape=(500, 10, 10), lb=-1, ub=10.0)

In [14]:
import cProfile
import pstats
valid_value = pattern.random()

# Execute required methods.
empty_val = pattern.empty(valid=True)
flat_val = pattern.flatten(empty_val, free=True, validate_value=False)
profile = True

if profile:
    pr = cProfile.Profile()
    pr.enable()

jit_flatten = jax.jit(lambda val: pattern.flatten(empty_val, free=True, validate_value=True))
jit_fold = jax.jit(lambda flat_val: pattern.fold(flat_val, free=True, validate_value=True))

jit_flatten(empty_val)
jit_fold(flat_val)

tic = time.time()
for _ in range(500):
    #pattern.flatten(empty_val, free=False, validate_value=False)
    jit_flatten(empty_val)
    #jit_fold(flat_val)
    
tic = time.time() - tic; print(tic)

if profile:
    pr.disable()

0.0442051887512207


In [9]:
ps = pstats.Stats(pr).strip_dirs().sort_stats('cumulative')
ps.print_stats()

         386160 function calls (386158 primitive calls) in 0.490 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        9    0.000    0.000    0.489    0.054 interactiveshell.py:3302(run_code)
        9    0.000    0.000    0.489    0.054 {built-in method builtins.exec}
        1    0.019    0.019    0.479    0.479 <ipython-input-5-959f729a1ed8>:21(<module>)
     5001    0.005    0.000    0.460    0.000 <ipython-input-5-959f729a1ed8>:14(<lambda>)
     5001    0.007    0.000    0.456    0.000 psdmatrix_patterns.py:295(flatten)
     5001    0.007    0.000    0.437    0.000 lax_numpy.py:1153(ravel)
     5002    0.004    0.000    0.407    0.000 lax_numpy.py:1109(reshape)
     5002    0.011    0.000    0.403    0.000 lax_numpy.py:1138(_reshape_method)
     5002    0.006    0.000    0.388    0.000 lax_numpy.py:1126(_reshape)
     5002    0.015    0.000    0.378    0.000 lax.py:657(reshape)
     5002    0.006    0.000    0.317 

<pstats.Stats at 0x7f986a134320>

In [7]:


empty_val = pattern.empty(valid=False)


random_val = pattern.random()
pattern.flatten(random_val, free=False)

str(pattern)

pattern.empty_bool(True)

# Make sure to test != using a custom test.
testcase.assertTrue(pattern == pattern)




###############################
# Test folding and unfolding.
for free in [True, False, None]:
    for free_default in [True, False, None]:
        pattern.free_default = free_default
        if (free_default is None) and (free is None):
            with testcase.assertRaises(ValueError):
                flat_val = pattern.flatten(valid_value, free=free)
            with testcase.assertRaises(ValueError):
                folded_val = pattern.fold(flat_val, free=free)
        else:
            flat_val = pattern.flatten(valid_value, free=free)
            testcase.assertEqual(len(flat_val), pattern.flat_length(free))
            folded_val = pattern.fold(flat_val, free=free)
            check_equal(valid_value, folded_val)
            if hasattr(valid_value, 'shape'):
                testcase.assertEqual(valid_value.shape, folded_val.shape)



In [8]:
assert False

AssertionError: 

# Generic Jax stuff

In [None]:
x = onp.random.random(100)
mu = onp.random.random(100)

# I would have expected this to get rid of the annoying warning but it does not.
cpu_device = jax.devices('cpu')[0]
jax.device_put(x, cpu_device);
jax.device_put(mu, cpu_device);

In [None]:
g = grad_fun(mu, x)
h = hess_fun(mu, x)
print(np.max(np.abs(g - (mu - x))))
print(np.max(np.abs(h - np.eye(100))))

In [None]:
# https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

In [None]:
from jax import custom_jvp
import jax.numpy as jnp

# f :: a -> b
@custom_jvp
def f(x):
    return jnp.sin(x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
    x, = primals
    t, = tangents
    return f(x), jnp.cos(x) * t

f.defjvp(f_jvp)

print(type(f(0.5)))

print('Use jax')
foo = jax.numpy.asarray(f(0.5) + 3)
print(foo, type(foo))
print(isinstance(foo, onp.ndarray))
print(isinstance(foo, jax.numpy.ndarray))

print('Use numpy')
foo = onp.asarray(f(0.5) + 3)
print(isinstance(foo, onp.ndarray))
print(isinstance(foo, jax.numpy.ndarray))


In [None]:
@custom_jvp
def f(x, y):
    return jnp.sin(x) * y

f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
          lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)

In [None]:
foo = np.array([1, 2, 3])
print(jax.ops.index_update(foo, [1, 2], [10, 20]))

foo = np.array([[1, 2], [3, 4]])
print(jax.ops.index_update(foo, [1, 2], [10, 20]))

inds = np.triu_indices(2)
print(inds)
print(jax.ops.index_update(foo, inds, [10, 20, 30]))


In [None]:
#jax.sp.logsumexp

In [None]:
vec = np.arange(0, 3, dtype=np.float32) + 1
np.diag(vec)

In [None]:

def _exp_matrix_diagonal(mat):
    assert mat.shape[0] == mat.shape[1]
    dim = mat.shape[0]
    diag_inds = (np.arange(dim), np.arange(dim))
    exp_diags = np.exp(np.diag(mat))
    return(jax.ops.index_update(mat, diag_inds, exp_diags))

def _log_matrix_diagonal(mat):
    assert mat.shape[0] == mat.shape[1]
    dim = mat.shape[0]
    diag_inds = (np.arange(dim), np.arange(dim))
    log_diags = np.log(np.diag(mat))
    return(jax.ops.index_update(mat, diag_inds, log_diags))

mat = onp.random.random((3, 3))
print(mat)
print(_exp_matrix_diagonal(mat))
print(jax.jacobian(_exp_matrix_diagonal)(mat))


In [None]:
np.triu_indices(5)

def pack_vec(vec, dim):
    assert len(vec) == dim * (dim + 1) / 2
    mat = np.zeros((dim, dim))
    inds = np.tril_indices(dim)
    return(jax.ops.index_update(mat, inds, vec))

vec = np.arange(0, 6, dtype=np.float32) + 1
print(vec.dtype)
print(pack_vec(vec, 3))

print('Raw:')
print(jax.jacobian(pack_vec)(vec, 3))

print('JIT:')
jac_fun = jit(jax.jacobian(pack_vec), static_argnums=1)
print(jac_fun(vec, 3))


In [None]:

# Fails

# @custom_jvp
# def replace_ind(x, v, i):
#     x[i] = v
#     return x

# replace_ind.defjvps(
#     lambda x_dot, ans, x, v, i: replace_ind(x_dot, 0.0, i),
#     lambda v_dot, ans, x, v, i: replace_ind(ans, v_dot, i),
#     None)

# x = np.array([1.0, 2.0, 3.5])
# replace_ind(x, 10.0, 1)

## Why is jax slow?

In [None]:
import timeit
foo = np.random.random((100, 10, 7)) + 1
number = 1000


def _unconstrain_array(array, lb, ub, np):
    # Assume that the inputs obey the constraints, lb < ub and
    # lb <= array <= ub, which are checked in the pattern.
    if ub == float("inf"):
        if lb == -float("inf"):
            # For consistent behavior, never return a reference.
            # Note that deepcopy will cause jax to fail.
            return copy.copy(array)
        else:
            return np.log(array - lb)
    else:  # the upper bound is finite
        if lb == -float("inf"):
            return -1 * np.log(ub - array)
        else:
            return np.log(array - lb) - np.log(ub - array)

setup_str = "from __main__ import jnp, np, foo, _unconstrain_array"
print(timeit.timeit('_unconstrain_array(foo, lb=0.5, ub=10.0, np=jnp)', setup=setup_str, number=number))
print(timeit.timeit('_unconstrain_array(foo, lb=0.5, ub=10.0, np=np)', setup=setup_str, number=number))


In [None]:
import timeit
foo = np.random.random((100, 10, 7)) + 1
number = 5000

setup_str = "from __main__ import jnp, np, foo"
print(timeit.timeit('jnp.atleast_1d(foo)', setup=setup_str, number=number))
print(timeit.timeit('np.atleast_1d(foo)', setup=setup_str, number=number))

In [None]:
import timeit
foo = np.random.random((100, 10, 7)) + 1
number = 5000

setup_str = "from __main__ import jnp, np, foo"
print(timeit.timeit('jnp.all(foo < 0.)', setup=setup_str, number=number))
print(timeit.timeit('np.all(foo < 0.)', setup=setup_str, number=number))

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

import timeit
foo = [ np.random.random((100, )) for _ in range(10) ]
number = 500

@jax.jit
def jitted_fun(x):
    return jnp.hstack(x)

setup_str = "from __main__ import jnp, np, foo, jitted_fun"
print(timeit.timeit('y = jnp.hstack(foo)', setup=setup_str, number=number))
print(timeit.timeit('y = jitted_fun(foo)', setup=setup_str, number=number))
print(timeit.timeit('y = np.hstack(foo)', setup=setup_str, number=number))

# Is this even going to work

In [None]:
class BasePattern():
    def __init__(self, scale):
        self._scale = scale + 1.0
    
    def fun(self, val):
        return val

class Pattern(BasePattern):
    
    def get_scale(self):
        return self._scale

    def fun(self, val):
        if self._scale < 0.0:
            return 0.0
        return self.get_scale() * val
    
pattern = Pattern(2.1)
      
jit_fun = jax.jit(pattern.fun)
print(jit_fun(5.0))

jit_grad = jax.jit(jax.grad(pattern.fun))
print(jit_grad(5.0))


# Profiling psd functions

All these are actually comparable between jax and autograd.

In [None]:
import cProfile
import pstats
valid_value = pattern.random()

# Execute required methods.
profile = True

if profile:
    pr = cProfile.Profile()
    pr.enable()

from paragami.psdmatrix_patterns import \
    _vectorize_ld_matrix, _unvectorize_ld_matrix, \
    _pack_posdef_matrix, _unpack_posdef_matrix

mat = jax_random((30, 30))
mat = mat @ mat + 100 * np.eye(30)
val = _vectorize_ld_matrix(mat)
#jit_fun = jax.jit(_vectorize_ld_matrix); jit_fun(mat)
#jit_fun = jax.jit(_unvectorize_ld_matrix); jit_fun(val)
#jit_fun = jax.jit(_pack_posdef_matrix); jit_fun(mat)
jit_fun = jax.jit(_unpack_posdef_matrix); jit_fun(val)

tic = time.time()
for _ in range(100):
    jit_fun(val)
    
tic = time.time() - tic; print(tic)

if profile:
    pr.disable()

In [None]:
# ps = pstats.Stats(pr).strip_dirs().sort_stats('cumulative')
# ps.print_stats()