In [None]:
from copy import deepcopy
from functools import partial
import os
from pathlib import Path
import pickle
from typing import Literal, NamedTuple

from gabenet.nets import MultinomialDirichletBelieve, PoissonGammaBelieve
from gabenet._surgery import determine_number_to_prune, prune_network
from gabenet.utils import get_hidden_unit_sizes
import haiku as hk
import jax
from jax.tree_util import tree_map
import jax.numpy as jnp
from jax import random
from jaxtyping import PyTree

from dataset import COSMIC_WEIGHTS


ARTEFACT_DIR = Path(os.environ.get("ARTEFACT_DIR", "/mnt/output/"))
ARTEFACT_DIR.mkdir(parents=True, exist_ok=True)


# Model hyperparameters.
MODEL: Literal[
    "multinomial_dirichlet_believe", "poisson_gamma_believe"
] = "multinomial_dirichlet_believe"
n_topics = len(COSMIC_WEIGHTS)
HIDDEN_LAYER_SIZES = [40, n_topics]
# Print out model hyperparameters for logging.
print(f"MODEL = {MODEL}")
print(f"n_topics = {n_topics}")
print(f"HIDDEN_LAYER_SIZES = {HIDDEN_LAYER_SIZES}")


i_checkpoint = 0


class TrainState(NamedTuple):
    params: hk.Params
    state: hk.State
    key: jax.Array  # type: ignore
    step: int

In [None]:
def load_last_checkpoint(hidden_layer_sizes) -> TrainState | None:
    """Load last state from disk."""
    global i_checkpoint
    # List all pickle files, sort by number and load last one.
    architecture = '-'.join(map(str, hidden_layer_sizes))
    checkpoint_dir = ARTEFACT_DIR / MODEL / architecture / "checkpoints"
    files = sorted(checkpoint_dir.glob("checkpoint_*.pkl"))
    if len(files) == 0:
        print("No checkpoints found.")
        return None
    with open(files[-1], "rb") as fi:
        state = pickle.load(fi)
    i_checkpoint = int(files[-1].stem.split("_")[-1])
    print(f"Loaded checkpoint i={i_checkpoint}.")
    i_checkpoint += 1
    return state

In [None]:
n_chains = 4
train_state = load_last_checkpoint(hidden_layer_sizes=HIDDEN_LAYER_SIZES)
state_source = train_state.state
params_source = train_state.params
key_source = train_state.key
key_seq = hk.PRNGSequence(key_source)

In [None]:
tree_map(jnp.shape, state_source)

In [None]:
states_pruned = []
n_prune = determine_number_to_prune(state_source)
for chain in range(n_chains):
    state_i = jax.tree_map(lambda x: x[chain], state_source)
    state_pruned = prune_network(state_i, n_prune)
    states_pruned.append(state_pruned)
state_pruned = jax.tree_map(lambda *s: jnp.stack(s), *states_pruned)
NEW_HIDDEN_LAYER_SIZES = get_hidden_unit_sizes(state_pruned)

In [None]:
tree_map(jnp.shape, state_pruned)

In [None]:
train_state = TrainState(params=params_source, state=state_pruned, key=next(key_seq), step=0)

In [None]:
from dataset import load_mutation_spectrum, COSMIC_WEIGHTS

n_features = 96
X_train, X_test = load_mutation_spectrum()
GAMMA_0 = 10.0



@hk.transform_with_state
def kernel(X=X_train, freeze_phi=True):
    """Advance the Markov chain by one step."""
    if MODEL == "multinomial_dirichlet_believe":
        model = MultinomialDirichletBelieve(
            NEW_HIDDEN_LAYER_SIZES, n_features, gamma_0=GAMMA_0
        )
    else:
        model = PoissonGammaBelieve(NEW_HIDDEN_LAYER_SIZES, n_features, gamma_0=GAMMA_0)
    if freeze_phi:
        model.layers.layers[-1].set_training(False)
    # Do one Gibbs sampling step.
    model(X)

# Test new network config.
keys = random.split(next(key_seq), num=n_chains)
kernel_fn = jax.pmap(kernel.apply, in_axes=(None, 0, 0))
_, new_state = kernel_fn(params_source, state_pruned, keys)

In [None]:
NEW_HIDDEN_LAYER_SIZES = get_hidden_unit_sizes(state_pruned)
architecture = '-'.join(map(str, NEW_HIDDEN_LAYER_SIZES))
target_dir = ARTEFACT_DIR / MODEL / architecture / "checkpoints"
target_dir.mkdir(parents=True, exist_ok=True)
with open( target_dir/ "checkpoint_0000.pkl", "wb") as fo:
    pickle.dump(train_state, fo)