In [2]:
import sys
sys.path.append("../")

import jax
import jax.numpy as np

from flax.core import FrozenDict

from models.diffusion import VariationalDiffusionModel
from models.diffusion_utils import generate, loss_vdm

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [32]:
# Transformer args
transformer_dict = FrozenDict({"d_model":256, "d_mlp":512, "n_layers":16, "n_heads":4, "induced_attention":False, "n_inducing_points":300})

# Instantiate model
vdm = VariationalDiffusionModel(gamma_min=-6.0, gamma_max=6.0,  # Min and max initial log-SNR in the noise schedule
          d_feature=4,  # Number of features per set element
          transformer_dict=transformer_dict,  # Score-prediction transformer parameters
          noise_schedule="learned_linear",  # Noise schedule; "learned_linear" or "scalar"
          n_layers=3,  # Layers in encoder/decoder element-wise ResNets
          d_embedding=8,  # Dim to encode the per-element features to
          d_hidden_encoding=64,  # Hidden dim used in various contexts (for embedding context, 4 * for encoding/decoding in ResNets)
          timesteps=300,  # Number of diffusion steps
          d_t_embedding=16,  # Timestep embedding dimension
          noise_scale=1e-3,  # Data noise model
          n_classes=0)  # Number of data classes. If >0, the first element of the conditioning vector is assumed to be the integer class.

In [33]:
n_points = 5000

rng = jax.random.PRNGKey(42)

x = jax.random.normal(rng, (4, n_points, 4))
mask = jax.random.randint(rng, (4, n_points), 0, 2)
conditioning = jax.random.normal(rng, (4, 6))

# Call to get losses
(loss_diff, loss_klz, loss_recon), params = vdm.init_with_output({"sample": rng, "params": rng, "uncond":rng}, x, conditioning, mask);

In [34]:
# Compute full loss, accounting for masking
loss_vdm(params, vdm, rng, x, conditioning, mask)

Array(5328009., dtype=float32)

In [35]:
# Sample from model

mask_sample = jax.random.randint(rng, (4, n_points), 0, 2)
conditionink_sample = jax.random.normal(rng, (4, 6))

x_samples = generate(vdm, params, rng, (4, n_points), conditionink_sample, mask_sample)
x_samples.mean().shape  # Mean of decoded Normal distribution

(4, 5000, 4)

In [31]:
%%timeit
jax.jit(loss_vdm, static_argnums=1)(params, vdm, rng, x, conditioning, mask)

23 ms ± 4.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [37]:
%%timeit
jax.jit(loss_vdm, static_argnums=1)(params, vdm, rng, x, conditioning, mask)

148 ms ± 87.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
