The aim of this notebook is to understand the three different shapes of TFP. The **event** shape, describing the "dimesionality" of x in $p(x)$. The **batch** shape, a batch of independent distributions, and the **sample** shape, describing a sample from that distributions.

### Links
See also:  

* https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb
* https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Distribution



In [83]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import tensorflow as tf
import tensorflow_probability as tfp

sns.reset_defaults()
#sns.set_style('whitegrid')
sns.set_context(context='talk',font_scale=0.7)

%matplotlib inline

tfd = tfp.distributions
print("TFB Version", tfp.__version__)
print("TF  Version",tf.__version__)

TFB Version 0.7.0-dev
TF  Version 2.0.0-alpha0


### Event Shape
    
Shape of a single draw (event) of a distribution. E.g. the shape of $x$ in $p(x)$. For a one dimensional x it is empty, there is no event_shape. 

In [84]:
n1 = tfd.Uniform(1,2)
n1

<tfp.distributions.Uniform 'Uniform/' batch_shape=() event_shape=() dtype=float32>

In [85]:
multivariate_normal = tfd.MultivariateNormalDiag(loc=[1., 2., 3.], scale_identity_multiplier=[1.])
multivariate_normal,multivariate_normal.mean()


(<tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag/' batch_shape=(1,) event_shape=(3,) dtype=float32>,
 <tf.Tensor: id=1931, shape=(1, 3), dtype=float32, numpy=array([[1., 2., 3.]], dtype=float32)>)

### Batch Shape
Batch shape describes independent, not identically distributed draws, aka a "batch" of distributions.

In [86]:
n2 = tfd.Uniform([1.,2.],2) #We have 2 independent distributions
n2

<tfp.distributions.Uniform 'Uniform/' batch_shape=(2,) event_shape=() dtype=float32>

In [87]:
two_multivariate_normals = tfd.MultivariateNormalDiag(loc=[[1., 2., 3.],[1., 2., 3.]], scale_identity_multiplier=[1.,2.])
print(two_multivariate_normals)

# Using broadcasting
two_multivariate_normals2 = tfd.MultivariateNormalDiag(loc=[1., 2., 3.], scale_identity_multiplier=[1.,2])
two_multivariate_normals2.mean(),two_multivariate_normals.mean()


tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag/", batch_shape=(2,), event_shape=(3,), dtype=float32)


(<tf.Tensor: id=2038, shape=(2, 3), dtype=float32, numpy=
 array([[1., 2., 3.],
        [1., 2., 3.]], dtype=float32)>,
 <tf.Tensor: id=1948, shape=(2, 3), dtype=float32, numpy=
 array([[1., 2., 3.],
        [1., 2., 3.]], dtype=float32)>)

### Sample Shape

We iid sample for the distribution(s) described by the batch and event shape

In [88]:
n1.sample([1,3]).numpy(),n1.sample([1,3]).shape

(array([[1.6436319, 1.4246452, 1.2373348]], dtype=float32),
 TensorShape([1, 3]))

In [89]:
n2.sample([1,3]).shape, n2

(TensorShape([1, 3, 2]),
 <tfp.distributions.Uniform 'Uniform/' batch_shape=(2,) event_shape=() dtype=float32>)

## Manipulating the shape

### Independent Distributions
The Independent distribution is used to treat a collection of independent, not-necessarily-identical (aka a batch of) distributions as a single distribution. The use case IMO is you sample data-points.

In [100]:
y = [1.1,2.2,3.3]
y_true = np.asarray([1.,2.,3.],dtype=np.float32)

In [101]:
n3 = tfd.Normal(loc=y,scale=0.5)

In [115]:
n3.log_prob(y_true) #Three number

<tf.Tensor: id=2186, shape=(3,), dtype=float32, numpy=array([-0.24579135, -0.30579138, -0.40579128], dtype=float32)>

Reinterpreting the batch dimensions. If we want a (single) batch dimension be intrepreted as event dimension we call set `reinterpreted_batch_ndims=1`. There if more than one batch dimension than: "The reinterpreted_batch_ndims parameter controls the number of batch dims which are absorbed as event dims". Taken from: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/g3doc/api_docs/python/tfp/distributions/Independent.md



In [116]:
n3e = tfd.Independent(
    tfd.Normal(loc=y,scale=0.5),
    reinterpreted_batch_ndims=1 
)
n3e

<tfp.distributions.Independent 'IndependentNormal/' batch_shape=() event_shape=(3,) dtype=float32>

In [119]:
n3e.log_prob(y_true) #A single number the sum of all three log-likelihoods from above

<tf.Tensor: id=2230, shape=(), dtype=float32, numpy=-0.957374>