In [148]:
import jax
import jax.numpy as jnp
import jax.lax as lax
import jax.random as jrandom
import numpy as np
import optax
import distrax
import haiku as hk

import tensorflow as tf
import tensorflow_datasets as tfds

from lfiax.flows.nsf import make_nsf

from typing import (
    Any,
    Iterator,
    Mapping,
    Optional,
    Tuple,
)

Array = jnp.ndarray
PRNGKey = Array
Batch = Mapping[str, np.ndarray]
OptState = Any


def sim_linear_jax(d: Array, priors: Array, key: PRNGKey):
    # Keys for the appropriate functions
    keys = jrandom.split(key, 3)

    # sample random normal dist
    noise_shape = (1,)

    mu_noise = jnp.zeros(noise_shape)
    sigma_noise = jnp.ones(noise_shape)

    n_n = distrax.Independent(
        distrax.MultivariateNormalDiag(mu_noise, sigma_noise)
    ).sample(seed=keys[0], sample_shape=[len(d), len(priors)])

    # sample random gamma noise
    n_g = distrax.Gamma(2.0, 1.0 / 2.0).sample(
        seed=keys[1], sample_shape=[len(d), len(priors)]
    )

    # perform forward pass
    y = jnp.broadcast_to(priors[:, 0], (len(d), len(priors)))
    y = y + jnp.expand_dims(d, 1) @ jnp.expand_dims(priors[:, 1], 0)
    y = y + n_g + jnp.squeeze(n_n)
    ygrads = priors[:, 1]

    return y, ygrads


# ----------------------------------------
# Helper functions to simulate data
# ----------------------------------------
def load_dataset(split: tfds.Split, batch_size: int) -> Iterator[Batch]:
    ds = split
    ds = ds.shuffle(buffer_size=10 * batch_size)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=1000)
    ds = ds.repeat()
    return iter(tfds.as_numpy(ds))


def sim_data(d: Array, priors: Array, key: PRNGKey):
    """
    Returns data in a format suitable for normalizing flow training.
    Data will be in shape [y, thetas]. The `y` variable can vary in size.
    """
    keys = jrandom.split(key, 2)

    theta_shape = (2,)

    mu = jnp.zeros(theta_shape)
    sigma = (3**2) * jnp.ones(theta_shape)

    base_distribution = distrax.Independent(  # Should this be independent?
        distrax.MultivariateNormalDiag(mu, sigma)
    )

    priors = base_distribution.sample(seed=keys[0], sample_shape=[num_samples])

    # ygrads allows to be compared to other implementations (Kleinegesse et)
    y, ygrads = sim_linear_jax(d, priors, keys[1])

    return jnp.column_stack((y.T, jnp.squeeze(priors), jnp.broadcast_to(d, (num_samples, len(d)))))


def prepare_data(batch: Batch, prng_key: Optional[PRNGKey] = None) -> Array:
    # Batch is [y, thetas, d]
    data = batch.astype(np.float32)
    # Handling the scalar case
    if data.shape[1] <= 3:
        y = jnp.expand_dims(data[:, :-2], -1)
    # Use known length of x to split up the cond_data
    # data_shape = data.shape
    # start = [0, 0]
    # stop = [data_shape[0], len_x]
    # y = lax.dynamic_slice(data, start, stop)
    y = data[:, :len_x]
    cond_data = data[:, len_x:]
    theta = cond_data[:, :-len_x]
    x = cond_data[:, -len_x:-len_xi]
    xi = cond_data[:, -len_xi:]
    # return x, cond_data
    # breakpoint()
    return y, theta, x, xi


# ----------------------------
# Haiku transform functions for training and evaluation
# ----------------------------
@hk.without_apply_rng
@hk.transform
def log_prob(data: Array, cond_data: Array) -> Array:
    # Get batch
    shift = data.mean(axis=0)
    scale = data.std(axis=0) + 1e-14
    
    model = make_nsf(
        event_shape=EVENT_SHAPE,
        cond_info_shape=cond_info_shape,
        num_layers=flow_num_layers,
        hidden_sizes=[hidden_size] * mlp_num_layers,
        num_bins=num_bins,
        standardize_x=True,
        standardize_z=True,
        use_resnet=True,
        event_dim=EVENT_DIM,
        shift=shift,
        scale=scale,
    )
    return model.log_prob(data, cond_data)


@hk.without_apply_rng
@hk.transform
def model_sample(key: PRNGKey, num_samples: int, cond_data: Array) -> Array:
    model = make_nsf(
        event_shape=EVENT_SHAPE,
        cond_info_shape=cond_info_shape,
        num_layers=flow_num_layers,
        hidden_sizes=[hidden_size] * mlp_num_layers,
        num_bins=num_bins,
    )
    z = jnp.repeat(cond_data, num_samples, axis=0)
    z = jnp.expand_dims(z, -1)
    return model._sample_n(key=key, n=[num_samples], z=z)


def loss_fn(params: hk.Params, prng_key: PRNGKey, y: Array, theta: Array, x: Array, xi: Array) -> Array:
    # x, cond_data = prepare_data(batch, prng_key)
    # I wonder if this will work...
    cond_data = jnp.concatenate((theta, x, xi), axis=1)
    # Loss is average negative log likelihood.
    loss = -jnp.mean(log_prob.apply(params, y, cond_data))
    return loss


@jax.jit
def eval_fn(params: hk.Params, batch: Batch) -> Array:
    y, theta, x, xi = prepare_data(batch)
    cond_data = jnp.concatenate((theta, x, xi), axis=1)
    # loss = -jnp.mean(log_prob.apply(params, x, cond_data))
    loss = -jnp.mean(log_prob.apply(params, y, cond_data))
    return loss


@jax.jit
def update(
    params: hk.Params, prng_key: PRNGKey, opt_state: OptState, batch: Batch
) -> Tuple[hk.Params, OptState]:
    """Single SGD update step."""
    # x, cond_data = prepare_data(batch, prng_key)
    y, theta, x, xi = prepare_data(batch)
    grads = jax.grad(loss_fn)(params, prng_key, y, theta, x, xi)
    grads_d = jax.grad(loss_fn, argnums=5)(params, prng_key, y, theta, x, xi)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, grads_d
    

In [118]:
seed = 123
key = jrandom.PRNGKey(seed)
d_prop = jrandom.uniform(key, shape=(1,), minval=-10., maxval=10.)

In [119]:
d_obs = jnp.array([1.])

In [124]:
jnp.concatenate((d_obs, d_prop), axis=0).shape

(2,)

In [149]:
# TODO: Put this in hydra config file
seed = 1231
key = jrandom.PRNGKey(seed)

# d = jnp.array([-10.0, 0.0, 5.0, 10.0])
# d = jnp.array([1., 2.])
# d = jnp.array([1.])
d_obs = jnp.array([1.])
d_prop = jrandom.uniform(key, shape=(1,), minval=-10., maxval=10.)
d = jnp.concatenate((d_obs, d_prop), axis=0)
len_x = len(d_obs) + len(d_prop)
len_xi = len(d_prop)
num_samples = 100

# Params and hyperparams
theta_shape = (2,)
EVENT_SHAPE = (len(d),)
# EVENT_DIM is important for the normalizing flow's block.
EVENT_DIM = 1
cond_info_shape = (theta_shape[0] + len(d),)

batch_size = 128
flow_num_layers = 10
mlp_num_layers = 4
hidden_size = 500
num_bins = 4
learning_rate = 1e-4

training_steps = 10  # 00
eval_frequency = 100

optimizer = optax.adam(learning_rate)

# Simulating the data to be used to train the flow.
num_samples = 10000
# TODO: put this function in training since d will be changing.
X = sim_data(d, num_samples, key)

# Create tf dataset from sklearn dataset
dataset = tf.data.Dataset.from_tensor_slices(X)

# Splitting into train/validate ds
train = dataset.skip(2000)
val = dataset.take(2000)

# load_dataset(split: tfds.Split, batch_size: int)
train_ds = load_dataset(train, 512)
valid_ds = load_dataset(val, 512)

# Training
prng_seq = hk.PRNGSequence(42)
params = log_prob.init(
    next(prng_seq),
    np.zeros((1, *EVENT_SHAPE)),
    np.zeros((1, *cond_info_shape)),
)
opt_state = optimizer.init(params)

for step in range(training_steps):
    params, opt_state, grads_d = update(params, next(prng_seq), opt_state, next(train_ds))

    if step % eval_frequency == 0:
        val_loss = eval_fn(params, next(valid_ds))
        print(f"STEP: {step:5d}; Validation loss: {val_loss:.3f}")


STEP:     0; Validation loss: 7.679


## Example using a more simple linear regression model.

In [159]:
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split


# create our dataset
X, y = make_regression(n_features=3)
X, X_test, y, y_test = train_test_split(X, y)


# model weights
params = {
    'w': jnp.zeros(X.shape[1:]),
    'b': 0.
}


def forward(params, X):
    return jnp.dot(X, params['w']) + params['b']


def loss_fn(params, X, y):
    err = forward(params, X) - y
    return jnp.mean(jnp.square(err))  # mse


grad_fn = jax.grad(loss_fn)
grad_xs = jax.grad(loss_fn, argnums=1)


def update(params, grads):
    return jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)


# the main training loop
for _ in range(100):
    loss = loss_fn(params, X_test, y_test)
    print(loss)

    grads = grad_fn(params, X, y)
    # print(grads)
    grads_x = grad_xs(params, X, y)
    # print(grads_x)
    params = update(params, grads)
# jvp = jvp_fn(params, X, y)
# vjp_fn, grad_inputs = jvp_fn(params, X, y)
# print(grad_inputs)


2181.5369
1858.5502
1585.1548
1353.422
1156.7448
989.6083
847.4021
726.26587
622.9615
534.76965
459.40234
394.93176
339.73135
292.42667
251.85446
217.0293
187.11482
161.40063
139.28238
120.24537
103.85087
89.72437
77.54581
67.04163
57.97752
50.15276
43.395264
37.557304
32.51204
28.15043
24.378696
21.116192
18.293394
15.850478
13.735813
11.904947
10.31944
8.946174
7.7565503
6.725816
5.8326335
5.058549
4.3875713
3.8059084
3.3016088
2.8643444
2.4851553
2.156312
1.8710963
1.6236962
1.4090894
1.2229112
1.0613791
0.92123014
0.7996186
0.6940952
0.6025214
0.5230455
0.45407078
0.39420485
0.34224513
0.29714483
0.25799254
0.22400692
0.19450277
0.16889036
0.1466547
0.12734972
0.110588394
0.09603701
0.08339975
0.07242788
0.06290057
0.054628
0.047444887
0.041206427
0.03578893
0.03108498
0.026999
0.023450699
0.020369269
0.017693063
0.015368364
0.01334941
0.011596036
0.01007299
0.008750262
0.007601229
0.0066029932
0.0057360493
0.0049831304
0.004328905
0.0037607104
0.0032672042
0.0028383692
0.002466018

In [166]:
grads_x

DeviceArray([[ 3.52432905e-03,  2.47096978e-02,  5.91810495e-02],
             [ 4.47242710e-05,  3.13569850e-04,  7.51016545e-04],
             [-1.65014213e-03, -1.15694404e-02, -2.77094282e-02],
             [ 2.14535417e-03,  1.50414603e-02,  3.60250995e-02],
             [-3.33245192e-03, -2.33644154e-02, -5.59590235e-02],
             [-2.63125426e-03, -1.84481926e-02, -4.41844091e-02],
             [-4.31130687e-03, -3.02273389e-02, -7.23960921e-02],
             [-1.45995826e-03, -1.02360277e-02, -2.45158337e-02],
             [-6.79470249e-04, -4.76388726e-03, -1.14097642e-02],
             [-1.30786747e-03, -9.16969217e-03, -2.19619032e-02],
             [-2.38773972e-03, -1.67408697e-02, -4.00952771e-02],
             [ 5.24106342e-03,  3.67460325e-02,  8.80087093e-02],
             [-1.81098015e-03, -1.26971044e-02, -3.04102451e-02],
             [-7.59501301e-04, -5.32499887e-03, -1.27536571e-02],
             [ 2.56988191e-04,  1.80178997e-03,  4.31538327e-03],
          

In [74]:
print(jnp.min(grads_x))

-0.0


In [68]:
grads.shape

(75, 3)

In [69]:
X.shape

(75, 3)

In [70]:
X_test.shape

(25, 3)

In [58]:
forward_grad = jax.jacfwd(forward)

# Compute the gradients of the forward pass with respect to X
forward_grads_X = forward_grad(params, X)

# Compute the Jacobian of the loss function with respect to X using the chain rule
loss_grads_X = jnp.dot(grad_fn(params['w'], X, y), forward_grads_X)


TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[tuple(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information.