In [11]:
%load_ext autoreload
%autoreload 2

from typing import List, Dict

import sys
import yaml
sys.path.append('.')
sys.path.append('..')

import time
from tqdm import tqdm
from pathlib import Path

import numpy as np
import jax
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import datasets
import eval
import models.diffusion
from models.diffusion_utils import generate
from models.train_utils import create_input_iter

from ml_collections.config_dict import ConfigDict

%matplotlib inline
plt.style.use('/mnt/home/tnguyen/default.mplstyle')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
path_to_model = Path("/mnt/ceph/users/tnguyen/dark_camels/point-cloud-diffusion-logging/"\
    "cosmology/effortless-meadow-54/")

with open(path_to_model / "config.yaml", "r") as file:
    config = yaml.safe_load(file)
config = ConfigDict(config)

train_ds, norm_dict = datasets.load_data(
    config.data.dataset,
    config.data.dataset_root,
    config.data.dataset_name,
    config.data.n_features,
    config.data.n_particles,
    config.training.batch_size,
    config.seed,
    shuffle=True,
    repeat=False,
    split="train",
    conditioning_parameters=config.data.conditioning_parameters,
)

vdm, params = models.diffusion.VariationalDiffusionModel.from_path_to_model(
    path_to_model=path_to_model, norm_dict=norm_dict)



In [14]:
rng = jax.random.PRNGKey(42)
batch_size = config.training.batch_size
n_particles = config.data.n_particles
steps = 500
boxsize = 1

batches = create_input_iter(train_ds)

true_samples = []
generated_samples = []
conditioning_samples = []

for i, batch in enumerate(batches):
    t0 = time.time()
    x_batch, conditioning_batch, mask_batch = batch
    true_samples.append(x_batch[0] * norm_dict["std"] + norm_dict["mean"])
    generated_samples.append(
        eval.generate_samples(
            vdm=vdm,
            params=params,
            rng=rng,
            n_samples=len(conditioning_batch[0]),
            n_particles=n_particles,
            conditioning=conditioning_batch[0],
            mask=mask_batch[0],
            steps=steps,
            norm_dict=norm_dict,
            boxsize=boxsize,
        )
    )
    conditioning_samples.append(conditioning_batch[0])
    print(f"Iteration {i} takes {time.time() - t0} seconds")

# convert back to Numpy arrays because the person who wrote this code hates JAX
true_samples = np.array(true_samples)
generated_samples = np.array(generated_samples)
conditioning_samples = np.array(conditioning_samples)

Iteration 0 takes 4.399191856384277 seconds
Iteration 1 takes 4.00584077835083 seconds
Iteration 2 takes 2.6433539390563965 seconds


  true_samples = np.array(true_samples)
  generated_samples = np.array(generated_samples)
  conditioning_samples = np.array(conditioning_samples)


In [15]:
true_samples = np.vstack(true_samples)
generated_samples = np.vstack(generated_samples)
conditioning_samples = np.vstack(conditioning_samples)

# save the samples
np.save(path_to_model / "true_samples.npy", true_samples)
np.save(path_to_model / "generated_samples.npy", generated_samples)
np.save(path_to_model / "conditioning_samples.npy", conditioning_samples)

del vdm # free up memory