In [1]:
import os
from pathlib import Path
import pickle
from typing import Literal, NamedTuple

import chex
chex.set_n_cpu_devices(6)
os.environ['ARTEFACT_DIR'] = '/tmp/digits'

import jax
import haiku as hk

import sys
sys.path.append('../../src')

from gabenet.nets import MultinomialDirichletBelieve, PoissonGammaBelieve
from gabenet.random import PRNGSequence
from gabenet.utils import perplexity, get_hidden_unit_sizes
from gabenet._surgery import (
    copy_to_larger_net,
    prune_network,
    determine_number_to_prune,
)
import haiku as hk
import jax
from jax import config
config.update("jax_debug_nans", True)
from jax import random
from jax.tree_util import tree_map
import jax.numpy as jnp
from jaxtyping import PRNGKeyArray

from dataset import load_digits


_ARTEFACT_DIR = Path('/tmp/digits')
_ARTEFACT_DIR.mkdir(parents=True, exist_ok=True)


class TrainState(NamedTuple):
    params: hk.Params
    state: hk.State
    key: jax.Array  # type: ignore
    step: int
    model_name: Literal["multinomial_dirichlet_believe", "poisson_gamma_believe"]
    hidden_layer_sizes: tuple[int]


def infer_last_checkpoint_number(checkpoint_dir: Path) -> int:
    """Look in checkpoint_dir and find largest checkpoint number."""
    # List all pickle files, sort by number and load last one.
    files = sorted(checkpoint_dir.glob("checkpoint_*.pkl"))
    if len(files) == 0:
        return -1
    return int(files[-1].stem.split("_")[-1])



def load_last_checkpoint(
    model_name, hidden_layer_sizes, source_dir=_ARTEFACT_DIR
) -> TrainState | None:
    """Load last state from disk."""
    architecture = "-".join(map(str, hidden_layer_sizes))
    checkpoint_dir = source_dir / model_name / architecture / "checkpoints"
    files = sorted(checkpoint_dir.glob("checkpoint_*.pkl"))
    if len(files) == 0:
        print("No checkpoints found.")
        return None
    with open(files[-1], "rb") as fi:
        state = pickle.load(fi)
    i_checkpoint = int(files[-1].stem.split("_")[-1])
    print(f"Loaded checkpoint i={i_checkpoint}.")
    return state




In [2]:
# Model hyperparameters.
MODEL: Literal[
    "multinomial_dirichlet_believe", "poisson_gamma_believe"
] = "multinomial_dirichlet_believe"
n_topics = 50
HIDDEN_LAYER_SIZES = (n_topics,)
GAMMA_0 = 1.0
ETA = 0.05
_bottom_layer_name = (
    f"{MODEL}/~/multinomial_layer"
    if MODEL == "multinomial_dirichlet_believe"
    else f"{MODEL}/~/poisson_layer"
)

In [3]:
X_train, X_test = load_digits()
n_features = X_train.shape[1]

@hk.transform_with_state
def kernel(n_hidden_units, X=X_train):
    """Advance the Markov chain by one step."""
    if MODEL == "multinomial_dirichlet_believe":
        model = MultinomialDirichletBelieve(
            n_hidden_units,
            n_features,
            gamma_0=GAMMA_0,
            eta=ETA,
        )
    else:
        model = PoissonGammaBelieve(
            n_hidden_units, n_features, gamma_0=GAMMA_0, eta=ETA
        )
    # Do one Gibbs sampling step.
    model(X)


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [5]:
train_state = load_last_checkpoint(MODEL, HIDDEN_LAYER_SIZES, _ARTEFACT_DIR)
n_chains = jax.device_count()
key_seq = hk.PRNGSequence(train_state.key)

Loaded checkpoint i=7.


In [6]:
from functools import partial

kernel_fn = partial(kernel.apply, n_hidden_units=HIDDEN_LAYER_SIZES)
# kernel_fn = jax.pmap(kernel_fn, in_axes=(None, 0, 0))


def at_i(i, state):
    return tree_map(lambda x: x[i], state)

In [7]:
import numpy as np
tree_map(lambda x: np.any(np.isnan(x)), train_state.state)

{'multinomial_dirichlet_believe/~/cap_layer': {'c': False,
  'm_k': False,
  'r': False,
  'theta': False},
 'multinomial_dirichlet_believe/~/multinomial_layer': {'copy[theta(1)]': False,
  'phi': False}}

In [18]:
state_3 = at_i(3, train_state.state)
state_3['multinomial_dirichlet_believe/~/cap_layer']['c'].max()

Array(4.4796104, dtype=float32)

In [8]:
state_3['multinomial_dirichlet_believe/~/cap_layer']['m_k']

Array([     0,   1091,      0,      0,      0,      0,     56,      0,
            0,      0,      0,      0,      0,      0,   1185, 107987,
            0,      0,      1, 120304,      0,      0,  31981,      0,
            0,      0,      0,  12145,      0,   6561,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0], dtype=int32)

In [19]:
_, result_state = kernel_fn(train_state.params, state_3, next(key_seq))
_ = result_state['multinomial_dirichlet_believe/~/cap_layer']['r'].block_until_ready()
# for i in range(n_chains):
#     print(f'chain {i}')
#     state_i = at_i(i, train_state.state)
#     _, state = kernel_fn(train_state.params, state_i, next(key_seq))
#     _ = state['multinomial_dirichlet_believe/~/cap_layer']['r'].block_until_ready()

In [20]:
tree_map(lambda x: np.any(np.isnan(x)), result_state)

{'multinomial_dirichlet_believe/~/cap_layer': {'c': False,
  'm_k': False,
  'r': False,
  'theta': False},
 'multinomial_dirichlet_believe/~/multinomial_layer': {'copy[theta(1)]': False,
  'phi': False}}

: 

In [1]:
kmap = jax.pmap(kernel_fn, in_axes=(None, 0, 0))
keys = random.split(next(key_seq), n_chains)
_, state = kmap(train_state.params, train_state.state, keys)

NameError: name 'train_state' is not defined

In [6]:
train_state

TypeError: tuple indices must be integers or slices, not str