In [1]:
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
import numpy as np
from flax.linen.module import compact
import flax.linen as nn
from flax.linen.dtypes import promote_dtype
from typing import Any, List
import dataclasses
import distrax

Array = Any

%load_ext autoreload
%autoreload 2


In [2]:
tfd = tfp.distributions
tfb = tfp.bijectors

In [3]:
import sys
sys.path.append("../")

from modules.autoregressive import MADE

In [49]:
from modules.autoregressive import MADE

made = MADE(n_params=5, n_context=0, activation="tanh", hidden_dims=[32, 32], name="made_0")
key = jax.random.PRNGKey(2)
x = jax.random.uniform(key=key, shape=(8, 5))

init_params = made.init(key, x)
params = made.apply(init_params, x)

forward, log_det = distrax.ScalarAffine(shift=params[..., 0], log_scale=params[..., 1]).forward_and_log_det(x)
forward.shape, log_det.shape

((8, 5), (8, 5))

In [95]:
from modules.autoregressive import MAF

made = MADE(n_params=5, n_context=0, activation="tanh", hidden_dims=[32, 32], name="made_0")
key = jax.random.PRNGKey(2)
x = jax.random.uniform(key=key, shape=(2, 4, 5))
init_params = made.init(key, x)

bijector_fn = lambda x: made.apply(init_params, x)

event_shape = x.shape[1:]

flow = MAF(bijector_fn)

base_dist = distrax.MultivariateNormalDiag(jnp.zeros(5,), jnp.ones(5,))

distrax.Transformed(base_dist, flow).log_prob(x)

Array([[-5.813867 , -5.5223155, -5.689406 , -6.288032 ],
       [-5.5090237, -5.5535507, -5.508765 , -5.299556 ]], dtype=float32)

In [96]:
distrax.Transformed(base_dist, flow).sample(seed=key)

Array([-0.15508525,  1.5718881 , -0.37886548,  0.17252125, -0.60508585],      dtype=float32)

In [97]:
y, log_det = flow.forward_and_log_det(x)
print(log_det)

[[ 0.08049314  0.19977322  0.5477073   0.12634201]
 [-0.01320131  0.0720925   0.16729566 -0.13098422]]


In [98]:
flow.inverse_and_log_det(y)

(Array([[[0.08457661, 0.31502724, 0.83997506, 0.9254357 , 0.8644473 ],
         [0.16037524, 0.1584147 , 0.76908594, 0.86303467, 0.21834007],
         [0.9804157 , 0.02629685, 0.48894688, 0.07442081, 0.12056994],
         [0.9569601 , 0.96784085, 0.7532649 , 0.7426135 , 0.75791216]],
 
        [[0.29766083, 0.6187032 , 0.9520696 , 0.13110866, 0.6226797 ],
         [0.28658247, 0.45798576, 0.6337631 , 0.9467083 , 0.41258538],
         [0.4673562 , 0.43952918, 0.9372883 , 0.0879786 , 0.04822491],
         [0.6085954 , 0.9759288 , 0.36098158, 0.21089767, 0.03525542]]],      dtype=float32),
 Array([[-0.08049314, -0.19977322, -0.5477073 , -0.12634201],
        [ 0.01320131, -0.0720925 , -0.16729566,  0.13098422]],      dtype=float32))

In [101]:
class MaskedAutoregressiveFlow(nn.Module):
    n_dim: int
    n_context: int = 0
    n_transforms: int = 4
    hidden_dims: List[int] = dataclasses.field(default_factory=lambda: [32, 32])
    activation: str = "tanh"
    unroll_loop: bool = False

    def setup(self):

        self.made = [MADE(n_params=self.n_dim, n_context=self.n_context, activation=self.activation, hidden_dims=self.hidden_dims, name="made_{}".format(i)) for i in range(self.n_transforms)]

        bijectors = []
        for i in range(self.n_transforms):
            bijectors.append(tfb.Permute(list(reversed(range(self.n_dim)))))
            bijectors.append(MAF(bijector_fn=self.made[i], unroll_loop=self.unroll_loop))
        
        self.bijector = distrax.Inverse(distrax.Chain(bijectors))

    def make_flow_model(self):

        flow = self.bijector
        base_dist = distrax.Independent(distrax.MultivariateNormalDiag(jnp.zeros(self.n_dim), jnp.ones(self.n_dim)))

        return flow, base_dist

    def __call__(self, x: jnp.array) -> jnp.array:
        flow, base_dist = self.make_flow_model()
        return distrax.Transformed(base_dist, flow).log_prob(x)

In [103]:
maf = MaskedAutoregressiveFlow(n_dim=5, unroll_loop=True)

key = jax.random.PRNGKey(2)
x = jax.random.uniform(key=key, shape=(3, 5))
params = maf.init(key, x)