## This notebook illustrates how the argument to `log_prob` is broadcasted against the shapes of a `tfd.Distribution`

See "Broadcasting, aka Why Is This So Confusing?" at: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/TensorFlow_Distributions_Tutorial.ipynb

See the graphical depiction of broadcasting rules at: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules

More examples at: 
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb

In [1]:
from tensorflow import enable_eager_execution
enable_eager_execution()


In [2]:
import tensorflow_probability as tfp 
tfd = tfp.distributions 

import numpy as np 

n = tfd.Normal(loc=[0., 1], scale=[1., 2])
print(n)
arg = np.array([0., 2.])
print('arg_shape={}'.format(arg.shape))
print(n.prob(arg))


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

tfp.distributions.Normal("Normal/", batch_shape=(2,), event_shape=(), dtype=float32)
arg_shape=(2,)
tf.Tensor([0.3989423  0.17603266], shape=(2,), dtype=float32)


In [3]:
n = tfd.Normal(loc=[[0.], [1]], scale=[[1.], [2]])
print(n)
arg = np.array([[0.], [2.]])
print('arg_shape={}'.format(arg.shape))
print(n.prob(arg))

tfp.distributions.Normal("Normal/", batch_shape=(2, 1), event_shape=(), dtype=float32)
arg_shape=(2, 1)
tf.Tensor(
[[0.3989423 ]
 [0.17603266]], shape=(2, 1), dtype=float32)


In [4]:
n = tfd.Normal(loc=[[0.], [1]], scale=[[1.], [2]])
print(n)
arg = np.array([0., 2., 4.])
print('arg_shape={}'.format(arg.shape))
print(n.prob(arg))

tfp.distributions.Normal("Normal/", batch_shape=(2, 1), event_shape=(), dtype=float32)
arg_shape=(3,)
tf.Tensor(
[[3.9894229e-01 5.3990960e-02 1.3383021e-04]
 [1.7603266e-01 1.7603266e-01 6.4758793e-02]], shape=(2, 3), dtype=float32)


In [5]:
n = tfd.Normal(loc=[0., 1], scale=[1., 2])
print(n)
arg = np.array([[0.], [2.], [4.]])
print('arg_shape={}'.format(arg.shape))
print(n.prob(arg))

tfp.distributions.Normal("Normal/", batch_shape=(2,), event_shape=(), dtype=float32)
arg_shape=(3, 1)
tf.Tensor(
[[3.9894229e-01 1.7603266e-01]
 [5.3990960e-02 1.7603266e-01]
 [1.3383021e-04 6.4758793e-02]], shape=(3, 2), dtype=float32)


In [6]:
n = tfd.Normal(loc=[[0.], [1.], [2.]], scale=[[1.], [2], [3.]])
print(n)
n_shape_shift = tfd.Independent(
    distribution=n,
    reinterpreted_batch_ndims=1
)
print(n_shape_shift)
arg = np.array([[0.], [2.], [4.]])
print('arg={}'.format(arg))
print('n_shape_shift.prob(arg): {}'.format(n_shape_shift.prob(arg)))

from scipy.stats import norm

print(norm.pdf(0, loc=0, scale=1))
print(norm.pdf(2, loc=1, scale=2))
print(norm.pdf(4, loc=2, scale=3))

tfp.distributions.Normal("Normal/", batch_shape=(3, 1), event_shape=(), dtype=float32)
tfp.distributions.Independent("IndependentNormal/", batch_shape=(3,), event_shape=(1,), dtype=float32)
arg=[[0.]
 [2.]
 [4.]]
n_shape_shift.prob(arg): [0.3989423  0.17603266 0.10648265]
0.3989422804014327
0.17603266338214976
0.10648266850745075


## Computing probs of a batch of distributions at a number of sample locations

See: https://nbviewer.jupyter.org/github/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb#Multivariate-distributions


In [24]:
two_multivariate_normals = tfd.MultivariateNormalDiag(loc=[1., 2., 3.], scale_identity_multiplier=[1., 2.])
print('distribution: {}'.format(two_multivariate_normals))
arg = np.array([[[1., 2., 3.]], [[3., 4., 5.]], [[6., 7., 8.]], [[9., 10., 11.]]])
print('arg.shape: {}'.format(arg.shape))
print('log_prob.shape: {}'.format(two_multivariate_normals.log_prob(arg).shape))



distribution: tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag/", batch_shape=(2,), event_shape=(3,), dtype=float32)
arg.shape: (4, 1, 3)
log_prob.shape: (4, 2)


In [23]:
arg = np.array([[1., 2., 3.], [3., 4., 5.], [6., 7., 8.], [9., 10., 11.]])
print('arg.shape: {}'.format(arg.shape))
import tensorflow as tf
arg = arg[:, tf.newaxis, :]
print('arg.shape: {}'.format(arg.shape))


arg.shape: (4, 3)
arg.shape: (4, 1, 3)


In [31]:
two_multivariate_normals = tfd.MultivariateNormalDiag(loc=[1.], scale_identity_multiplier=[1., 2.])
print('distribution: {}'.format(two_multivariate_normals))
arg = np.array([1., 2., 3.])
arg = arg[..., tf.newaxis, tf.newaxis]
print('arg.shape: {}'.format(arg.shape))
print('log_prob.shape: {}'.format(two_multivariate_normals.log_prob(arg).shape))


distribution: tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag/", batch_shape=(2,), event_shape=(1,), dtype=float32)
arg.shape: (3, 1, 1)
log_prob.shape: (3, 2)


In [34]:
print('arg: {}'.format(arg))
print('two_multivariate_normals.log_prob(arg): {}'.format(two_multivariate_normals.log_prob(arg)))


arg: [[[1.]]

 [[2.]]

 [[3.]]]
two_multivariate_normals.log_prob(arg): [[-0.9189385 -1.6120857]
 [-1.4189385 -1.7370857]
 [-2.9189386 -2.1120858]]


In [42]:
first_batch = tfd.Normal(loc=1., scale=1.)
print('first_batch: {}'.format(first_batch))
print(first_batch.log_prob(arg.squeeze()))

first_batch: tfp.distributions.Normal("Normal/", batch_shape=(), event_shape=(), dtype=float32)
tf.Tensor([-0.9189385 -1.4189385 -2.9189386], shape=(3,), dtype=float32)


In [43]:
second_batch = tfd.Normal(loc=1., scale=2.)
print('second_batch: {}'.format(second_batch))
print(second_batch.log_prob(arg.squeeze()))

second_batch: tfp.distributions.Normal("Normal/", batch_shape=(), event_shape=(), dtype=float32)
tf.Tensor([-1.6120857 -1.7370857 -2.1120858], shape=(3,), dtype=float32)
