In [1]:
import os
from pathlib import Path
import pickle


import chex

chex.set_n_cpu_devices(2)

import sys
sys.path.append("../src")


# Import the remaining JAX related 
from gabenet.mcmc import sample_markov_chain
from gabenet.nets import MultinomialDirichletBelieve
import haiku as hk
import jax
import jax.numpy as jnp
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer



In [2]:
files_train = fetch_20newsgroups(subset='all', )
cv = CountVectorizer(min_df=10, max_features=2_000)
X_train = cv.fit_transform(files_train.data)
X_train = jnp.array(X_train.todense().astype(jnp.float32))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
X_train = X_train[:100]

In [4]:
# Pseudo-random number generator sequence.
key_seq = hk.PRNGSequence(42)

m_samples, n_features = X_train.shape

In [5]:
@hk.transform_with_state
def kernel(n_hidden_units = (200, )):
    """Advance the Markov chain by one step."""
    model = MultinomialDirichletBelieve(n_hidden_units, n_features)
    # Do one Gibbs sampling step.
    model(X_train)

In [6]:
params, state = kernel.init(next(key_seq))

 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125. 126. 127.
 128. 129. 130. 131. 132. 133. 134. 135. 136. 137. 138. 139. 140. 141.
 142. 143. 144. 145. 146. 147. 148. 149. 150. 151. 152. 153. 154. 155.
 156. 157. 158. 159. 160. 161. 162. 163. 164. 165. 166. 167. 168. 169.
 170. 171. 172. 173. 174. 175. 176. 177. 178. 179. 180. 181. 182. 183.
 184. 185. 186. 187. 188. 189. 190. 191. 192. 193. 194. 195. 196. 197.
 198. 199.]


In [12]:
theta = state['multinomial_dirichlet_believe/~/multinomial_layer']['copy[theta(1)]']
phi = state['multinomial_dirichlet_believe/~/multinomial_layer']['phi']

In [22]:
from tensorflow_probability.substrates import jax as tfp  # type: ignore

tfd = tfp.distributions


In [18]:
from gabenet.random import augmented_poisson
rate = theta[:, jnp.newaxis, :] * phi.T[jnp.newaxis, ...]

augmented_poisson(next(key_seq), rate, X_train)

Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 

In [23]:
with jax.profiler.trace("/tmp/trace-augment-internals"):
    rate = theta[:, jnp.newaxis, :] * phi.T[jnp.newaxis, ...]
    rate_norm = jnp.sum(rate, axis=-1, keepdims=True)
    zeta = jnp.where(rate_norm == 0, 0, rate / rate_norm)
    x_augmented = tfd.Multinomial(total_count=X_train, probs=zeta).sample(seed=next(key_seq))

In [None]:
rate = state['multi_dirichlet_believe/~/multinomial_layer']

In [7]:
# Warm up
_, state = kernel.apply(params, state, next(key_seq))
_ = state['multinomial_dirichlet_believe/~/cap_layer']['r'].block_until_ready()

In [8]:
with jax.profiler.trace("/tmp/jax-trace"):
    _, state = kernel.apply(params, state, next(key_seq))
    _ = state['multinomial_dirichlet_believe/~/cap_layer']['r'].block_until_ready()