Skip to content

Commit

Permalink
Make tfb.Reshape work with JAX omnistaging.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 332064660
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Sep 16, 2020
1 parent 0f5c693 commit 782d0c6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 21 deletions.
1 change: 1 addition & 0 deletions tensorflow_probability/python/bijectors/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ multi_substrate_py_library(
"//tensorflow_probability/python/internal:assert_util",
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/internal:nest_util",
"//tensorflow_probability/python/internal:prefer_static",
"//tensorflow_probability/python/internal:tensor_util",
"//tensorflow_probability/python/internal:tensorshape_util",
],
Expand Down
37 changes: 16 additions & 21 deletions tensorflow_probability/python/bijectors/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util

Expand Down Expand Up @@ -150,9 +151,11 @@ def __init__(self, event_shape_out, event_shape_in=(-1,),
dtype = dtype_util.common_dtype(
[event_shape_out, event_shape_in], dtype_hint=tf.int32)
event_shape_out = tensor_util.convert_nonref_to_tensor(
event_shape_out, name='event_shape_out', dtype=dtype)
event_shape_out, name='event_shape_out', dtype=dtype,
as_shape_tensor=True)
event_shape_in = tensor_util.convert_nonref_to_tensor(
event_shape_in, name='event_shape_in', dtype=dtype)
event_shape_in, name='event_shape_in', dtype=dtype,
as_shape_tensor=True)

forward_min_event_ndims_ = _rank_from_shape(event_shape_in)
if forward_min_event_ndims_ is None:
Expand Down Expand Up @@ -189,15 +192,15 @@ def _parameter_control_dependencies(self, is_init):

def _forward(self, x):
output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor(
tf.shape(x), self._event_shape_in, self._event_shape_out,
ps.shape(x), self._event_shape_in, self._event_shape_out,
self.validate_args)
y = tf.reshape(x, output_shape)
tensorshape_util.set_shape(y, output_tensorshape)
return y

def _inverse(self, y):
output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor(
tf.shape(y), self._event_shape_out, self._event_shape_in,
ps.shape(y), self._event_shape_out, self._event_shape_in,
self.validate_args)
x = tf.reshape(y, output_shape)
tensorshape_util.set_shape(x, output_tensorshape)
Expand Down Expand Up @@ -259,27 +262,19 @@ def _replace_event_shape_in_shape_tensor(
event_shape_in,
event_shape_out)

# TODO(b/124240153): Remove map(tf.identity, deps) once tf.function
# correctly supports control_dependencies.
validation_dependencies = (
map(tf.identity, (event_shape_in, event_shape_out))
if validate_args else ())

if (tensorshape_util.is_fully_defined(output_tensorshape) and
(is_validated or not validate_args)):
with tf.control_dependencies(validation_dependencies):
output_shape = tf.convert_to_tensor(
tensorshape_util.as_list(output_tensorshape), name='output_shape',
dtype_hint=tf.int32)
output_shape = ps.convert_to_shape_tensor(
tensorshape_util.as_list(output_tensorshape), name='output_shape',
dtype_hint=tf.int32)
return output_shape, output_tensorshape

with tf.control_dependencies(validation_dependencies):
event_shape_in_ndims = (
tf.size(event_shape_in)
if tensorshape_util.num_elements(event_shape_in.shape) is None else
tensorshape_util.num_elements(event_shape_in.shape))
input_non_event_shape, input_event_shape = tf.split(
input_shape, num_or_size_splits=[-1, event_shape_in_ndims])
event_shape_in_ndims = (
tf.size(event_shape_in)
if tensorshape_util.num_elements(event_shape_in.shape) is None else
tensorshape_util.num_elements(event_shape_in.shape))
input_non_event_shape, input_event_shape = tf.split(
input_shape, num_or_size_splits=[-1, event_shape_in_ndims])

additional_assertions = []
if is_validated:
Expand Down

0 comments on commit 782d0c6

Please sign in to comment.