In [9]:
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
import numpy as np
from flax.linen.module import compact
import flax.linen as nn
from flax.linen.dtypes import promote_dtype
from typing import Any, List
import dataclasses
import distrax

Array = Any

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
tfd = tfp.distributions
tfb = tfp.bijectors

In [11]:
import sys
sys.path.append("../")

from modules.autoregressive import  MaskedAutoregressiveBijector
from modules.utils import make_shift_and_scale

In [12]:
bij = MaskedAutoregressiveBijector(bijector_fn=make_shift_and_scale, n_params=2, hidden_dims=[8, 8], name="autoregressive_1")

key = jax.random.PRNGKey(2)
x = jax.random.uniform(key=key, shape=(2,))

params = bij.init(key, x)

In [13]:
apply_fn = lambda x: bij.apply({"params":params['params']}, x)

In [14]:
apply_fn(x)(jnp.ones(2))

DeviceArray([1.        , 0.75378555], dtype=float32)

In [15]:
n_transforms = 4

bijectors = []
for _ in range(n_transforms):
    bijectors.append(tfb.MaskedAutoregressiveFlow(bijector_fn=apply_fn)) 

bijector = distrax.Chain(bijectors)

In [16]:
base_dist = distrax.Independent(
            distrax.MultivariateNormalFullCovariance(
                loc=jnp.zeros(2),
                covariance_matrix=jnp.eye(2),
        ))

maf = distrax.Transformed(base_dist, bijector)

In [17]:
maf.log_prob(jnp.ones((4,2)))

DeviceArray([-22.546316, -22.546316, -22.546316, -22.546316], dtype=float32)

In [18]:
maf.sample(sample_shape=(8,), seed=jax.random.PRNGKey(0))

DeviceArray([[ 0.08482574,  1.3161952 ],
             [ 0.29561743, -0.67382956],
             [ 0.33432344, -2.570653  ],
             [ 0.6481277 , -2.022002  ],
             [-0.7824839 ,  3.252925  ],
             [ 0.6297971 , -2.1548283 ],
             [-0.32787678,  0.72380227],
             [-1.6607416 ,  6.293631  ]], dtype=float32)

In [114]:
class MaskedAutoregressiveFlow(nn.Module):
    n_transforms: int = 10

    def setup(self):

        self.bij = [MaskedAutoregressiveBijector(bijector_fn=make_shift_and_scale, n_params=2, hidden_dims=[8, 8], name="made_{}".format(i)) for i in range(self.n_transforms)]

        # Need to unroll loop since Jax transforms and Flax models cannot be mixed
        bijectors = []
        for i in range(self.n_transforms):
            bijectors.append(tfb.MaskedAutoregressiveFlow(bijector_fn=self.bij[i]))
        self.bijector: distrax.Bijector = distrax.Chain(bijectors)

    def make_flow(self):
        
        flow = self.bijector
        base_dist = distrax.Independent(
                    distrax.MultivariateNormalFullCovariance(
                        loc=jnp.zeros(2),
                        covariance_matrix=jnp.eye(2),
                ))

        return flow, base_dist

    def __call__(self, x: jnp.array) -> jnp.array:
        flow, base_dist = self.make_flow()
        return distrax.Transformed(base_dist, flow).log_prob(x)

In [115]:
maf = MaskedAutoregressiveFlow(n_transforms=10)

key = jax.random.PRNGKey(2)
x = jax.random.uniform(key=key, shape=(2,))

params = maf.init(key, x)

In [116]:
sum(x.size for x in jax.tree_leaves(params))

  sum(x.size for x in jax.tree_leaves(params))


1320