Skip to content

Commit

Permalink
Revert "Add numpy/jax rewrite for cumulative_logsumexp."
Browse files Browse the repository at this point in the history
This reverts commit 82520ca for
compatibility with TF 2.11.
  • Loading branch information
emilyfertig committed Dec 5, 2022
1 parent 17ddf7c commit 9600b80
Show file tree
Hide file tree
Showing 6 changed files with 4 additions and 39 deletions.
1 change: 1 addition & 0 deletions tensorflow_probability/python/experimental/mcmc/BUILD
Expand Up @@ -704,6 +704,7 @@ multi_substrate_py_library(
"//tensorflow_probability/python/internal:prefer_static",
"//tensorflow_probability/python/internal:tensor_util",
"//tensorflow_probability/python/internal:tensorshape_util",
"//tensorflow_probability/python/math:generic",
"//tensorflow_probability/python/math:gradient",
"//tensorflow_probability/python/mcmc/internal:util",
],
Expand Down
Expand Up @@ -21,6 +21,7 @@
from tensorflow_probability.python.distributions import uniform
from tensorflow_probability.python.internal import distribution_util as dist_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.math.generic import log_cumsum_exp
from tensorflow_probability.python.math.gradient import value_and_gradient
from tensorflow_probability.python.mcmc.internal import util as mcmc_util

Expand Down Expand Up @@ -134,7 +135,7 @@ def _resample_using_log_points(log_probs, sample_shape, log_points, name=None):
tf.zeros(points_shape, dtype=tf.int32)],
axis=-1)
log_marker_positions = tf.broadcast_to(
tf.math.cumulative_logsumexp(log_probs, axis=-1),
log_cumsum_exp(log_probs, axis=-1),
markers_shape)
log_markers_and_points = ps.concat(
[log_marker_positions, log_points], axis=-1)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_probability/python/internal/backend/numpy/BUILD
Expand Up @@ -469,7 +469,7 @@ py_test(
"--test_mode=xla",
# TODO(b/168718272): reduce_*([nan, nan], axis=0) (GPU)
# histogram_fixed_width_bins fails with f32([0.]), [0.0, 0.0], 2
"--xla_disabled=math.cumulative_logsumexp,math.reduce_min,math.reduce_max,histogram_fixed_width_bins",
"--xla_disabled=math.reduce_min,math.reduce_max,histogram_fixed_width_bins",
],
main = "numpy_test.py",
shard_count = 11,
Expand Down
23 changes: 0 additions & 23 deletions tensorflow_probability/python/internal/backend/numpy/numpy_math.py
Expand Up @@ -61,7 +61,6 @@
'count_nonzero',
'cumprod',
'cumsum',
'cumulative_logsumexp',
'digamma',
'divide',
'divide_no_nan',
Expand Down Expand Up @@ -261,23 +260,6 @@ def _cumop(op, x, axis=0, exclusive=False, reverse=False, name=None,
_cumsum = utils.partial(_cumop, np.cumsum, initial_value=0.)


def _cumulative_logsumexp(x, axis=0, exclusive=False, reverse=False, name=None):
del name
axis = int(axis)
if axis < 0:
axis = axis + len(x.shape)
if JAX_MODE:
op = jax.lax.cumlogsumexp
else:
op = np.logaddexp.accumulate
return _cumop(
op, x,
axis=axis,
exclusive=exclusive,
reverse=reverse,
initial_value=-np.inf)


def _equal(x, y, name=None):
del name
x = _convert_to_tensor(x)
Expand Down Expand Up @@ -578,11 +560,6 @@ def _unsorted_segment_sum(data, segment_ids, num_segments, name=None):
'tf.math.cumsum',
_cumsum)

cumulative_logsumexp = utils.copy_docstring(
'tf.math.cumulative_logsumexp',
_cumulative_logsumexp)


digamma = utils.copy_docstring(
'tf.math.digamma',
lambda x, name=None: scipy_special.digamma(x))
Expand Down
10 changes: 0 additions & 10 deletions tensorflow_probability/python/internal/backend/numpy/numpy_test.py
Expand Up @@ -1203,16 +1203,6 @@ def _not_implemented(*args, **kwargs):
hps.booleans()).map(lambda x: x[0] + (x[1], x[2]))
],
xla_const_args=(1, 2, 3)),
TestCase(
'math.cumulative_logsumexp', [
hps.tuples(
array_axis_tuples(
elements=floats(min_value=-1e12, max_value=1e12)),
hps.booleans(),
hps.booleans()).map(lambda x: x[0] + (x[1], x[2]))
],
rtol=6e-5,
xla_const_args=(1, 2, 3)),
]

NUMPY_TEST_CASES += [ # break the array for pylint to not timeout.
Expand Down
4 changes: 0 additions & 4 deletions tensorflow_probability/python/math/generic.py
Expand Up @@ -30,7 +30,6 @@
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow_probability.python.internal import variadic_reduce
from tensorflow_probability.python.math.scan_associative import scan_associative
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import


__all__ = [
Expand Down Expand Up @@ -90,9 +89,6 @@ def log_combinations(n, counts, name='log_combinations'):

# TODO(b/154562929): Remove this once the built-in op supports XLA.
# TODO(b/156297366): Derivatives of this function may not always be correct.
@deprecation.deprecated('2023-03-01',
'`log_cumsum_exp` is deprecated; '
' Use `tf.math.cumulative_logsumexp` instead.')
def log_cumsum_exp(x, axis=-1, name=None):
"""Computes log(cumsum(exp(x))).
Expand Down

0 comments on commit 9600b80

Please sign in to comment.