In [2]:
%load_ext autoreload
%autoreload 2

from typing import List, Dict

import os
import sys
import yaml
sys.path.append('.')
sys.path.append('..')

import time
from tqdm import tqdm
from pathlib import Path
from functools import partial

import numpy as np
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import matplotlib as mpl
import matplotlib.pyplot as plt
import datasets
import eval
from ml_collections.config_dict import ConfigDict

import models.diffusion
from models.diffusion_utils import generate
from models.train_utils import create_input_iter
from models.flows import nsf, maf

%matplotlib inline
plt.style.use('/mnt/home/tnguyen/default.mplstyle')

print(jax.devices())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
[cuda(id=0)]


In [3]:
root = '/mnt/ceph/users/tnguyen/dark_camels/point-cloud-diffusion-logging'
vdm_name = 'vdm/dazzling-leaf-75'
flows_name = 'flows/earthy-dream-105'

rng = jax.random.PRNGKey(42)
steps = 500
batch_size = 256
num_particles = 100
num_features = 8
num_repeats = 100
conditioning_parameters = ['halo_mvir', 'inv_wdm_mass', 'log_sn1', 'log_sn2', 'log_agn1']
dataset_root = '/mnt/ceph/users/tnguyen/dark_camels/point-cloud-diffusion-datasets/processed_datasets/'
dataset_name = 'mw_zooms-wdm-dmprop/nmax100-vmaxtilde-pad'

In [4]:
# read in the dataset
x, mask, conditioning, norm_dict = datasets.get_nbody_data(
    dataset_root, dataset_name, num_features, num_particles,
    conditioning_parameters=conditioning_parameters
)
# unnormalized xb
x = x * norm_dict['std'] + norm_dict['mean']

In [5]:
path_to_vdm = Path(os.path.join(root, vdm_name))
path_to_flows = Path(os.path.join(root, flows_name))

# load the vdm and the flows
vdm, vdm_params = models.diffusion.VariationalDiffusionModel.from_path_to_model(
    path_to_model=path_to_vdm, norm_dict=norm_dict)
flows, flows_params = nsf.NeuralSplineFlow.from_path_to_model(
    path_to_model=path_to_flows)



In [6]:
@partial(jax.vmap, in_axes=(0, None, 0))
def sample_from_flow(context, n_samples=10_000, key=jax.random.PRNGKey(42)):
    """Helper function to sample from the flow model.
    """
    def sample_fn(flows):
        x_samples = flows.sample(
            num_samples=n_samples, rng=key, 
            context=context * jnp.ones((n_samples, 1)))
        return x_samples

    x_samples = nn.apply(sample_fn, flows)(flows_params)
    return x_samples

@partial(jax.vmap, in_axes=(0, None))
def create_mask(n, num_particles):
    # Create an array [0, 1, 2, ..., num_particles-1]
    indices = jnp.arange(num_particles)
    # Compare each index to n, resulting in True (1) if index < n, else False (0)
    mask = indices < n
    return mask.astype(jnp.float32)

In [7]:
# iterate over the dataset and generate samples
truth_samples = []
gen_samples = []
gen_cond = []
gen_mask = []

dset = datasets.make_dataloader(
    x, conditioning, mask, batch_size=batch_size, shuffle=False, repeat=False)
dset = create_input_iter(dset)

for batch in tqdm(dset):
    x_batch, cond_batch, mask_batch = batch[0], batch[1], batch[2]
    x_batch = jnp.repeat(x_batch[0], num_repeats, axis=0)
    cond_batch = jnp.repeat(cond_batch[0], num_repeats, axis=0)
    mask_batch = jnp.repeat(mask_batch[0], num_repeats, axis=0)

    # generate the number of particles using the flows
    num_subhalos = 10**sample_from_flow(
        cond_batch, 1, jax.random.split(rng, len(cond_batch))).squeeze()
    num_subhalos = jnp.clip(num_subhalos, 1, num_particles)
    num_subhalos = jnp.round(num_subhalos).astype(jnp.int32)    
    mask_batch = create_mask(num_subhalos, num_particles)

    # generate using the VDM
    gen_samples.append(
        eval.generate_samples(
            vdm=vdm,
            params=vdm_params,
            rng=rng,
            n_samples=len(cond_batch),
            n_particles=num_particles,
            conditioning=cond_batch,
            mask=mask_batch,
            steps=steps,
            norm_dict=norm_dict,
            boxsize=1,  # doesn't matter
        )
    )
    gen_cond.append(cond_batch)
    gen_mask.append(mask_batch)
    truth_samples.append(x_batch)

gen_samples = jnp.concatenate(gen_samples, axis=0)
gen_cond = jnp.concatenate(gen_cond, axis=0)
gen_mask = jnp.concatenate(gen_mask, axis=0)
truth_samples = jnp.concatenate(truth_samples, axis=0)

0it [00:00, ?it/s]

In [None]:
# Save the samples
out_root = '/mnt/home/tnguyen/ceph/dark_camels/point-cloud-diffusion-outputs'
out_path = os.path.join(
    out_root, f'vdm-flows/{os.path.basename(vdm_name)}_{os.path.basename(flows_name)}.npz')
np.savez(
    out_path, samples=gen_samples, cond=gen_cond, 
    mask=gen_mask, truth=truth_samples)