In [2]:
import distrax
import jax
import jax.numpy as jnp


In [6]:
key = jax.random.PRNGKey(1234)
mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.1, 0.2, 0.3])

normal_dist = distrax.MultivariateNormalDiag(mu, sigma)
samples, log_proba = normal_dist.sample_and_log_prob(seed=key)
samples, log_proba

(DeviceArray([-1.0019782 , -0.01461947,  0.6765263 ], dtype=float32),
 DeviceArray(1.7750063, dtype=float32))

# Understanding some terminologies...

This step is important since distrax is inspired by the TensorFlow Probability (tfp) library. We borrow these notes from tfp or distrax docs.

## 1. distribution shapes

There are three important concepts associated with TensorFlow Distributions shapes:

* Event shape describes the shape of a single draw from the distribution; it may be dependent across dimensions. For scalar distributions, the event shape is `[]`. For a 5-dimensional MultivariateNormal, the event shape is `[5]`.
* Batch shape describes independent, not identically distributed draws, aka a "batch" of distributions.
* Sample shape describes independent, identically distributed draws of batches from the distribution family.

The event shape and the batch shape are properties of a Distribution object, whereas the sample shape is associated with a specific call to sample or log_prob.

Some experiments now...

In [25]:
uniform_distributions = [
    distrax.Uniform(low=0., high=1.),
    distrax.Uniform(low=[0., 0., 0.], high=[1., 1., 1.]),
    distrax.Uniform(low=jnp.zeros((2,3)), high=jnp.ones((2,3))),
    distrax.Uniform(low=[0.], high=[1.]),
    distrax.Uniform(low=[[0.]], high=[[1.]]),
    distrax.Uniform(low=0., high=jnp.ones((2, 2)))
] # event shape is [], batch shape decided by low, high

In [22]:
uniform_distributions[0].sample(seed=key) # one uniform scalar batch

DeviceArray(0.29453015, dtype=float32)

In [17]:
uniform_distributions[1].sample(seed=key) # three uniforms

DeviceArray([0.49210894, 0.4708643 , 0.14046204], dtype=float32)

In [18]:
uniform_distributions[2].sample(seed=key) # 2 by 3 uniforms

DeviceArray([[0.8177053 , 0.17224324, 0.24385035],
             [0.03261805, 0.6770656 , 0.8112081 ]], dtype=float32)

In [20]:
uniform_distributions[3].sample(seed=key) # vector batch

DeviceArray([0.29453015], dtype=float32)

In [23]:
uniform_distributions[4].sample(seed=key) # expanded batch

DeviceArray([[0.29453015]], dtype=float32)

In [26]:
uniform_distributions[5].sample(seed=key) # parameter broadcast

DeviceArray([[0.49210894, 0.44287562],
             [0.14046204, 0.10368097]], dtype=float32)

The basic rule is that when we sample from a distribution, the resulting Tensor has shape `[sample_shape, batch_shape, event_shape]`, where batch_shape and event_shape are provided by the Distribution object, and sample_shape is provided by the call to sample. For scalar distributions, event_shape = `[]`, so the Tensor returned from sample will have shape `[sample_shape, batch_shape]`.

In [31]:
def describe_sample_tensor_shape(sample_shape, distribution):
    print('Sample shape:', sample_shape)
    print('Returned sample tensor shape:',
          distribution.sample(sample_shape=sample_shape, seed=key).shape)

In [32]:
describe_sample_tensor_shape((3,2), uniform_distributions[0])

Sample shape: (3, 2)
Returned sample tensor shape: (3, 2)


In [35]:
def describe_sample_tensor_shapes(distributions, sample_shapes):
    started = False
    for distribution in distributions:
      print(distribution)
      for sample_shape in sample_shapes:
        describe_sample_tensor_shape(sample_shape, distribution)

sample_shapes = [(), 1, 2, [1, 5], [3, 4, 5]]
describe_sample_tensor_shapes(poisson_distributions, sample_shapes)

<distrax._src.distributions.uniform.Uniform object at 0x13c271090>
Sample shape: ()
Returned sample tensor shape: (1,)
Sample shape: 1
Returned sample tensor shape: (1, 1)
Sample shape: 2
Returned sample tensor shape: (2, 1)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 1)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 1)
<distrax._src.distributions.uniform.Uniform object at 0x13c159d20>
Sample shape: ()
Returned sample tensor shape: (3,)
Sample shape: 1
Returned sample tensor shape: (1, 3)
Sample shape: 2
Returned sample tensor shape: (2, 3)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 3)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 3)
<distrax._src.distributions.uniform.Uniform object at 0x160367e50>
Sample shape: ()
Returned sample tensor shape: (2, 3)
Sample shape: 1
Returned sample tensor shape: (1, 2, 3)
Sample shape: 2
Returned sample tensor shape: (2, 2, 3)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 2, 

We note that in case where sample and/or event shape is (), the corresponding slot in the overall Tensor shape will be ignored.

Now let's take a look at log_prob, which is somewhat trickier. log_prob takes as input a (non-empty) tensor representing the location(s) at which to compute the log_prob for the distribution. In the most straightforward case, this tensor will have a shape of the form `[sample_shape, batch_shape, event_shape]`, where batch_shape and event_shape match the batch and event shapes of the distribution. Recall once more that for scalar distributions, event_shape = `[]`, so the input tensor has shape `[sample_shape, batch_shape]`. In this case, we get back a tensor of shape `[sample_shape, batch_shape]`:

In [42]:
_, log_proba = uniform_distributions[0].sample_and_log_prob(sample_shape=(), seed=key)
log_proba.shape

()

In [43]:
_, log_proba = uniform_distributions[0].sample_and_log_prob(sample_shape=(1), seed=key)
log_proba.shape

(1,)

In [44]:
_, log_proba = uniform_distributions[0].sample_and_log_prob(sample_shape=(1,2,3), seed=key)
log_proba.shape

(1, 2, 3)

In [45]:
_, log_proba = uniform_distributions[1].sample_and_log_prob(sample_shape=(1), seed=key)
log_proba.shape

(1, 3)

okay, that's straightforward enough! Now onto reading about the MaskedCoupling layer in the distrax library.

The masking layer has forward operation of 
$$
y = (1-m) \cdot f(x; g(m\cdot x)) + m\cdot x,
$$
where $m$ is the binary mask array of the same dimension as $x$, $\cdot$ is the elementwise multiplication op, and $g$ is the conditioner function that transforms on the masked input and produces output to be conditioned on in the inner bijector $f$. 

In the RealNVP model, the conditioner is typically a neural network. For simplicity we use an MLP here:

In [49]:
from typing import Sequence
Array = jnp.array
import haiku as hk
import numpy as np

    
def make_conditioner(event_shape: Sequence[int],
                     hidden_sizes: Sequence[int], 
                     num_bijector_params: int) -> hk.Sequential:
    return hk.Sequential([
        hk.Flatten(preserve_dims=-len(event_shape)), # so flatten all event dimensions
        hk.nets.MLP(hidden_sizes, activate_final=True), # core MLP
        # final projection, set weight=0 to start from the identity flow
        hk.Linear(
          np.prod(event_shape) * num_bijector_params,
          w_init=jnp.zeros,
          b_init=jnp.zeros),
        hk.Reshape(tuple(event_shape) + (num_bijector_params,), preserve_dims=-1), # unflatten the last dim to event shape
    ])

# affine bijector as used in RealNVP
def bijector_fn(params: Array):
    shift, log_scale = params[..., 0], params[..., 1]
    return distrax.ScalarAffine(shift=shift,
                               log_scale=log_scale)

def make_flow(num_layers = 5,
              event_shape = [2],
              hidden_sizes = [4, 4, event_shape[0]*2]) -> distrax.Transformed:

    # Alternating binary mask.
    mask = jnp.arange(0, np.prod(event_shape)) % 2
    mask = jnp.reshape(mask, event_shape)
    mask = mask.astype(bool)

    layers = []
    for _ in range(num_layers):
        layer = distrax.MaskedCoupling(
            mask=mask,
            bijector=bijector_fn,
            conditioner=make_conditioner(event_shape, hidden_sizes, 2)
        )
        layers.append(layer)
        mask = jnp.logical_not(mask) 
    
        flow = distrax.Inverse(distrax.Chain(layers))
        base_dist = distrax.MultivariateNormalDiag(loc=jnp.zeros(shape=event_shape),
                                                   scale_diag=jnp.ones(shape=event_shape))
        
        return distrax.Transformed(base_dist, flow)
