In [10]:
from functools import partial
# Run the following before any XLA modules such as JAX:
import chex

chex.set_n_cpu_devices(2)

# Add current directory to Python path.
import sys
sys.path.append("../../src")

# Import the remaining JAX related 
import arviz as az
from gabenet.mcmc import sample_markov_chain
from gabenet.nets import MultinomialDirichletBelieve
from gabenet.utils import freeze_trainable_states,perplexity
import haiku as hk
import jax
from jax import random
import jax.numpy as jnp

from dataset import load_mutation_spectrum, COSMIC_WEIGHTS

In [2]:
# Pseudo-random number generator sequence.
key_seq = hk.PRNGSequence(43)

X_train, X_test = load_mutation_spectrum()

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
X_train = X_train[:10]
X_test = X_test[:10]

In [4]:
n_features = 96
n_topics = len(COSMIC_WEIGHTS)

@hk.transform_with_state
def kernel(X=X_train, freeze_phi=True):
    """Advance the Markov chain by one step."""
    model = MultinomialDirichletBelieve([n_topics], n_features)
    if freeze_phi:
        model.layers.layers[1].set_training(False)
    # Do one Gibbs sampling step.
    model(X)
    
def probability(params, state):
    bottom_params = params.get("multinomial_dirichlet_believe/~/multinomial_layer", {})
    bottom_state = state["multinomial_dirichlet_believe/~/multinomial_layer"]
    phi = bottom_params.get("phi", bottom_state.get("phi"))
    theta = bottom_state["theta"]
    return theta @ phi

Inference using COSMIC weights.

In [11]:
keys = random.split(next(key_seq), jax.device_count())
params, state = jax.vmap(partial(kernel.init, freeze_phi=False), in_axes=[0, None])(keys, X_train)
params, state = freeze_trainable_states(state, variable_names=['phi'])
params['multinomial_dirichlet_believe/~/multinomial_layer']['phi'] = jnp.array(COSMIC_WEIGHTS)



Perplexity before training.

In [7]:
probs = probability(params, state).mean(axis=[0])
perplexity(X_test, probs)

Array(107.320724, dtype=float32)

In [8]:
params, states = sample_markov_chain(
    next(key_seq), 
    kernel=kernel, 
    n_samples=40, #400, 
    n_burnin_steps=50,#1_000, 
    params=params, 
    n_chains=jax.device_count(),
    initial_state=state,
)
_ = states['multinomial_dirichlet_believe/~/cap_layer']['theta'].block_until_ready()

In [12]:
# Perplexity on trained test set.
probs = probability(params, states).mean(axis=[0, 1])
print('Perplexity on fitted test set', perplexity(X_train, probs))
print('Perplexity on unseen observations', perplexity(X_test, probs))

Perplexity on fitted test set 74.03703
Perplexity on unseen observations 73.699


In [13]:
c_cap = states['multinomial_dirichlet_believe/~/cap_layer']['c']
idata = az.convert_to_inference_data({'c': c_cap})

In [16]:
print(az.summary(idata)['r_hat'])

c    2.07
Name: r_hat, dtype: float64


In [None]:
az.plot_trace(idata)