Skip to content
Merged

R0.10 #933

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tensorflow_probability/python/bijectors/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1521,12 +1521,16 @@ multi_substrate_py_test(
name = "softplus_test",
size = "small",
srcs = ["softplus_test.py"],
jax_size = "medium",
deps = [
":bijector_test_util",
":bijectors",
# absl/testing:parameterized dep,
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/internal:test_util",
"//tensorflow_probability/python/math",
# tensorflow/compiler/jit dep,
],
)

Expand Down
29 changes: 27 additions & 2 deletions tensorflow_probability/python/bijectors/softplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,31 @@
]


JAX_MODE = False # Overwritten by rewrite script.


# TODO(b/155501444): Remove this when tf.nn.softplus is fixed.
if JAX_MODE:
_stable_grad_softplus = tf.nn.softplus
else:

@tf.custom_gradient
def _stable_grad_softplus(x):
"""A (more) numerically stable softplus than `tf.nn.softplus`."""
x = tf.convert_to_tensor(x)
if x.dtype == tf.float64:
cutoff = -20
else:
cutoff = -9

y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x))

def grad_fn(dy):
return dy * tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x))

return y, grad_fn


class Softplus(bijector.Bijector):
"""Bijector which computes `Y = g(X) = Log[1 + exp(X)]`.

Expand Down Expand Up @@ -101,9 +126,9 @@ def _is_increasing(cls):

def _forward(self, x):
if self.hinge_softness is None:
return tf.math.softplus(x)
return _stable_grad_softplus(x)
hinge_softness = tf.cast(self.hinge_softness, x.dtype)
return hinge_softness * tf.math.softplus(x / hinge_softness)
return hinge_softness * _stable_grad_softplus(x / hinge_softness)

def _inverse(self, y):
if self.hinge_softness is None:
Expand Down
21 changes: 21 additions & 0 deletions tensorflow_probability/python/bijectors/softplus_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@

# Dependency imports

from absl.testing import parameterized
import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_probability.python import bijectors as tfb
from tensorflow_probability.python import math as tfp_math
from tensorflow_probability.python.bijectors import bijector_test_util
from tensorflow_probability.python.internal import test_util

Expand Down Expand Up @@ -149,6 +151,25 @@ def testVariableHingeSoftness(self):
with tf.control_dependencies([hinge_softness.assign(0.)]):
self.evaluate(b.forward(0.5))

@parameterized.named_parameters(
('32bitGraph', np.float32, False),
('64bitGraph', np.float64, False),
('32bitXLA', np.float32, True),
('64bitXLA', np.float64, True),
)
@test_util.numpy_disable_gradient_test
def testLeftTailGrad(self, dtype, do_compile):
x = np.linspace(-50., -8., 1000).astype(dtype)

@tf.function(autograph=False, experimental_compile=do_compile)
def fn(x):
return tf.math.log(tfb.Softplus().forward(x))

_, grad = tfp_math.value_and_gradient(fn, x)

true_grad = 1 / (1 + np.exp(-x)) / np.log1p(np.exp(x))
self.assertAllClose(true_grad, self.evaluate(grad), atol=1e-3)


if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -287,5 +287,7 @@ def _convert_to_dict(x):
if isinstance(x, collections.OrderedDict):
return x
if hasattr(x, '_asdict'):
return x._asdict()
# Wrap with `OrderedDict` to indicate that namedtuples have a well-defined
# order (by default, they convert to just `dict` in Python 3.8+).
return collections.OrderedDict(x._asdict())
return dict(x)
4 changes: 2 additions & 2 deletions tensorflow_probability/python/layers/distribution_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pickle

# Dependency imports
from cloudpickle import CloudPickler
from cloudpickle.cloudpickle import CloudPickler
import numpy as np
import six
import tensorflow.compat.v2 as tf
Expand All @@ -47,7 +47,7 @@
from tensorflow_probability.python.distributions import variational_gaussian_process as variational_gaussian_process_lib
from tensorflow_probability.python.internal import distribution_util as dist_util
from tensorflow_probability.python.layers.internal import distribution_tensor_coercible as dtc
from tensorflow_probability.python.layers.internal import tensor_tuple as tensor_tuple
from tensorflow_probability.python.layers.internal import tensor_tuple
from tensorflow.python.keras.utils import tf_utils as keras_tf_utils # pylint: disable=g-direct-tensorflow-import


Expand Down