In [1]:
import jax
import jax.numpy as jnp
import numpy as np

from tensorflow_probability.substrates import jax as tfp

from flax.linen.module import compact
import flax.linen as nn
from typing import Any

Array = Any

In [2]:
tfd = tfp.distributions
tfb = tfp.bijectors

In [13]:
class MaskedDense(nn.Dense):
    use_context: bool = False
    
    @compact
    def __call__(self, inputs: Array, context=None, mask=None) -> Array:
        """
        Taken from flax.linen.Dense.
        Applies a masked linear transformation to the inputs along the last dimension.
        Args:
        inputs: The nd-array to be transformed.
        Returns:
        The transformed input.
        """
        inputs = jnp.asarray(inputs, self.dtype)
        if context is not None and self.use_context:
            assert (
                inputs.shape[0] == context.shape[0]
            ), "inputs and context must have the same batch size"
            inputs = jnp.hstack([inputs, context])

        kernel = self.param(
            "kernel", self.kernel_init, (mask.shape[0], self.features)
        )
        kernel = jnp.asarray(kernel, self.dtype)
        kernel = kernel * mask  # major difference from flax.linen.Dense
        y = jax.lax.dot_general(
            inputs,
            kernel,
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )
        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features,))
            bias = jnp.asarray(bias, self.dtype)
            y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
        return y

In [14]:
def make_shift_and_scale(params):
    return tfb.Shift(params[..., 0])(tfb.Scale(log_scale=params[..., 1]))

In [54]:
class MaskedAutoregressiveBijector(nn.Module):
    bijector_fn: Any = make_shift_and_scale
    n_params: Any = 2
    hidden_dims = [8, 8]
        
    @compact
    def __call__(self, y):
        
        n_inputs = y.shape[-1]
        broadcast_dims = y.shape[:-1]

        masks = tfb.masked_autoregressive._make_dense_autoregressive_masks(
            self.n_params,
            n_inputs, 
            self.hidden_dims, 
            input_order='left-to-right'
        )
        
        for mask in masks[:-1]:
            y = MaskedDense(features=mask.shape[-1])(y, mask=mask)
            y = jax.nn.tanh(y)
        y = MaskedDense(features=masks[-1].shape[-1])(y, mask=masks[-1])
                
        # Unravel the inputs and parameters
        params = y.reshape(broadcast_dims + (n_inputs, self.n_params))
        
        return self.bijector_fn(params)

In [55]:
bij = MaskedAutoregressiveBijector()

key = jax.random.PRNGKey(2)
x = jax.random.uniform(key=key, shape=(2,))

params = bij.init(key, x)

In [56]:
apply_fn = lambda x: bij.apply({"params":params['params']}, x)

In [57]:
apply_fn(x)(jnp.ones(2))

DeviceArray([1.        , 0.75378555], dtype=float32)

In [58]:
n_transforms = 4

bijectors = []
for _ in range(n_transforms):
    bijectors.append(tfb.MaskedAutoregressiveFlow(bijector_fn=apply_fn))    
bijector = tfb.Chain(bijectors)

In [59]:
maf = tfd.TransformedDistribution(
    distribution=tfd.Sample(
        tfd.Normal(loc=0., scale=1.), sample_shape=[2]),
    bijector=bijector)

In [60]:
maf.log_prob(jnp.ones((4,2)))

DeviceArray([-22.546316, -22.546316, -22.546316, -22.546316], dtype=float32)

In [61]:
maf.sample(sample_shape=(8,), seed=jax.random.PRNGKey(0))

DeviceArray([[ 0.08482574,  1.3161952 ],
             [ 0.29561743, -0.67382956],
             [ 0.33432344, -2.570653  ],
             [ 0.6481277 , -2.022002  ],
             [-0.7824839 ,  3.252925  ],
             [ 0.6297971 , -2.1548283 ],
             [-0.32787678,  0.72380227],
             [-1.6607416 ,  6.293631  ]], dtype=float32)