Skip to content

Commit

Permalink
Define forward log det jacobian explicitly in CorrelationCholesky bij…
Browse files Browse the repository at this point in the history
…ector.

The bijector base class definition is not able to handle rank-changing bijectors such as CorrelationCholesky. The problem is that the value of event_ndims for fldj is sent to ildj as is but it should be 'event_ndims + 1'.

For this specific case, the resulting exception only gets triggered when value of 'event_ndims' is not known statically -- for instance when invoked by HamiltonianMonteCarlo in Graph mode. The error is still there in Eager mode IIUC but in this case the final answer is correct as no exception is thrown.

I've added a unit test for this case which fails without this change.

PiperOrigin-RevId: 249770795
  • Loading branch information
bloops authored and tensorflower-gardener committed May 24, 2019
1 parent 912f40f commit 1231703
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ def _inverse(self, y):
# transformation.
return fill_triangular.FillTriangular().inverse(x[..., 1:, :-1])

def _forward_log_det_jacobian(self, x):
# TODO(b/133442896): It should be possible to use the fallback
# implementation of _forward_log_det_jacobian in terms of
# _inverse_log_det_jacobian in the base Bijector class.
return -self._inverse_log_det_jacobian(self.forward(x))

def _inverse_log_det_jacobian(self, y):
# The inverse log det jacobian (ILDJ) of the entire mapping is the sum of
# the ILDJs of each row's mapping.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# Dependency imports
from absl.testing import parameterized
import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf

from tensorflow_probability.python import bijectors as tfb
Expand Down Expand Up @@ -156,6 +157,28 @@ def testTheoreticalFldj(self):
atol=1e-5,
rtol=1e-5)

def testBijectorWithVariables(self):
x_ = np.array([1.], dtype=np.float32)
y_ = np.array([[1., 0.], [0.707107, 0.707107]], dtype=np.float32)

x = tf.Variable(x_, dtype=tf.float32)
y = tf.Variable(y_, dtype=tf.float32)
forward_event_ndims = tf.Variable(1, dtype=tf.int32)
inverse_event_ndims = tf.Variable(2, dtype=tf.int32)
self.evaluate(tf1.global_variables_initializer())

bijector = tfb.CorrelationCholesky()
self.assertAllClose(
y_, self.evaluate(bijector.forward(x)), atol=1e-5, rtol=1e-5)
self.assertAllClose(
x_, self.evaluate(bijector.inverse(y)), atol=1e-5, rtol=1e-5)

fldj = bijector.forward_log_det_jacobian(x, event_ndims=forward_event_ndims)
self.assertAllClose(-np.log(2), self.evaluate(fldj))

ildj = bijector.inverse_log_det_jacobian(y, event_ndims=inverse_event_ndims)
self.assertAllClose(np.log(2), ildj)

@parameterized.parameters(itertools.product([2, 3, 4, 5, 6, 7], [1., 2., 3.]))
def testWithLKJSamples(self, dimension, concentration):
bijector = tfb.CorrelationCholesky()
Expand Down

0 comments on commit 1231703

Please sign in to comment.