In [13]:
%matplotlib inline
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
# easier access to tensorflow distributions
ds = tf.contrib.distributions

An advanced inference tool should support some tweaking and instrospection by the user, therefore a composable design in important. An ideal interface would have high level constructs for the model and the parameters and utilities to compute the most relevant statistical quantities. Further down the line, serialization would be important to share/import/export statistical models.

About statistical models, we can clearly distinguish three different types:
1. Those for which we can analytically evaluate the normalized pdf (e.g. multibin Poisson likelihood, gaussian, bijectable-tranform over a distribribution with and analytical pdf)
2. Those for which we can compute the unnormalized pdf but we do not analytically know the normalization.
3. Those for which we cannot compute even the unnormalized pdf, because they are complex and/or probabilistc/stochastic.

In HEP, we normally deal with the first type (a realistic statiscal model for each event is of the third type but we cannot deal with that so we use sample summary statistics or very simplified pdf modelling, all of it corresponding to type 1).

I could not think of an example usage of type 2 statistical models in HEP, but in principle are equivalent to type 1 as long as you compute the normalization integral.

Therefore, the initial focus of the tools should be type 1 statistical models. Other tools like Edward/Pyro deal already with probabilistic models.

In [16]:
ds.Poisson([2.,2.]).batch_shape

TensorShape([Dimension(2)])

In [18]:
poisson_ind = ds.Independent(distribution=ds.Poisson([3.,4.]),
                            reinterpreted_batch_ndims=1)
poisson_ind.batch_shape

TensorShape([])

In [32]:
sess = tf.Session()
sess.run(poisson_ind.sample())

array([3., 6.], dtype=float32)

In [34]:
sess.run(poisson_ind.log_prob([[2.,1.],[1.,1.], [2.,1.]]))

array([-4.109628 , -4.5150933, -4.109628 ], dtype=float32)