In [None]:
from copy import deepcopy
from functools import partial
import os
from pathlib import Path
import pickle
from typing import Literal, NamedTuple

from gabenet.nets import MultinomialDirichletBelieve, PoissonGammaBelieve
from gabenet._surgery import copy_to_larger_net
from gabenet.utils import get_layer_name, get_model_name
import haiku as hk
import jax
from jax import random
from jaxtyping import PyTree

from dataset import load_mutation_spectrum, COSMIC_WEIGHTS


ARTEFACT_DIR = Path(os.environ.get("ARTEFACT_DIR", "/mnt/output/"))
ARTEFACT_DIR.mkdir(parents=True, exist_ok=True)


# Model hyperparameters.
MODEL: Literal[
    "multinomial_dirichlet_believe", "poisson_gamma_believe"
] = "poisson_gamma_believe"
n_topics = len(COSMIC_WEIGHTS)
HIDDEN_LAYER_SIZES = [n_topics, n_topics]
GAMMA_0 = 10.0
# Print out model hyperparameters for logging.
print(f"MODEL = {MODEL}")
print(f"n_topics = {n_topics}")
print(f"HIDDEN_LAYER_SIZES = {HIDDEN_LAYER_SIZES}")
print(f"GAMMA_0 = {GAMMA_0}")

n_features = 96
X_train, X_test = load_mutation_spectrum()

i_checkpoint = 0


class TrainState(NamedTuple):
    params: hk.Params
    state: hk.State
    key: jax.Array  # type: ignore
    step: int

In [None]:
@hk.transform_with_state
def kernel(X=X_train, freeze_phi=True):
    """Advance the Markov chain by one step."""
    if MODEL == "multinomial_dirichlet_believe":
        model = MultinomialDirichletBelieve(
            HIDDEN_LAYER_SIZES, n_features, gamma_0=GAMMA_0
        )
    else:
        model = PoissonGammaBelieve(HIDDEN_LAYER_SIZES, n_features, gamma_0=GAMMA_0)
    if freeze_phi:
        model.layers.layers[-1].set_training(False)
    # Do one Gibbs sampling step.
    model(X)

In [None]:
def load_last_checkpoint(hidden_layer_sizes) -> TrainState | None:
    """Load last state from disk."""
    global i_checkpoint
    # List all pickle files, sort by number and load last one.
    architecture = '-'.join(map(str, hidden_layer_sizes))
    checkpoint_dir = ARTEFACT_DIR / MODEL / architecture / "checkpoints"
    files = sorted(checkpoint_dir.glob("checkpoint_*.pkl"))
    if len(files) == 0:
        print("No checkpoints found.")
        return None
    with open(files[-1], "rb") as fi:
        state = pickle.load(fi)
    i_checkpoint = int(files[-1].stem.split("_")[-1])
    print(f"Loaded checkpoint i={i_checkpoint}.")
    i_checkpoint += 1
    return state

In [None]:
def migrate_v005_to_v006(state: PyTree) -> PyTree:
    """Migrate theta cache name from v0.0.5 to v0.0.6."""
    model_name = get_model_name(state)
    bottom = get_layer_name("bottom", model_name)
    new_state = deepcopy(state)
    for layer_key, layer in state.items():
        if layer_key.endswith(bottom):
            if "theta" in layer:
                new_state[layer_key]["copy[theta(1)]"] = new_state[layer_key].pop(
                    "theta"
                )
        elif "theta_tplus1" in layer:
            new_state[layer_key]["copy[theta(t+1)]"] = new_state[layer_key].pop(
                "theta_tplus1"
            )
    return new_state

In [None]:
train_state = load_last_checkpoint(hidden_layer_sizes=[n_topics])
state_source = train_state.state
params_source = train_state.params
key_source = train_state.key
key_seq = hk.PRNGSequence(key_source)
keys = random.split(next(key_seq), 4)
init_fn = partial(kernel.init, freeze_phi=False)
params_target, state_target = jax.pmap(init_fn)(keys)

In [None]:
migrated_state_source = migrate_v005_to_v006(state_source)
new_params, new_state = copy_to_larger_net(
    params_source, migrated_state_source, params_target, state_target
)
train_state = TrainState(params=new_params, state=new_state, key=next(key_seq), step=0)

In [None]:
architecture = '-'.join(map(str, HIDDEN_LAYER_SIZES))
target_dir = ARTEFACT_DIR / MODEL / architecture / "checkpoints"
target_dir.mkdir(parents=True, exist_ok=True)
with open( target_dir/ "checkpoint_0000.pkl", "wb") as fo:
    pickle.dump(train_state, fo)