# TensorFlow Distributions Shapes

mostly copied from:
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="7"

In [2]:
import tensorflow as tf
tf.contrib.eager.enable_eager_execution()
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

import matplotlib.pyplot as plt

print('TensorFlow version {}, TF Probability version {}.'.format(tf.__version__, tfp.__version__))

TensorFlow version 1.10.0, TF Probability version 0.4.0.


In [3]:
poisson_distributions = [
    tfd.Poisson(rate=1., name='One Poisson Scalar Batch'),
    tfd.Poisson(rate=[1., 10., 100.], name='Three Poissons'),
    tfd.Poisson(rate=[[1., 10., 100.,], [2., 20., 200.]],
                name='Two-by-Three Poissons'),
    tfd.Poisson(rate=[1.], name='One Poisson Vector Batch'),
    tfd.Poisson(rate=[[1.]], name='One Poisson Expanded Batch')
]

print('\n'.join([str(d) for d in poisson_distributions]))

tf.distributions.Poisson("One Poisson Scalar Batch/", batch_shape=(), event_shape=(), dtype=float32)
tf.distributions.Poisson("Three Poissons/", batch_shape=(3,), event_shape=(), dtype=float32)
tf.distributions.Poisson("Two-by-Three Poissons/", batch_shape=(2, 3), event_shape=(), dtype=float32)
tf.distributions.Poisson("One Poisson Vector Batch/", batch_shape=(1,), event_shape=(), dtype=float32)
tf.distributions.Poisson("One Poisson Expanded Batch/", batch_shape=(1, 1), event_shape=(), dtype=float32)


In [4]:
three_poissons = tfd.Poisson(rate=[1., 10., 100.], name='Three Poissons')
three_poissons

<tf.distributions.Poisson 'Three Poissons/' batch_shape=(3,) event_shape=() dtype=float32>

In [5]:
three_poissons.log_prob([10.])

<tf.Tensor: id=33, shape=(3,), dtype=float32, numpy=array([-16.104412 ,  -2.0785599, -69.052704 ], dtype=float32)>

In [6]:
# input has shape (2, 2, 1)
three_poissons.log_prob([[[1.], [10.]], [[100.], [1000.]]])

<tf.Tensor: id=53, shape=(2, 2, 3), dtype=float32, numpy=
array([[[-1.0000000e+00, -7.6974149e+00, -9.5394829e+01],
        [-1.6104412e+01, -2.0785599e+00, -6.9052704e+01]],

       [[-3.6473938e+02, -1.4348087e+02, -3.2223511e+00],
        [-5.9131279e+03, -3.6195427e+03, -1.4069575e+03]]], dtype=float32)>

In [7]:
# input has shape (2, 2); this will result in an error
three_poissons.log_prob([[1., 10.], [100., 1000.]])

InvalidArgumentError: Incompatible shapes: [2,2] vs. [3] [Op:Mul] name: Three Poissons/log_prob/mul/

### tfb.Reshape

Using the bijector *tfb.Reshape*, we can change the event shape of a distribution.

In [8]:
six_way_multinomial = tfd.Multinomial(total_count=1000., probs=[.3, .25, .2, .15, .08, .02])
six_way_multinomial

<tf.distributions.Multinomial 'Multinomial/' batch_shape=() event_shape=(6,) dtype=float32>

In [9]:
transformed_multinomial = tfd.TransformedDistribution(
    distribution=six_way_multinomial,
    bijector=tfb.Reshape(event_shape_out=[2, 3]))
transformed_multinomial

<tf.distributions.TransformedDistribution 'reshapeMultinomial/' batch_shape=() event_shape=(2, 3) dtype=float32>

In [10]:
event = [500., 100., 100., 150., 100., 50.]
event_ = [[500., 100., 100.], [150., 100., 50.]]
assert six_way_multinomial.log_prob(event).numpy() == \
    transformed_multinomial.log_prob(event_).numpy()

### tfd.Independent

Using *tfd.Independent*, we can move batch_shape to event_shape.

In [11]:
two_by_five_bernoulli = tfd.Bernoulli(
    probs=[[.05, .1, .15, .2, .25], [.3, .35, .4, .45, .5]])
two_sets_of_five = tfd.Independent(
    distribution=two_by_five_bernoulli,
    reinterpreted_batch_ndims=1)
two_by_five_bernoulli, two_sets_of_five

(<tf.distributions.Bernoulli 'Bernoulli/' batch_shape=(2, 5) event_shape=() dtype=int32>,
 <tf.distributions.Independent 'IndependentBernoulli/' batch_shape=(2,) event_shape=(5,) dtype=int32>)

In [12]:
event = [[1., 0., 0., 1., 0.], [0., 0., 1., 1., 1.]]
print(two_by_five_bernoulli.log_prob(event))
print(two_sets_of_five.log_prob(event))

tf.Tensor(
[[-2.9957323  -0.10536052 -0.16251893 -1.609438   -0.28768206]
 [-0.35667494 -0.43078288 -0.9162907  -0.79850775 -0.6931472 ]], shape=(2, 5), dtype=float32)
tf.Tensor([-5.160732  -3.1954036], shape=(2,), dtype=float32)
