In [1]:
import sys
 # Add end-to-end package to path.
import sys
from pathlib import Path

sys.path.append(str(Path('../src').absolute()))

In [2]:
from sklearn import datasets

# Run the following before any XLA modules such as JAX:
import chex

chex.set_n_cpu_devices(2)

# Import the remaining JAX related 
from gabenet.mcmc import sample_markov_chain
from gabenet.nets import PoissonGammaBelieve

import haiku as hk
import jax
from jax import random
import jax.numpy as jnp



In [3]:
digits = datasets.load_digits()
n_samples = len(digits.images)
X_train = digits.images.reshape((n_samples, -1))

In [4]:
# Pseudo-random number generator sequence.
key_seq = hk.PRNGSequence(42)

m_samples, n_features = X_train.shape

def build_gamma_belief_net():
    """A two-layer decoder network."""
    n_hidden_units = (10, )
    return PoissonGammaBelieve(n_hidden_units, n_features)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [5]:
@hk.transform_with_state
def kernel():
    """Advances the Markov chain by one step."""
    model = build_gamma_belief_net()
    model(X_train)

In [6]:
@hk.transform_with_state
def forward():
    """Initialise Markov chain using forward samples."""
    model = build_gamma_belief_net()
    return model.forward(m_samples)

param_init, state_init = forward.init(next(key_seq))

In [7]:
# with jax.profiler.trace("/tmp/forward-trace"):
#     X, _ = forward.apply(param_init, state_init, next(key_seq))
#     X.block_until_ready()
#     # state_init['poisson_gamma_believe/~/cap_layer']['r'].block_until_ready()

In [8]:
_, state = kernel.apply(param_init, state_init, next(key_seq))
state['poisson_gamma_believe/~/cap_layer']['r'].block_until_ready()

Array([4.0652275e+00, 4.4300509e-04, 6.0318822e-01, 2.3464619e-03,
       1.1366474e-03, 1.5800984e+00, 5.6379624e-05, 4.6082416e+00,
       2.7244008e-04, 1.8254912e-01], dtype=float32)

In [9]:
with jax.profiler.trace("/tmp/backward-trace"):
    _, state = kernel.apply(param_init, state_init, next(key_seq))
    state['poisson_gamma_believe/~/cap_layer']['r'].block_until_ready()

In [10]:
kernel_jit = jax.jit(kernel.apply)
# Warm-up.
_, state = kernel_jit(param_init, state_init, next(key_seq))
state['poisson_gamma_believe/~/cap_layer']['r'].block_until_ready()

Array([4.2245402e+00, 2.2497942e-04, 5.9250891e-01, 4.5619719e-03,
       1.5383102e-03, 1.5621212e+00, 8.0437101e-05, 4.6301651e+00,
       6.8614125e-04, 1.7477462e-01], dtype=float32)

In [12]:
_, state = kernel_jit(param_init, state_init, next(key_seq))
state['poisson_gamma_believe/~/cap_layer']['r'].block_until_ready()

Array([3.9794190e+00, 9.3176265e-12, 6.1930192e-01, 3.2035506e-03,
       5.4359541e-04, 1.5460705e+00, 9.8756704e-05, 4.6776247e+00,
       2.2216747e-04, 1.8224584e-01], dtype=float32)

In [13]:
with jax.profiler.trace("/tmp/backward-compiled-trace"):
    _, state = kernel_jit(param_init, state_init, next(key_seq))
    state['poisson_gamma_believe/~/cap_layer']['r'].block_until_ready()