## This notebook illustrates how the argument to `log_prob` is broadcasted against the shape of a `tfd.Distribution`

See "Broadcasting, aka Why Is This So Confusing?" at: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/TensorFlow_Distributions_Tutorial.ipynb

See the graphical depiction of broadcasting rules at: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules

More examples at: 
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb

In [7]:
print(norm.pdf(0, loc=1, scale=2))
print(norm.pdf(2, loc=1, scale=2))
print(norm.pdf(4, loc=1, scale=2))

0.17603266338214976
0.17603266338214976
0.06475879783294587


In [8]:
n = tfd.Normal(loc=[0., 1], scale=[1., 2])
print(n)
arg = np.array([0., 2.])
print('arg_shape={}'.format(arg.shape))
print(n.prob(arg))

tfp.distributions.Normal("Normal/", batch_shape=(2,), event_shape=(), dtype=float32)
arg_shape=(2,)
tf.Tensor([0.3989423  0.17603266], shape=(2,), dtype=float32)


In [9]:
n = tfd.Normal(loc=[[0.], [1]], scale=[[1.], [2]])
print(n)
arg = np.array([[0.], [2.]])
print('arg_shape={}'.format(arg.shape))
print(n.prob(arg))

tfp.distributions.Normal("Normal/", batch_shape=(2, 1), event_shape=(), dtype=float32)
arg_shape=(2, 1)
tf.Tensor(
[[0.3989423 ]
 [0.17603266]], shape=(2, 1), dtype=float32)


In [10]:
n = tfd.Normal(loc=[[0.], [1]], scale=[[1.], [2]])
print(n)
arg = np.array([0., 2., 4.])
print('arg_shape={}'.format(arg.shape))
print(n.prob(arg))

tfp.distributions.Normal("Normal/", batch_shape=(2, 1), event_shape=(), dtype=float32)
arg_shape=(3,)
tf.Tensor(
[[3.9894229e-01 5.3990960e-02 1.3383021e-04]
 [1.7603266e-01 1.7603266e-01 6.4758793e-02]], shape=(2, 3), dtype=float32)


In [11]:
n = tfd.Normal(loc=[0., 1], scale=[1., 2])
print(n)
arg = np.array([[0.], [2.], [4.]])
print('arg_shape={}'.format(arg.shape))
print(n.prob(arg))

tfp.distributions.Normal("Normal/", batch_shape=(2,), event_shape=(), dtype=float32)
arg_shape=(3, 1)
tf.Tensor(
[[3.9894229e-01 1.7603266e-01]
 [5.3990960e-02 1.7603266e-01]
 [1.3383021e-04 6.4758793e-02]], shape=(3, 2), dtype=float32)
