In [149]:
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
import numpy as np
from flax.linen.module import compact
import flax.linen as nn
from flax.linen.dtypes import promote_dtype
from typing import Any, List
import dataclasses
import distrax

Array = Any

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [150]:
tfd = tfp.distributions
tfb = tfp.bijectors

In [151]:
import sys
sys.path.append("../")

from modules.autoregressive import MADE
from modules.utils import make_shift_and_scale

In [158]:
class MaskedAutoregressiveFlow(nn.Module):
    n_dim: int
    n_context: int = 0
    n_transforms: int = 1
    hidden_dims: List[int] = dataclasses.field(default_factory=lambda: [32, 32])
    activation: str = "tanh"
    use_random_permutations: bool = False
    rng_key: jnp.ndarray = jax.random.PRNGKey(0)

    def setup(self):

        self.bij = [MADE(bijector_fn=make_shift_and_scale, n_params=self.n_dim, n_context=self.n_context, activation=self.activation, hidden_dims=self.hidden_dims, name="made_{}".format(i)) for i in range(self.n_transforms)]

        # Need to unroll loop since Jax transforms and Flax models cannot be mixed
        bijectors = []
        key = self.rng_key
        for i in range(self.n_transforms):
            if self.use_random_permutations:
                bijectors.append(tfb.Permute(jax.random.choice(key, jnp.arange(self.n_dim), shape=(self.n_dim,), replace=False)))
                key, _ = jax.random.split(key)
            else:
                bijectors.append(tfb.Permute(list(reversed(range(self.n_dim)))))
            bijectors.append(tfb.MaskedAutoregressiveFlow(bijector_fn=self.bij[i]))

        self.bijector: distrax.Bijector = distrax.Chain(bijectors)

    def make_flow_model(self):

        flow = self.bijector
        base_dist = distrax.MultivariateNormalDiag(jnp.zeros(self.n_dim), jnp.ones(self.n_dim))

        return flow, base_dist

    def __call__(self, x: jnp.array) -> jnp.array:
        flow, base_dist = self.make_flow_model()
        return distrax.Transformed(base_dist, flow).log_prob(x)

In [159]:
maf = MaskedAutoregressiveFlow(n_dim=5, n_context=0)

key = jax.random.PRNGKey(2)
x = jax.random.uniform(key=key, shape=(8, 5))
params = maf.init(key, x)

TypeError: log_prob() got an unexpected keyword argument 'context'