In [1]:
import numpy as np

In [2]:
import pymc4 as pm
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [3]:
config = tf.ConfigProto()

In [4]:
config.graph_options.optimizer_options.global_jit_level = (
    tf.OptimizerOptions.ON_1)
config.intra_op_parallelism_threads = 1
config.inter_op_parallelism_threads = 1

## Eager

In [5]:
tf.enable_eager_execution()

In [6]:
@pm.model
def t_test(sd_prior='half_normal'):
    mu = pm.Normal('mu', 0, 1)
    sd = pm.HalfNormal('sd', 1)
    pm.Normal('y_0', 0, 2 * sd)
    pm.Normal('y_1', mu, 2 * sd)

model = t_test.configure()

model._forward_context.vars

[<pymc4.random_variables.continuous.Normal at 0xb31a6a048>,
 <pymc4.random_variables.continuous.HalfNormal at 0xb31a6a208>,
 <pymc4.random_variables.continuous.Normal at 0x11321c668>,
 <pymc4.random_variables.continuous.Normal at 0xb31a6a898>]

In [8]:
func = model.make_log_prob_function()

mu = tf.ones((10,))
sd = tf.ones((10,))
y_0 = tf.ones((10,))
y_1 = tf.ones((10,))
%timeit logp = func(mu, sd, y_0, y_1)
func(mu, sd, y_0, y_1)

8.88 ms ± 433 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


<tf.Tensor: id=329987, shape=(), dtype=float32, numpy=-76.88429>

In [9]:
logp_func_defun = tf.contrib.eager.defun(func)


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [10]:
mu = tf.ones((10,))
sd = tf.ones((10,))
y_0 = tf.ones((10,))
y_1 = tf.ones((10,))
%timeit logp = logp_func_defun(mu, sd, y_0, y_1)
logp_func_defun(mu, sd, y_0, y_1)

287 µs ± 67.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


<tf.Tensor: id=330307, shape=(), dtype=float32, numpy=-76.88429>

In [11]:
mu = tf.ones((10,))
sd = tf.ones((10,))
y_0 = tf.ones((10,))
y_1 = tf.ones((10,))

with tf.GradientTape() as tape:
    tape.watch(mu)
    tape.watch(sd)
    tape.watch(y_0)
    tape.watch(y_1)

    logp = logp_func_defun(mu, sd, y_0, y_1)

tape.gradient(logp, mu)

<tf.Tensor: id=330677, shape=(10,), dtype=float32, numpy=array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], dtype=float32)>

In [12]:
%%timeit
with tf.GradientTape() as tape:
    tape.watch(mu)
    tape.watch(sd)
    tape.watch(y_0)
    tape.watch(y_1)
    logp = logp_func_defun(mu, sd, y_0, y_1)

tape.gradient(logp, [mu, sd, y_0, y_1])

823 µs ± 34.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [13]:
def logp_and_grad(*args):
    logp = func(*args)
    return logp, tf.gradients(logp, args)

logp_grad_func_defun = tf.contrib.eager.defun(logp_and_grad)

mu = tf.ones((10,))
sd = tf.ones((10,))
y_0 = tf.ones((10,))
y_1 = tf.ones((10,))
%timeit logp = logp_grad_func_defun(mu, sd, y_0, y_1)
logp_grad_func_defun(mu, sd, y_0, y_1)

498 µs ± 77.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


(<tf.Tensor: id=639539, shape=(), dtype=float32, numpy=-76.88429>,
 [<tf.Tensor: id=639540, shape=(10,), dtype=float32, numpy=array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], dtype=float32)>,
  <tf.Tensor: id=639541, shape=(10,), dtype=float32, numpy=
  array([-8.139055, -8.139055, -8.139055, -8.139055, -8.139055, -8.139055,
         -8.139055, -8.139055, -8.139055, -8.139055], dtype=float32)>,
  <tf.Tensor: id=639542, shape=(10,), dtype=float32, numpy=
  array([-0.25, -0.25, -0.25, -0.25, -0.25, -0.25, -0.25, -0.25, -0.25,
         -0.25], dtype=float32)>,
  <tf.Tensor: id=639543, shape=(10,), dtype=float32, numpy=array([-0., -0., -0., -0., -0., -0., -0., -0., -0., -0.], dtype=float32)>])

In [14]:
from tensorflow.contrib.compiler import xla

In [16]:
# Doesn't work
array = tf.ones(40)

@tf.contrib.eager.defun
def logp_wrapper(array):
    mu = array[:10]
    sd = array[10:20]
    y_0 = array[20:30]
    y_1 = array[30:40]
    logp = func(mu, sd, y_0, y_1)
    grad = tf.gradients(logp, array)
    return logp, grad

@tf.contrib.eager.defun
def logp_wrapper_xla(array):
    logp, grad = xla.compile(logp_wrapper, inputs=[array])
    return logp, grad

logp_wrapper_xla(array)

ValueError: Op type not registered 'XlaClusterOutput' in binary running on ip-192-168-0-206.ec2.internal. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done before importing the graph, as contrib ops are lazily registered when the module is first accessed. while building NodeDef 'output0'

In [35]:
%timeit logp_wrapper(array)

350 µs ± 49.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## Comparison to PyMC3

In [17]:
import pymc3 as pm3

In [18]:
with pm3.Model() as model:
    mu = pm3.Normal('mu', 0, 1, shape=10)
    sd = pm3.HalfNormal('sd', sd=1, transform=None, shape=10)
    pm3.Normal('y_0', 0, 2 * sd, shape=10)
    pm3.Normal('y_1', mu, 2 * sd, shape=10)

In [19]:
func_pm3 = model.logp_dlogp_function()

In [20]:
x0 = np.ones(func_pm3.size)

In [21]:
func_pm3.set_extra_values({})
%timeit func_pm3(x0)

40.1 µs ± 4.08 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
