In [None]:
import numpy as np
from lib.models.diffusion_model import CategoricalDiffusionModel
from lib.config.config_mnist import get_config
import lib.models.model_utils as model_utils
import lib.optimizer.optimizer as optim
import lib.networks.networks_utils as networks_utils
import lib.utils.bookkeeping as bookkeeping
import lib.datasets.datasets_utils as datasets_utils
from lib.datasets.datasets import get_dataloader
import lib.utils.utils as utils
import flax
import jax
import jax.numpy as jnp
from tqdm import tqdm

In [None]:
resume_train = False
config = get_config()

# changed to see better what i need 
optimizer = optim.build_optimizer(config)
fwd_model = model_utils.build_fwd_model(config)
net = networks_utils.build_network(config)
backwd_model = model_utils.build_backwd_model(config, fwd_model, net)

model = CategoricalDiffusionModel(config, fwd_model, backwd_model, optimizer)

global_key = jax.random.PRNGKey(42)
global_key, model_key = jax.random.split(global_key, 2)

# struct with step, params state, optimizer state, ema state
state = model_utils.init_state(config, model, model_key)
print("Number of parameters:", sum(x.size for x in jax.tree_leaves(state.params)))


if resume_train:
    load_dir = ''

init_step = state.step
#state = flax.jax_utils.replicate(state)


# replicate state over several devices: if one nothing happens
# every device, process, got different rng key
process_rng_key = jax.random.fold_in(global_key, jax.process_index())
# functions over several devices
#train_step_fn = jax.pmap(train_step_fn, axis_name="shard")

train_step_fn = model.training_step
train_step_fn = jax.jit(train_step_fn)
lr_schedule = optim.build_lr_schedule(config)

num_samples = 16

train_ds = datasets_utils.numpy_iter(get_dataloader(config, "train"))

In [None]:

for step in tqdm(range(init_step + 1, config.total_train_steps + 1)):
    print("Iteration:", step + 1)

    #batch = fn_data_preprocess(next(train_ds))
    batch = next(train_ds)
    
    process_rng_key = jax.random.fold_in(process_rng_key, step)
    # for cpu: step_rng_keys = process_rng_key
    step_rng_keys = utils.shard_prng_key(process_rng_key)
    
    state, aux = train_step_fn(state, step_rng_keys, batch)

    #if step % config.log_every_steps == 0:
    #    aux = jax.device_get(flax.jax_utils.unreplicate(aux))
    #    aux["train/lr"] = lr_schedule(step)

    if step % config.sample_every_steps == 0:
        x0 = model.sample_loop(state, process_rng_key, num_samples, conditioner=None)
        #  x0 = utils.all_gather(x0)
        
    if step % config.save_every_steps == 0:
        bookkeeping.save_model(state, step)