In [140]:
import jax
import jax.numpy as jnp
import jax.lax as lax
import jax.random as jrandom
import numpy as np
import optax
import distrax
import haiku as hk

import tensorflow as tf
import tensorflow_datasets as tfds

from lfiax.flows.nsf import make_nsf

from typing import (
    Any,
    Iterator,
    Mapping,
    Optional,
    Tuple,
)

Array = jnp.ndarray
PRNGKey = Array
Batch = Mapping[str, np.ndarray]
OptState = Any


def sim_linear_jax(d: Array, priors: Array, key: PRNGKey):
    # Keys for the appropriate functions
    keys = jrandom.split(key, 3)

    # sample random normal dist
    noise_shape = (1,)

    mu_noise = jnp.zeros(noise_shape)
    sigma_noise = jnp.ones(noise_shape)

    n_n = distrax.Independent(
        distrax.MultivariateNormalDiag(mu_noise, sigma_noise)
    ).sample(seed=keys[0], sample_shape=[len(d), len(priors)])

    # sample random gamma noise
    n_g = distrax.Gamma(2.0, 1.0 / 2.0).sample(
        seed=keys[1], sample_shape=[len(d), len(priors)]
    )

    # perform forward pass
    y = jnp.broadcast_to(priors[:, 0], (len(d), len(priors)))
    y = y + jnp.expand_dims(d, 1) @ jnp.expand_dims(priors[:, 1], 0)
    y = y + n_g + jnp.squeeze(n_n)
    ygrads = priors[:, 1]

    return y, ygrads


# ----------------------------------------
# Helper functions to simulate data
# ----------------------------------------
def load_dataset(split: tfds.Split, batch_size: int) -> Iterator[Batch]:
    ds = split
    ds = ds.shuffle(buffer_size=10 * batch_size)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=1000)
    ds = ds.repeat()
    return iter(tfds.as_numpy(ds))


def sim_data(d: Array, priors: Array, key: PRNGKey):
    """
    Returns data in a format suitable for normalizing flow training.
    Data will be in shape [y, thetas]. The `y` variable can vary in size.
    """
    keys = jrandom.split(key, 2)

    theta_shape = (2,)

    mu = jnp.zeros(theta_shape)
    sigma = (3**2) * jnp.ones(theta_shape)

    base_distribution = distrax.Independent(  # Should this be independent?
        distrax.MultivariateNormalDiag(mu, sigma)
    )

    priors = base_distribution.sample(seed=keys[0], sample_shape=[num_samples])

    # ygrads allows to be compared to other implementations (Kleinegesse et)
    y, ygrads = sim_linear_jax(d, priors, keys[1])

    return jnp.column_stack((y.T, jnp.squeeze(priors), jnp.broadcast_to(d, (num_samples, len(d)))))


def prepare_data(batch: Batch, prng_key: Optional[PRNGKey] = None) -> Array:
    # Batch is [y, thetas, d]
    data = batch.astype(np.float32)
    # Handling the scalar case
    if data.shape[1] <= 3:
        y = jnp.expand_dims(data[:, :-2], -1)
    # Use known length of x to split up the cond_data
    # data_shape = data.shape
    # start = [0, 0]
    # stop = [data_shape[0], len_x]
    # y = lax.dynamic_slice(data, start, stop)
    y = data[:, :len_x]
    cond_data = data[:, len_x:]
    theta = cond_data[:, :-len_x]
    x = cond_data[:, -len_x:len_xi]
    xi = cond_data[:, -len_xi:]
    # return x, cond_data
    return y, theta, x, xi


# ----------------------------
# Haiku transform functions for training and evaluation
# ----------------------------
@hk.without_apply_rng
@hk.transform
def log_prob(data: Array, cond_data: Array) -> Array:
    # Get batch
    shift = data.mean(axis=0)
    scale = data.std(axis=0) + 1e-14
    
    model = make_nsf(
        event_shape=EVENT_SHAPE,
        cond_info_shape=cond_info_shape,
        num_layers=flow_num_layers,
        hidden_sizes=[hidden_size] * mlp_num_layers,
        num_bins=num_bins,
        standardize_x=True,
        standardize_z=True,
        use_resnet=True,
        event_dim=EVENT_DIM,
        shift=shift,
        scale=scale,
    )
    return model.log_prob(data, cond_data)


@hk.without_apply_rng
@hk.transform
def model_sample(key: PRNGKey, num_samples: int, cond_data: Array) -> Array:
    model = make_nsf(
        event_shape=EVENT_SHAPE,
        cond_info_shape=cond_info_shape,
        num_layers=flow_num_layers,
        hidden_sizes=[hidden_size] * mlp_num_layers,
        num_bins=num_bins,
    )
    z = jnp.repeat(cond_data, num_samples, axis=0)
    z = jnp.expand_dims(z, -1)
    return model._sample_n(key=key, n=[num_samples], z=z)


def loss_fn(params: hk.Params, prng_key: PRNGKey, y: Array, theta: Array, x: Array, xi: Array) -> Array:
    # x, cond_data = prepare_data(batch, prng_key)
    # I wonder if this will work...
    cond_data = jnp.column_stack((theta, x, xi))
    # Loss is average negative log likelihood.
    loss = -jnp.mean(log_prob.apply(params, y, cond_data))
    return loss


@jax.jit
def eval_fn(params: hk.Params, batch: Batch) -> Array:
    y, theta, x, xi = prepare_data(batch)
    # loss = -jnp.mean(log_prob.apply(params, x, cond_data))
    loss = -jnp.mean(log_prob.apply(params, y, (theta, x, xi)))
    return loss


@jax.jit
def update(
    params: hk.Params, prng_key: PRNGKey, opt_state: OptState, batch: Batch
) -> Tuple[hk.Params, OptState]:
    """Single SGD update step."""
    # x, cond_data = prepare_data(batch, prng_key)
    y, theta, x, xi = prepare_data(batch)
    grads = jax.grad(loss_fn)(params, prng_key, y, theta, x, xi)
    grads_d = jax.grad(loss_fn, argnums=5)(params, prng_key, y, theta, x, xi)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state


In [118]:
seed = 123
key = jrandom.PRNGKey(seed)
d_prop = jrandom.uniform(key, shape=(1,), minval=-10., maxval=10.)

In [119]:
d_obs = jnp.array([1.])

In [124]:
jnp.concatenate((d_obs, d_prop), axis=0).shape

(2,)

In [141]:
# TODO: Put this in hydra config file
seed = 1231
key = jrandom.PRNGKey(seed)

# d = jnp.array([-10.0, 0.0, 5.0, 10.0])
# d = jnp.array([1., 2.])
# d = jnp.array([1.])
d_obs = jnp.array([1.])
d_prop = jrandom.uniform(key, shape=(1,), minval=-10., maxval=10.)
d = jnp.concatenate((d_obs, d_prop), axis=0)
len_x = len(d_obs) + len(d_prop)
len_xi = len(d_prop)
num_samples = 100

# Params and hyperparams
theta_shape = (2,)
EVENT_SHAPE = (len(d),)
# EVENT_DIM is important for the normalizing flow's block.
EVENT_DIM = 1
cond_info_shape = (theta_shape[0] + len(d),)

batch_size = 128
flow_num_layers = 10
mlp_num_layers = 4
hidden_size = 500
num_bins = 4
learning_rate = 1e-4

training_steps = 10  # 00
eval_frequency = 100

optimizer = optax.adam(learning_rate)

# Simulating the data to be used to train the flow.
num_samples = 10000
# TODO: put this function in training since d will be changing.
X = sim_data(d, num_samples, key)

# Create tf dataset from sklearn dataset
dataset = tf.data.Dataset.from_tensor_slices(X)

# Splitting into train/validate ds
train = dataset.skip(2000)
val = dataset.take(2000)

# load_dataset(split: tfds.Split, batch_size: int)
train_ds = load_dataset(train, 512)
valid_ds = load_dataset(val, 512)

# Training
prng_seq = hk.PRNGSequence(42)
params = log_prob.init(
    next(prng_seq),
    np.zeros((1, *EVENT_SHAPE)),
    np.zeros((1, *cond_info_shape)),
)
opt_state = optimizer.init(params)

for step in range(training_steps):
    params, opt_state = update(params, next(prng_seq), opt_state, next(train_ds))

    if step % eval_frequency == 0:
        val_loss = eval_fn(params, next(valid_ds))
        print(f"STEP: {step:5d}; Validation loss: {val_loss:.3f}")


ValueError: 'conditioner_module/mlp/~/linear_0/w' with retrieved shape (6, 500) does not match shape=[5, 500] dtype=dtype('float32')

## Example using a more simple linear regression model.

In [102]:
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split


# create our dataset
X, y = make_regression(n_features=3)
X, X_test, y, y_test = train_test_split(X, y)


# model weights
params = {
    'w': jnp.zeros(X.shape[1:]),
    'b': 0.
}


def forward(params, X):
    return jnp.dot(X, params['w']) + params['b']


def loss_fn(params, X, y):
    err = forward(params, X) - y
    return jnp.mean(jnp.square(err))  # mse


grad_fn = jax.grad(loss_fn)
grad_xs = jax.grad(loss_fn, argnums=1)


def update(params, grads):
    return jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)


# the main training loop
for _ in range(100):
    loss = loss_fn(params, X_test, y_test)
    print(loss)

    grads = grad_fn(params, X, y)
    # print(grads)
    grads_x = grad_xs(params, X, y)
    # print(grads_x)
    params = update(params, grads)
# jvp = jvp_fn(params, X, y)
# vjp_fn, grad_inputs = jvp_fn(params, X, y)
# print(grad_inputs)


8141.749
7026.248
6101.3345
5327.9395
4675.831
4121.5415
3646.7786
3237.2224
2881.5994
2570.981
2298.2488
2057.681
1844.6412
1655.3422
1486.6582
1335.9877
1201.1442
1080.2745
971.79266
874.3323
786.706
707.87585
636.92914
573.0588
515.5483
463.7596
417.122
375.12442
337.30768
303.2594
272.60788
245.01825
220.18892
197.8472
177.74779
159.66864
143.40965
128.79007
115.64697
103.83316
93.215836
83.67557
75.1043
67.404594
60.48888
54.27826
48.701363
43.694206
39.19906
35.16405
31.542425
28.292137
25.37536
22.758108
20.409803
18.302986
16.412924
14.717433
13.196639
11.832554
10.609089
9.511839
8.527853
7.6454344
6.8541503
6.144577
5.50838
4.9379315
4.426497
3.9679482
3.5568316
3.1882825
2.857861
2.5616286
2.2960923
2.0580654
1.8446738
1.6533967
1.4819403
1.3282497
1.1904765
1.0669953
0.95632344
0.85711676
0.76819175
0.68849045
0.6170555
0.55302685
0.49564597
0.4442091
0.39811057
0.35679317
0.319765
0.28658006
0.2568317
0.23017737
0.20628563
0.18487637
0.16568606
0.14848359


In [101]:
grads_x

DeviceArray([[ 6.62432110e-04,  6.52491627e-03,  4.58412757e-03],
             [ 2.30209646e-03,  2.26755124e-02,  1.59308463e-02],
             [ 6.78365293e-04,  6.68185716e-03,  4.69438732e-03],
             [-2.70380464e-04, -2.66323122e-03, -1.87107245e-03],
             [ 5.83249319e-04,  5.74496994e-03,  4.03617043e-03],
             [ 6.56638294e-05,  6.46784727e-04,  4.54403315e-04],
             [ 3.39906896e-04,  3.34806228e-03,  2.35220557e-03],
             [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
             [ 5.02135139e-04,  4.94600087e-03,  3.47484881e-03],
             [ 1.54503126e-04,  1.52184651e-03,  1.06918428e-03],
             [ 2.39479850e-04,  2.35886197e-03,  1.65723567e-03],
             [ 1.12545874e-03,  1.10857002e-02,  7.78833916e-03],
             [ 5.52348676e-04,  5.44060115e-03,  3.82233388e-03],
             [-8.92255572e-04, -8.78866389e-03, -6.17453968e-03],
             [ 6.74744151e-05,  6.64618856e-04,  4.66932834e-04],
          

In [74]:
print(jnp.min(grads_x))

-0.0


In [68]:
grads.shape

(75, 3)

In [69]:
X.shape

(75, 3)

In [70]:
X_test.shape

(25, 3)

In [58]:
forward_grad = jax.jacfwd(forward)

# Compute the gradients of the forward pass with respect to X
forward_grads_X = forward_grad(params, X)

# Compute the Jacobian of the loss function with respect to X using the chain rule
loss_grads_X = jnp.dot(grad_fn(params['w'], X, y), forward_grads_X)


TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[tuple(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information.