In [None]:
import numpy as np
from lib.models.diffusion_model import CategoricalDiffusionModel
from lib.config.config_mnist import get_config
from lib.utils.utils import init_state
import lib.models.model_utils as model_utils
import lib.optimizer.optimizer as optim
import flax
from flax import jax_utils
import jax
import jax.numpy as jnp

In [None]:
config = get_config()

optimizer = optim.build_optimizer(config)
backwd_model = model_utils.build_backwd_model(config)
fwd_model = model_utils.build_fwd_model(config)

model = CategoricalDiffusionModel(config, fwd_model, backwd_model, optimizer)


global_key = jax.random.PRNGKey(42)
global_key, model_key = jax.random.split(global_key, 2)

# struct with step, params state, optimizer state, ema state
state = init_state(config, model, model_key)

init_step = state.step

# replicate state over several devices: if one nothing happens
state = flax.jax_utils.replicate(state)
# every device, process, got different rng key
process_rng_key = jax.random.fold_in(global_key, jax.process_index())
# functions over several devices
#train_step_fn = jax.pmap(train_step_fn, axis_name="shard")

train_step_fn = model.training_step
train_step_fn = jax.jit(train_step_fn)
lr_schedule = optim.build_lr_schedule(config)

In [None]:



for step in range(init_step + 1, config.total_train_steps + 1):
    
    batch = fn_data_preprocess(next(train_ds))
    process_rng_key = jax.random.fold_in(process_rng_key, step)
    step_rng_keys = utils.shard_prng_key(process_rng_key)
    state, aux = train_step_fn(state, step_rng_keys, batch)

    if step % config.log_every_steps == 0:
        aux = jax.device_get(flax.jax_utils.unreplicate(aux))
        aux["train/lr"] = lr_schedule(step)

    if step % config.plot_every_steps == 0 and fn_eval is not None:
        metric = fn_eval(step, state, process_rng_key)


    if step % config.save_every_steps == 0:
        save_model(state, step)