In [45]:
import jax
import jax.numpy as jnp
import jax.lax as lax
import jax.random as jrandom
import numpy as np
import optax
import distrax
import haiku as hk

import tensorflow as tf
import tensorflow_datasets as tfds

from lfiax.flows.nsf import make_nsf

from typing import (
    Any,
    Iterator,
    Mapping,
    Optional,
    Tuple,
)

Array = jnp.ndarray
PRNGKey = Array
Batch = Mapping[str, np.ndarray]
OptState = Any


def sim_linear_jax(d: Array, priors: Array, key: PRNGKey):
    # Keys for the appropriate functions
    keys = jrandom.split(key, 3)

    # sample random normal dist
    noise_shape = (1,)

    mu_noise = jnp.zeros(noise_shape)
    sigma_noise = jnp.ones(noise_shape)

    n_n = distrax.Independent(
        distrax.MultivariateNormalDiag(mu_noise, sigma_noise)
    ).sample(seed=keys[0], sample_shape=[len(d), len(priors)])

    # sample random gamma noise
    n_g = distrax.Gamma(2.0, 1.0 / 2.0).sample(
        seed=keys[1], sample_shape=[len(d), len(priors)]
    )

    # perform forward pass
    y = jnp.broadcast_to(priors[:, 0], (len(d), len(priors)))
    y = y + jnp.expand_dims(d, 1) @ jnp.expand_dims(priors[:, 1], 0)
    y = y + n_g + jnp.squeeze(n_n)
    ygrads = priors[:, 1]

    return y, ygrads


# ----------------------------------------
# Helper functions to simulate data
# ----------------------------------------
def load_dataset(split: tfds.Split, batch_size: int) -> Iterator[Batch]:
    ds = split
    ds = ds.shuffle(buffer_size=10 * batch_size)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=1000)
    ds = ds.repeat()
    return iter(tfds.as_numpy(ds))


def sim_data(d: Array, priors: Array, key: PRNGKey):
    """
    Returns data in a format suitable for normalizing flow training.
    Data will be in shape [y, thetas]. The `y` variable can vary in size.
    """
    keys = jrandom.split(key, 2)

    theta_shape = (2,)

    mu = jnp.zeros(theta_shape)
    sigma = (3**2) * jnp.ones(theta_shape)

    base_distribution = distrax.Independent(  # Should this be independent?
        distrax.MultivariateNormalDiag(mu, sigma)
    )

    priors = base_distribution.sample(seed=keys[0], sample_shape=[num_samples])

    # ygrads allows to be compared to other implementations (Kleinegesse et)
    y, ygrads = sim_linear_jax(d, priors, keys[1])

    return jnp.column_stack((y.T, jnp.squeeze(priors), jnp.broadcast_to(d, (num_samples, len(d)))))


def prepare_data(batch: Batch, prng_key: Optional[PRNGKey] = None) -> Array:
    # Batch is [y, thetas, d]
    data = batch.astype(np.float32)
    # Handling the scalar case
    if data.shape[1] <= 3:
        x = jnp.expand_dims(data[:, :-2], -1)
    # Use known length of x to split up the cond_data
    data_shape = data.shape
    start = [0, 0]
    stop = [data_shape[0], len_x]
    x = lax.dynamic_slice(data, start, stop)
    cond_data = data[:, len_x:]
    # breakpoint()
    return x, cond_data


# ----------------------------
# Haiku transform functions for training and evaluation
# ----------------------------
@hk.without_apply_rng
@hk.transform
def log_prob(data: Array, cond_data: Array) -> Array:
    # Get batch
    shift = data.mean(axis=0)
    scale = data.std(axis=0) + 1e-14
    
    model = make_nsf(
        event_shape=EVENT_SHAPE,
        cond_info_shape=cond_info_shape,
        num_layers=flow_num_layers,
        hidden_sizes=[hidden_size] * mlp_num_layers,
        num_bins=num_bins,
        standardize_x=True,
        standardize_z=True,
        use_resnet=True,
        event_dim=EVENT_DIM,
        shift=shift,
        scale=scale,
    )
    return model.log_prob(data, cond_data)


@hk.without_apply_rng
@hk.transform
def model_sample(key: PRNGKey, num_samples: int, cond_data: Array) -> Array:
    model = make_nsf(
        event_shape=EVENT_SHAPE,
        cond_info_shape=cond_info_shape,
        num_layers=flow_num_layers,
        hidden_sizes=[hidden_size] * mlp_num_layers,
        num_bins=num_bins,
    )
    z = jnp.repeat(cond_data, num_samples, axis=0)
    z = jnp.expand_dims(z, -1)
    return model._sample_n(key=key, n=[num_samples], z=z)


def loss_fn(params: hk.Params, prng_key: PRNGKey, batch: Batch) -> Array:
    x, cond_data = prepare_data(batch, prng_key)
    # Loss is average negative log likelihood.
    loss = -jnp.mean(log_prob.apply(params, x, cond_data))
    return loss


@jax.jit
def eval_fn(params: hk.Params, batch: Batch) -> Array:
    x, cond_data = prepare_data(batch)
    loss = -jnp.mean(log_prob.apply(params, x, cond_data))
    return loss


@jax.jit
def update(
    params: hk.Params, prng_key: PRNGKey, opt_state: OptState, batch: Batch
) -> Tuple[hk.Params, OptState]:
    """Single SGD update step."""
    grads = jax.grad(loss_fn)(params, prng_key, batch)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state

In [46]:
# TODO: Put this in hydra config file
seed = 1231
key = jrandom.PRNGKey(seed)

d = jnp.array([-10.0, 0.0, 5.0, 10.0])
# d = jnp.array([1., 2.])
# d = jnp.array([1.])
len_x = len(d)
num_samples = 100

# Params and hyperparams
theta_shape = (2,)
EVENT_SHAPE = (len(d),)
# EVENT_DIM is important for the normalizing flow's block.
EVENT_DIM = 1
cond_info_shape = (theta_shape[0] + len(d),)

batch_size = 128
flow_num_layers = 10
mlp_num_layers = 4
hidden_size = 500
num_bins = 4
learning_rate = 1e-4

training_steps = 10  # 00
eval_frequency = 100

optimizer = optax.adam(learning_rate)

# Simulating the data to be used to train the flow.
num_samples = 10000
X = sim_data(d, num_samples, key)

# Create tf dataset from sklearn dataset
dataset = tf.data.Dataset.from_tensor_slices(X)

# Splitting into train/validate ds
train = dataset.skip(2000)
val = dataset.take(2000)

# load_dataset(split: tfds.Split, batch_size: int)
train_ds = load_dataset(train, 512)
valid_ds = load_dataset(val, 512)

# Training
prng_seq = hk.PRNGSequence(42)
params = log_prob.init(
    next(prng_seq),
    np.zeros((1, *EVENT_SHAPE)),
    np.zeros((1, *cond_info_shape)),
)
opt_state = optimizer.init(params)

for step in range(training_steps):
    params, opt_state = update(params, next(prng_seq), opt_state, next(train_ds))

    if step % eval_frequency == 0:
        val_loss = eval_fn(params, next(valid_ds))
        print(f"STEP: {step:5d}; Validation loss: {val_loss:.3f}")


TypeError: Shapes must be 1D sequences of concrete values of integer type, got Traced<ShapedArray(int32[2])>with<DynamicJaxprTrace(level=0/1)>.
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

In [90]:
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split


# create our dataset
X, y = make_regression(n_features=3)
X, X_test, y, y_test = train_test_split(X, y)


# model weights
params = {
    'w': jnp.zeros(X.shape[1:]),
    'b': 0.
}


def forward(params, X):
    return jnp.dot(X, params['w']) + params['b']


def loss_fn(params, X, y):
    err = forward(params, X) - y
    return jnp.mean(jnp.square(err))  # mse


def jvp_fn(params, xs, y):
    # return jax.jacfwd(lambda params: loss_fn(params, xs, y))(params)
    return jax.vjp(loss_fn)(params, xs, y)


grad_fn = jax.grad(loss_fn)
grad_xs = jax.grad(loss_fn, argnums=1)

# loss, grad_fn = jax.value_and_grad(loss_fn, argnums=(1, 2))

# def loss_and_grads_fn(params, xs, y):
#     loss, grads = jax.value_and_grad(loss_fn)(params, xs, y)
#     vjp_fn = jax.vjp(loss_fn, (params, xs, y))
#     inputs_grads = vjp_fn(params)
#     return loss, grads, inputs_grads

# def loss_and_grads_fn(params, xs, y):
#     loss, grads = jax.value_and_grad(loss_fn)(params, xs, y)
#     vjp_fn = jax.vjp(lambda params: loss_fn(params, xs, y))
#     inputs_grads = vjp_fn(params)
#     return loss, grads, inputs_grads

# def loss_and_grads_fn(params, xs, y):
#     loss, grads = jax.value_and_grad(loss_fn)(params, xs, y)
#     vjp_fn = jax.vjp(lambda params, xs, y: loss_fn(params, xs, y))
#     inputs_grads = vjp_fn(params)
#     return loss, grads, inputs_grads


# def loss_and_grads_fn(params, xs, y):
#     loss, grads = jax.value_and_grad(lambda params, xs, y: loss_fn(params, xs, y))(params, xs, y)
#     vjp_fn = jax.vjp(lambda params, xs, y: loss_fn(params, xs, y))
#     inputs_grads = vjp_fn(params, xs, y)
#     return loss, grads, inputs_grads

# def loss_and_grads_fn(params, xs, y):
#     loss, grads = jax.value_and_grad(loss_fn)(params, xs, y)
#     vjp_fn = jax.vjp(lambda params, xs, y: loss_fn(params, xs, y))
#     inputs_grads = vjp_fn(params, xs, y)
#     return loss, grads, inputs_grads

# def loss_and_grads_fn(params, xs, y):
#     loss, grads = jax.value_and_grad(loss_fn)(params, xs, y)
#     vjp_fn = jax.vjp(lambda p, x, y: loss_fn(p, x, y))
#     inputs_grads = vjp_fn(params, xs, y)
#     return loss, grads, inputs_grads


# def loss_and_grads_vjp(params, xs, y):
#     def loss_and_grads(params):
#         loss, grads = jax.value_and_grad(loss_fn)(params, xs, y)
#         return loss, grads
#     return jax.vjp(loss_and_grads)(params)


# loss, grads, grads_x = loss_and_grads_fn(params, X, y)


def update(params, grads):
    return jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)


# the main training loop
# for _ in range(50):
loss = loss_fn(params, X_test, y_test)
print(loss)

grads = grad_fn(params, X, y)
print(grads)
grads_x = grad_xs(params, X, y)
print(grads_x)
params = update(params, grads)
# jvp = jvp_fn(params, X, y)
# vjp_fn, grad_inputs = jvp_fn(params, X, y)
# print(grad_inputs)


11523.323
{'b': DeviceArray(21.859552, dtype=float32, weak_type=True), 'w': DeviceArray([-179.33102 , -201.21684 ,  -60.641083], dtype=float32)}
[[ 0.  0.  0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [-0. -0. -0.]
 [-0. -0. -0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [-0. -0. -0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [-0. -0. -0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [-0. -0. -0.]
 [-0. -0. -0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [-0. -0. -0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [-0. -0. -0.]
 [-0. -0. -0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [-0. -0. -0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]


In [74]:
print(jnp.min(grads_x))

-0.0


In [68]:
grads.shape

(75, 3)

In [69]:
X.shape

(75, 3)

In [70]:
X_test.shape

(25, 3)

In [58]:
forward_grad = jax.jacfwd(forward)

# Compute the gradients of the forward pass with respect to X
forward_grads_X = forward_grad(params, X)

# Compute the Jacobian of the loss function with respect to X using the chain rule
loss_grads_X = jnp.dot(grad_fn(params['w'], X, y), forward_grads_X)


TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[tuple(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information.