In [1]:
%load_ext autoreload
%autoreload 2

from typing import List, Dict

import sys
import yaml
sys.path.append('.')
sys.path.append('..')

import time
from tqdm import tqdm
from pathlib import Path

import numpy as np
import jax
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import datasets
import eval
from ml_collections.config_dict import ConfigDict

import models.diffusion
from models.diffusion_utils import generate
from models.train_utils import create_input_iter
from analysis_utils import envs

%matplotlib inline
plt.style.use('/mnt/home/tnguyen/default.mplstyle')

2024-02-25 20:56:58.012550: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-25 20:56:58.012585: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-25 20:56:58.013783: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
model_name = 'wdm/floral-wildflower-72'
rng = jax.random.PRNGKey(42)
steps = 500
batch_size = 128
num_particles = 100
num_repeats = 1
dataset_root = None
dataset_name = None

In [4]:
# load the config
path_to_model = envs.DEFAULT_LOGGING_DIR / model_name
with open(path_to_model / "config.yaml", "r") as file:
    config = yaml.safe_load(file)
config = ConfigDict(config)

if dataset_root is None:
    dataset_root = config.data.dataset_root
if dataset_name is None:
    dataset_name = config.data.dataset_name

# read in the dataset
x, mask, conditioning, norm_dict = datasets.get_nbody_data(
    dataset_root,
    dataset_name,
    config.data.n_features,
    config.data.n_particles,
    conditioning_parameters=config.data.conditioning_parameters,
)

# unnormalized xb
x = x * norm_dict['std'] + norm_dict['mean']

# load the model
vdm, params = models.diffusion.VariationalDiffusionModel.from_path_to_model(
    path_to_model=path_to_model, norm_dict=norm_dict)



In [5]:
# iterate over the dataset and generate samples
truth_samples = []
gen_samples = []
gen_cond = []
gen_mask = []

dset = datasets.make_dataloader(
    x, conditioning, mask, batch_size=batch_size, shuffle=False, repeat=False)
dset = create_input_iter(dset)

In [11]:


for batch in tqdm(dset):
    x_batch, cond_batch, mask_batch = batch[0], batch[1], batch[2]
    x_batch = jnp.repeat(x_batch[0], num_repeats, axis=0)
    cond_batch = jnp.repeat(cond_batch[0], num_repeats, axis=0)
    mask_batch = jnp.repeat(mask_batch[0], num_repeats, axis=0)

    gen_samples.append(
        eval.generate_samples(
            vdm=vdm,
            params=params,
            rng=rng,
            n_samples=len(cond_batch),
            n_particles=num_particles,
            conditioning=cond_batch,
            mask=mask_batch,
            steps=steps,
            norm_dict=norm_dict,
            boxsize=1,  # doesn't matter
        )
    )
    gen_cond.append(cond_batch)
    gen_mask.append(mask_batch)
    truth_samples.append(x_batch)

gen_samples = jnp.concatenate(gen_samples, axis=0)
gen_cond = jnp.concatenate(gen_cond, axis=0)
gen_mask = jnp.concatenate(gen_mask, axis=0)
truth_samples = jnp.concatenate(truth_samples, axis=0)

2024-02-25 16:55:57.480874: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5it [27:22, 328.59s/it]


KeyboardInterrupt: 

In [12]:
len(gen_cond)

In [None]:
# Save the samples
out_path = envs.DEFAULT_OUTPUT_DIR / (model_name + '.npz')
np.savez(
    out_path, samples=gen_samples, cond=gen_cond, 
    mask=gen_mask, truth=truth_samples)