In [1]:
%load_ext autoreload
%autoreload 2
import numpyro
import numpyro.distributions as dist
import jax.numpy as np
import numpy as onp
import jax
from jax.flatten_util import ravel_pytree
import argparse
import boundingmachine as bm
import mcdboundingmachine as mcdbm
import opt
from model_handler import load_model
import pickle
import ml_collections.config_flags
import wandb
from absl import app, flags
from utils import flatten_nested_dict, update_config_dict, setup_training, make_grid, W2_distance
from jax import scipy as jscipy
from configs.base import LR_DICT, get_config

import os
import ot
# Set XLA_PYTHON_CLIENT_PREALLOCATE flag
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# tf.config.experimental.set_visible_devices([], "GPU")
# python main.py --config.model funnel --config.boundmode MCD_ULA

# Boundmodes
# 	- ULA uses MCD_ULA
# 	- MCD uses MCD_ULA_sn
#	- UHA uses UHA
# 	- LDVI uses MCD_U_a-lp-sn
#   - CAIS uses MCD_CAIS_sn
#   - CAIS_UHA uses MCD_CAIS_UHA_sn

config = get_config()

config.model = "funnel"
config.boundmode = "MCD_ULA"

2023-09-22 18:23:45.751385: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-22 18:23:45.751454: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-09-22 18:23:48.653728: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:433] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.


In [None]:
print(config)

config.pretrain_mfvi = False
config.iters = 1000


wandb.init()
iters_base=config.iters
log_prob_model, dim, sample_from_target_fn = load_model(config.model, config)
rng_key_gen = jax.random.PRNGKey(config.seed)

train_rng_key_gen, eval_rng_key_gen = jax.random.split(rng_key_gen)

# Train initial variational distribution to maximize the ELBO
trainable=('vd',)
params_flat, unflatten, params_fixed = bm.initialize(dim=dim, nbridges=0, trainable=trainable)


grad_and_loss = jax.jit(jax.grad(bm.compute_bound, 1, has_aux = True), static_argnums = (2, 3, 4))
if not config.pretrain_mfvi:
    mfvi_iters = 1
    vdparams_init = unflatten(params_flat)[0]['vd']
else:
    mfvi_iters = config.mfvi_iters
    losses, diverged, params_flat, tracker = opt.run(
        config, config.mfvi_lr, mfvi_iters, params_flat, unflatten, params_fixed,
        log_prob_model, grad_and_loss, trainable, train_rng_key_gen, log_prefix='pretrain')
    vdparams_init = unflatten(params_flat)[0]['vd']

    elbo_init = -np.mean(np.array(losses[-500:]))
    print('Done training initial parameters, got ELBO %.2f.' % elbo_init)
    wandb.log({'elbo_init': onp.array(elbo_init)})

if config.boundmode == 'UHA':
    trainable = ('eta', 'mgridref_y')
    if config.train_eps:
        trainable = trainable + ('eps',)
    if config.train_vi:
        trainable = trainable + ('vd',)
    params_flat, unflatten, params_fixed = bm.initialize(dim=dim, nbridges=config.nbridges, eta=config.init_eta, eps = config.init_eps,
        lfsteps=config.lfsteps, vdparams=vdparams_init, trainable=trainable)
    grad_and_loss = jax.jit(jax.grad(bm.compute_bound, 1, has_aux = True), static_argnums = (2, 3, 4))

    loss_fn = jax.jit(bm.compute_bound, static_argnums = (2, 3, 4))

elif 'MCD' in config.boundmode:
    trainable = ('eta', 'gamma', 'mgridref_y')
    if config.train_eps:
        trainable = trainable + ('eps',)
    if config.train_vi:
        trainable = trainable + ('vd',)
    
    print(trainable)
    params_flat, unflatten, params_fixed = mcdbm.initialize(dim=dim, nbridges=config.nbridges, vdparams=vdparams_init, eta=config.init_eta, eps = config.init_eps,
        trainable=trainable, mode=config.boundmode)
    grad_and_loss = jax.jit(jax.grad(mcdbm.compute_bound, 1, has_aux = True), static_argnums = (2, 3, 4))

    loss_fn = jax.jit(mcdbm.compute_bound, static_argnums = (2, 3, 4))

else:
    raise NotImplementedError('Mode %s not implemented.' % config.boundmode)

# Average over 30 seeds, 500 samples each after training is done.
n_samples = config.n_samples
n_input_dist_seeds = config.n_input_dist_seeds

target_samples = sample_from_target_fn(jax.random.PRNGKey(1), n_samples * n_input_dist_seeds)

losses, diverged, params_flat, tracker = opt.run(config, config.lr, config.iters, params_flat, unflatten, params_fixed, log_prob_model, grad_and_loss,
    trainable, train_rng_key_gen, log_prefix='train', target_samples=target_samples)


eval_losses, samples = opt.sample(
    config, n_samples, n_input_dist_seeds, params_flat, unflatten, params_fixed, log_prob_model, loss_fn,
    eval_rng_key_gen, log_prefix='eval')

# (n_input_dist_seeds, n_samples)
eval_losses = np.array(eval_losses)

# Calculate mean and std of ELBOs over 30 seeds
final_elbos = -np.mean(eval_losses, axis=1)
final_elbo = np.mean(final_elbos)
final_elbo_std = np.std(final_elbos)

# Calculate mean and std of log Zs over 30 seeds
ln_numsamp = np.log(n_samples)

final_ln_Zs = jscipy.special.logsumexp(-np.array(eval_losses), axis=1)  - ln_numsamp

final_ln_Z = np.mean(final_ln_Zs)
final_ln_Z_std = np.std(final_ln_Zs)

print('Done training, got ELBO %.2f.' % final_elbo)
print('Done training, got ln Z %.2f.' % final_ln_Z)

wandb.log({
    'elbo_final': onp.array(final_elbo),
    'final_ln_Z': onp.array(final_ln_Z),
    'elbo_final_std': onp.array(final_elbo_std),
    'final_ln_Z_std': onp.array(final_ln_Z_std)
    })

# Plot samples
if config.model in ["nice", "funnel"]:
    other_target_samples = sample_from_target_fn(jax.random.PRNGKey(2), samples.shape[0])

    w2_dists, self_w2_dists = [], []
    for i in range(n_input_dist_seeds):
        
        samples_i = samples[i * n_samples : (i + 1) * n_samples, ...]
        target_samples_i = target_samples[i * n_samples : (i + 1) * n_samples, ...]
        other_target_samples_i = other_target_samples[i * n_samples : (i + 1) * n_samples, ...]
        w2_dists.append(W2_distance(samples_i, target_samples_i))
        self_w2_dists.append(W2_distance(target_samples_i, other_target_samples_i))

    if config.model == "nice":
        make_grid(samples, config.im_size, n=64, wandb_prefix="images/sample")
    
    wandb.log({"w2_dist": onp.mean(onp.array(w2_dists)),
                "w2_dist_std": onp.std(onp.array(w2_dists)),
                "self_w2_dist": onp.mean(onp.array(self_w2_dists)),	
                "self_w2_dist_std": onp.std(onp.array(self_w2_dists))})

params_train, params_notrain = unflatten(params_flat)
params = {**params_train, **params_notrain}

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


N: 5
alpha: 0.05
boundmode: MCD_ULA
funnel_clipy: 11
funnel_d: 10
funnel_sig: 3
hidden_dim: 1000
id: -1
im_size: 14
init_eps: 1.0e-05
init_eta: 0.0
iters: 150000
lfsteps: 1
lr: 0.0001
mfvi_iters: 15000
mfvi_lr: 0.0001
model: funnel
n_bits: 3
n_input_dist_seeds: 30
n_samples: 500
nbridges: 8
pretrain_mfvi: true
run_cluster: 0
seed: 1
train_eps: true
train_vi: true
wandb:
  code_dir: /home/sp2058/CAIS/src
  entity: shreyaspadhy
  log: true
  log_artifact: true
  name: ''
  project: cais



[34m[1mwandb[0m: Currently logged in as: [33mshreyaspadhy[0m. Use [1m`wandb login --relogin`[0m to force relogin


('eta', 'gamma', 'mgridref_y', 'eps', 'vd')
No score network needed by the method.


100%|██████████| 1000/1000 [00:17<00:00, 58.42it/s]


Done training, got ELBO -2.22.
Done training, got ln Z -0.84.


In [20]:
other_target_samples = sample_from_target_fn(jax.random.PRNGKey(2), samples.shape[0])

import matplotlib.pyplot as plt
from utils import sinkhorn_divergence
from sinkhorn import sinkhorn

plt.plot(other_target_samples[:, 0], other_target_samples[:, 1], 'o')

plt.plot(samples[:, 0],  samples[:, 1], 'o')

plt.plot(target_samples[:, 0], target_samples[:, 1], 'o')
plt.plot()

w2_dists, self_w2_dists = [], []

# fig, ax = plt.subplots(3, 3, figsize=(9, ,9))
for i in range(9):
    
    samples_i = samples[i * n_samples : (i + 1) * n_samples, ...]
    target_samples_i = target_samples[i * n_samples : (i + 1) * n_samples, ...]
    other_target_samples_i = other_target_samples[i * n_samples : (i + 1) * n_samples, ...]
    print(samples_i.shape, target_samples_i.shape, other_target_samples_i.shape)
    w2_dists.append(sinkhorn(samples_i, target_samples_i))
    self_w2_dists.append(sinkhorn_divergence(target_samples_i, other_target_samples_i))
    # row, col = 
    plt.plot(other_target_samples_i[:, 0], other_target_samples_i[:, 1], 'o', label='other targets', alpha=0.5)
    plt.plot(target_samples_i[:, 0], target_samples_i[:, 1], 'o', label='targets', alpha=0.5)
    plt.plot(samples_i[:, 0],  samples_i[:, 1], 'x', label='samples', alpha=0.5)
    plt.title(f'w2 dist : {w2_dists[-1]}, seld w2 dist : {self_w2_dists[-1]}')
    plt.legend()
    plt.show()

w2_dists, self_w2_dists = np.array(w2_dists), np.array(self_w2_dists)

print(w2_dists, self_w2_dists)
print(f'self w2 dist: {np.mean(self_w2_dists)},std: {np.std(self_w2_dists)}')
print(f'w2 dist: {np.mean(w2_dists)},std: {np.std(w2_dists)}')
# w2_dists, self_w2_dists = [], []
# for i in range(n_input_dist_seeds):
    
#     samples_i = samples[i * n_samples : (i + 1) * n_samples, ...]
#     target_samples_i = target_samples[i * n_samples : (i + 1) * n_samples, ...]
#     other_target_samples_i = other_target_samples[i * n_samples : (i + 1) * n_samples, ...]
#     w2_dists.append(W2_distance(samples_i, target_samples_i))
#     self_w2_dists.append(W2_distance(target_samples_i, other_target_samples_i))

ModuleNotFoundError: No module named 'torch'