In [None]:
import numpy as np
from lib.models.diffusion_model import CategoricalDiffusionModel
from lib.config.config_mnist import get_config
import lib.optimizer.optimizer as optim
import lib.utils.bookkeeping as bookkeeping
import lib.datasets.datasets_utils as datasets_utils
from lib.datasets.datasets import get_dataloader

import jax
import jax.numpy as jnp
from tqdm import tqdm
import os


In [None]:
# creating paths
path = 'SavedModels/MNIST'
date = '2023-10-07'
config_name = 'config_001.yaml'
model_name = 'checkpoint_600'

config_path = os.path.join(path, date, config_name)
checkpoint_path = os.path.join(path, date, model_name)

In [None]:
config = bookkeeping.load_config(config_path)
train_ds = datasets_utils.numpy_iter(get_dataloader(config, "train"))

model = CategoricalDiffusionModel(config)

global_key = jax.random.PRNGKey(11)
train_key, model_key, sample_key = jax.random.split(global_key, 3)

state = model.init_state(model_key)
state = bookkeeping.load_model(checkpoint_path, state)
step = state.step

In [None]:
n_samples = 16
sample_key, sub_sample_key = jax.random.split(sample_key)
process_sample_rng_key = jax.random.fold_in(sub_sample_key, jax.process_index())

samples = model.sample_loop(state, process_sample_rng_key, n_samples, conditioner=None)
#  samples = utils.all_gather(samples)
samples = jnp.reshape(samples, (n_samples, config.image_size, config.image_size, 1))
saving_plot_path = os.path.join(config.sample_plot_path, f"samples_epoch_eval{step}.png")
datasets_utils.plot_mnist_batch(samples, saving_plot_path)