In [1]:
# ! pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# ! pip install tensorflow-probability
# ! pip install --upgrade tensorflow

In [1]:
# ! pip install --upgrade pandas yaml tensorflow tensorflow-probability ml-collections jraph

In [2]:
import sys, os
sys.path.append("../")

import pandas as pd
import yaml
import jax
import jax.numpy as np
import optax
import flax
from flax.core import FrozenDict
from flax.training import train_state, checkpoints
from ml_collections.config_dict import ConfigDict
import numpy as vnp
import matplotlib.pyplot as plt
import tensorflow as tf

# Ensure TF does not see GPU and grab all GPU memory
tf.config.set_visible_devices([], device_type='GPU')

from tqdm import tqdm, trange

replicate = flax.jax_utils.replicate
unreplicate = flax.jax_utils.unreplicate

from models.diffusion import VariationalDiffusionModel
from models.diffusion_utils import loss_vdm, sigma2, generate
from models.train_utils import create_input_iter, param_count, train_step
from datasets import load_data

EPS = 1e-7

%load_ext autoreload
%autoreload 2



## Dirs

In [3]:
data_dir = "/n/holyscratch01/iaifi_lab/ccuesta/data_for_sid/"

## Carol's runs
# logging_dir = "/n/holyscratch01/iaifi_lab/ccuesta/checkpoints/"

# run_name = "blooming-puddle-230"  # Has weird likelihoods
# run_name = "chocolate-cloud-122"  # The one in the paper
# run_name = "silver-breeze-231"  # Only rotations
# run_name = "hopeful-bush-231"  # Only translations

## My new runs
logging_dir = "/n/holystore01/LABS/iaifi_lab/Users/smsharma/set-diffuser/logging/cosmology-augmentations-guidance/"
# logging_dir = "/n/holystore01/LABS/iaifi_lab/Users/smsharma/set-diffuser/logging/cosmology-augmentations/"
# run_name = "efficient-firefly-111"  # No translations or rotations
# run_name = "worldly-voice-112"  # Only translations
# run_name = "stilted-oath-118" # Both; longer run
# run_name = "blooming-waterfall-120"  # Different config; larger batch size etc
# run_name = "solar-pond-123"  # k=100, batch size 16
# run_name = "treasured-resonance-125"
# run_name = "peach-gorge-130"
# run_name = "glowing-rain-139"  # Run with unconditional dropout
run_name = "toasty-dream-140"  # Run without unconditional dropout

## Load cluster run

In [4]:
config_file = "{}/{}/config.yaml".format(logging_dir, run_name)

with open(config_file, 'r') as file:
    config = yaml.safe_load(file)
    
config = ConfigDict(config)

In [5]:
# Load the dataset
train_ds, norm_dict = load_data(
        "nbody",
        3,
        5000,
        8,
        42,
        shuffle=True,
        split="train",
    )

batches = create_input_iter(train_ds)

In [8]:
x, conditioning, mask = next(batches)
x = x[0]
conditioning = conditioning[0]
mask = mask[0]

In [9]:
# Diffusion model
x_mean = tuple(map(float, norm_dict["mean"]))
x_std = tuple(map(float, norm_dict["std"]))
config.data.apply_pbcs = False
box_size = config.data.box_size if config.data.apply_pbcs else None
unit_cell = tuple(map(tuple, config.data.unit_cell)) if config.data.apply_pbcs else None

norm_dict_input = FrozenDict(
    {
        "x_mean": x_mean,
        "x_std": x_std,
        "box_size": box_size,
        "unit_cell": unit_cell,
    }
)

In [21]:
print("{} devices visible".format(jax.device_count()))

# Score and (optional) encoder model configs
score_dict = FrozenDict(config.score)
encoder_dict = FrozenDict(config.encoder)
decoder_dict = FrozenDict(config.decoder)

# Diffusion model
vdm = VariationalDiffusionModel(
        d_feature=config.data.n_features,
        timesteps=config.vdm.timesteps,
        noise_schedule=config.vdm.noise_schedule,
        noise_scale=config.vdm.noise_scale,
        d_t_embedding=config.vdm.d_t_embedding,
        gamma_min=config.vdm.gamma_min,
        gamma_max=config.vdm.gamma_max,
        score=config.score.score,
        score_dict=score_dict,
        embed_context=config.vdm.embed_context,
        d_context_embedding=config.vdm.d_context_embedding,
        n_classes=config.vdm.n_classes,
        use_encdec=config.vdm.use_encdec,
        encoder_dict=encoder_dict,
        decoder_dict=decoder_dict,
        norm_dict=norm_dict_input,
)

# Pass a test batch through to initialize model
x_batch, conditioning_batch, mask_batch = next(batches)
rng = jax.random.PRNGKey(42)
_, params = vdm.init_with_output({"sample": rng, "params": rng}, x_batch[0], conditioning_batch[0], mask_batch[0])

print(f"Params: {param_count(params):,}")

# Training config and state
schedule = optax.warmup_cosine_decay_schedule(init_value=0.0, peak_value=config.optim.learning_rate, warmup_steps=config.training.warmup_steps, decay_steps=config.training.n_train_steps)
tx = optax.adamw(learning_rate=schedule, weight_decay=config.optim.weight_decay)
state = train_state.TrainState.create(apply_fn=vdm.apply, params=params, tx=tx)

1 devices visible
Params: 535,101


In [22]:
# ckpt_dir = "{}/{}/".format(logging_dir, run_name)  # Load SLURM run
# restored_state = checkpoints.restore_checkpoint(ckpt_dir=ckpt_dir, target=state)

# if state is restored_state:
#     raise FileNotFoundError(f"Did not load checkpoint correctly")

In [23]:
@jax.jit
def get_loss(params, x, cond, mask):
    l1, l2, l3 = vdm.apply(params, x, cond, mask, rngs={"sample": rng, "params": rng})
    return l1.sum() + l2.sum() + l3.sum()

In [24]:
# get_loss(x_batch[0][:8], conditioning_batch[0][:8], mask_batch[0][:8])

# with jax.profiler.trace("./profile/jax-trace", create_perfetto_link=False, create_perfetto_trace=True):
#   get_loss(x_batch[0][:8], conditioning_batch[0][:8], mask_batch[0][:8])

In [25]:
n_nodes = 5000
n_batch = 4

In [26]:
%%timeit
jax.value_and_grad(get_loss)(params, x_batch[0][:n_batch, :n_nodes], conditioning_batch[0][:n_batch, :n_nodes], mask_batch[0][:n_batch, :n_nodes])

2023-08-23 18:06:38.476184: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:

  %fusion.367 = (pred[4,250000]{1,0}, s32[4,250000,2]{2,1,0}) fusion(s32[2]{0} %constant.1365, s32[4,250000]{1,0} %constant.2606, s32[4,125000]{1,0} %constant.2631, s32[4,125000]{1,0} %constant.2632), kind=kInput, calls=%fused_computation.367, metadata={op_name="jit(get_loss)/jit(main)/VariationalDiffusionModel/VariationalDiffusionModel.diffusion_loss/score_model/vmap(GraphConvNet_0)/jit(_take)/reduce_and[axes=(2,)]" source_file="/n/home11/smsharma/.local/lib/python3.10/site-packages/jraph/_src/models.py" source_line=178}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with env

169 ms ± 29.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [27]:
1000 / 74

13.513513513513514

In [28]:
from models.graph_utils import nearest_neighbors

In [30]:
# %%timeit
sources, targets, dist = jax.vmap(nearest_neighbors, in_axes=(0, None, 0))(x_batch[0][:n_batch, :n_nodes], 50, mask_batch[0][:n_batch, :n_nodes])

In [31]:
sources, targets, dist

(Array([[   0,    0,    0, ..., 4999, 4999, 4999],
        [   0,    0,    0, ..., 4999, 4999, 4999],
        [   0,    0,    0, ..., 4999, 4999, 4999],
        [   0,    0,    0, ..., 4999, 4999, 4999]], dtype=int32),
 Array([[   0,  193,  194, ..., 4102,   67, 4855],
        [   0,  530, 3426, ...,  686, 1499, 2872],
        [   0, 4500, 3226, ..., 3098, 3821, 4536],
        [   0, 4079,  685, ..., 4322,  384, 3414]], dtype=int32),
 Array([[0.        , 0.00088743, 0.01018956, ..., 0.17113404, 0.17649578,
         0.1828176 ],
        [0.        , 0.03095622, 0.03634575, ..., 0.18830384, 0.19302878,
         0.20054212],
        [0.        , 0.03018067, 0.03622621, ..., 0.17344593, 0.17392701,
         0.17541279],
        [0.        , 0.00714167, 0.01390084, ..., 0.1959911 , 0.1978073 ,
         0.19785741]], dtype=float32))

In [None]:
sources, targets, dist

In [18]:
get_loss(x_batch[0][:8], conditioning_batch[0][:8], mask_batch[0][:8])

jax.profiler.start_trace("./profile/tensorboard")

get_loss(x_batch[0][:8], conditioning_batch[0][:8], mask_batch[0][:8])

jax.profiler.stop_trace()


In [None]:
vdm.apply(restored_state.params, x, cond, mask, rngs={"sample": rng, "params": rng})