Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e74b9cb
Track parameters in STS subclasses of `LinearGaussianStateSpaceModel`.
davmre Nov 10, 2020
16c6093
Merge pull request #4863 from apaszke:pmap-in-axes
tensorflower-gardener Nov 11, 2020
1d5b092
Rewrite RunningMean as a CompositeTensor with methods, instead of the
axch Nov 11, 2020
3602bfe
Bug fix: Parabolic was incorrectly mapped to Epanechnikov, lacked a t…
brianwa84 Nov 11, 2020
c901a98
Add a general-purpose bijector to numerically invert any bijective sc…
davmre Nov 11, 2020
c7915d6
Add Chandrupatla's method for finding roots of scalar functions.
davmre Nov 11, 2020
c90dbee
Update the style guide to provide guidance on overloaded operators.
davmre Nov 12, 2020
24c6c1a
Relax tolerance on `HalfStudentT` test due to JAX mode failure.
emilyfertig Nov 12, 2020
fb6706b
Fix log_prob example in RealNVP docstring
hartikainen Nov 12, 2020
0b863a8
Fix ContinuousBernoulli quantile and sampler so that it does not emit
srvasude Nov 12, 2020
34144a1
Give `tfp.math.value_and_gradient` more natural calling convention fo…
Nov 12, 2020
8fbc527
Cleaning pass on tests of potential scale reduction reducer.
axch Nov 13, 2020
35c7e57
Change the RationalQuadraticSpline example to use non-unit batch
hartikainen Nov 13, 2020
4700190
Fix bin position/slope reshaping in RationalQuadraticSpline example
hartikainen Nov 13, 2020
6161342
Cleaning pass on tests of covariance reducer.
axch Nov 16, 2020
06166b6
Use Python rather than numpy to compute constants in HalfNormal so
axch Nov 16, 2020
748c222
Use list instead of tuple for nvp.log_prob inputs
hartikainen Nov 17, 2020
ecb9e43
Bug fix for additive coupling and test to catch this bug.
a-googler Nov 17, 2020
13884d1
Add hyp2f1 for computation of Gauss' Hypergeometric Function for inpu…
srvasude Nov 17, 2020
a88a623
Delete redundant tests of tracing reducer.
axch Nov 17, 2020
698d3e5
Remove redundant tests of expectations reducer
axch Nov 17, 2020
58332f9
Make the tensorfloat32 warning more targeted to only the mcmc package.
brianwa84 Nov 17, 2020
9a51412
Remove deprecated tfb.ScaleTrilL. Use tfb.FillScaleTriL instead.
jburnim Nov 17, 2020
c9eb03a
Rewrite RunningCentralMoments as a CompositeTensor with methods, inst…
axch Nov 17, 2020
d8f67b3
Oryx: Allow reloading the distributions/bijectors modules.
SiegeLordEx Nov 18, 2020
29d9e54
Fix a numeric bug that sometimes caused Chandrapatla's method to fail…
davmre Nov 18, 2020
9b93ab9
[JAX] Delete jax.source_info_util.
hawkinsp Nov 19, 2020
dcde695
[JAX] Delete jax.lax_linalg.
hawkinsp Nov 19, 2020
105ccec
Use prefer_static in ordered_logistic.
brianwa84 Nov 19, 2020
d649037
Add tfp.random.spherical_uniform and make it robust in dimensions 1 a…
srvasude Nov 19, 2020
249adeb
Fix two broadcast-not-yet-supported errors tickled by the addition of…
brianwa84 Nov 19, 2020
1ca2d42
Catch another class of "rank not supported" errors from Argmin/Argmax.
brianwa84 Nov 19, 2020
04cd7a9
Update numpy/jax generated tensorshape rewrite, which changed upstrea…
brianwa84 Nov 19, 2020
972387d
Increase test size.
brianwa84 Nov 19, 2020
d1795f8
Improve numerics for bessel_ive, bessel_kve by moving more computatio…
srvasude Nov 20, 2020
33881c0
Allow RunningVariance to be initialized from num_counts, mean, variance.
ColCarroll Nov 20, 2020
3e2ebc4
DPP has a non-trainable parameter `eigenvectors` (unless/until we add…
brianwa84 Nov 20, 2020
48ab5b0
Bump atol for vmap Beta log_prob.
brianwa84 Nov 20, 2020
1fd985d
Set the version for the TFP 0.12-rc2 release.
jburnim Nov 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions STYLE_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ they supersede all previous conventions.
* Definitely use named args for 2nd args onward in docstrings.

1. Use names which describe semantics, not computation or mathematics, e.g.,
avoid `xp1 = x+1` or `tfd.Normal(loc=mu, scale=sigma)`.
avoid `xp1 = x + 1` or `tfd.Normal(loc=mu, scale=sigma)`.

1. Prefer inlining intermediates which are used once.

Expand Down Expand Up @@ -157,16 +157,16 @@ they supersede all previous conventions.

1. Prefer using the most specific TF operator. E.g,

* Use `tf.squared_difference(x,y)` over `(x-y)**2`.
* Use `tf.rsqrt` over `1./tf.sqrt(x)`.
* Use `tf.squared_difference(x, y)` over `(x - y)**2`.
* Use `tf.rsqrt` over `1. / tf.sqrt(x)`.

1. Worry about gradients! (It's often not automatic for API builders!)

1. When forced to choose between FLOPS and numerical accuracy, prefer numerical
accuracy.

1. Avoid tf.cast if possible. Eg, prefer `tf.where(cond, a, b)` over
`tf.cast(cond,dtype=a.dtype)*a + (1-tf.cast(cond,dtype=b.dtype)*b`
1. Avoid tf.cast if possible. Eg, prefer `tf.where(pred, a, b)` over
`tf.cast(cond, dtype=a.dtype) * a + (1 - tf.cast(cond, dtype=b.dtype) * b`

1. Preserve static shape hints.

Expand Down Expand Up @@ -217,3 +217,15 @@ they supersede all previous conventions.
`Tensor`s, and Numpy objects. When converting a user-provided literal to a
`Tensor` (see e.g. `Distribution._call_log_prob`), specify the dtype to
`tf.convert_to_tensor` if it is available.

1. Prefer overloaded operators on `Tensor`s (`+`, `-`, etc.) to explicit
method calls (`tf.add`, `tf.sub`, etc.). Exceptions:

* Prefer `tf.equal` to `==` when checking element-wise equality, because the
semantics of the latter are inconsistent between eager and graph
(`tf.function`) modes.
* Use `&` and `|` only if you want bitwise logic. Note that these are
equivalent to logical ops only if all inputs are `bool`s or are in
`{0, 1}`.


17 changes: 4 additions & 13 deletions spinoffs/oryx/oryx/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,15 @@
from oryx.bijectors import bijector_extensions
from tensorflow_probability.substrates import jax as tfp

__all__ = [
'bijector_extensions'
]

tfb = tfp.bijectors

_bijectors = {}
__all__ = tfb.__all__

for name in dir(tfb):
for name in __all__:
bij = getattr(tfb, name)
if inspect.isclass(bij) and issubclass(bij, tfb.Bijector):
if bij is not tfb.Bijector:
bij = bijector_extensions.make_type(bij)
_bijectors[name] = bij


for key, val in _bijectors.items():
locals()[key] = val

locals()[name] = bij

del _bijectors
del tfb
4 changes: 2 additions & 2 deletions spinoffs/oryx/oryx/core/interpreters/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def process_higher_order_primitive(self, primitive, f, tracers, params,
params = params.copy()
new_params = dict(
params,
mapped_invars=(True,) * len(tree_util.tree_leaves(plants)) +
params['mapped_invars'])
in_axes=(0,) * len(tree_util.tree_leaves(plants)) +
params['in_axes'])
else:
new_params = dict(params)
all_args, all_tree = tree_util.tree_flatten((plants, vals))
Expand Down
4 changes: 2 additions & 2 deletions spinoffs/oryx/oryx/core/interpreters/inverse/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,8 @@ def remove_slice(cell):
flat_vals, in_tree = tree_util.tree_flatten((mapped_incells, mapped_outcells))
f, aux = flat_propagate(f, in_tree)
# Assume all invars as mapped
new_mapped_invars = (True,) * len(flat_vals)
new_params = dict(params, mapped_invars=new_mapped_invars)
new_in_axes = (0,) * len(flat_vals)
new_params = dict(params, in_axes=new_in_axes)
if 'donated_invars' in params:
new_params['donated_invars'] = (False,) * len(flat_vals)
subenv_vals = prim.bind(f, *flat_vals, **new_params)
Expand Down
21 changes: 7 additions & 14 deletions spinoffs/oryx/oryx/core/interpreters/unzip.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
from jax import core as jax_core
from jax import custom_derivatives as cd
from jax import linear_util as lu
from jax import source_info_util
from jax import tree_util
from jax import util as jax_util
from jax._src import source_info_util
from jax.interpreters import partial_eval as pe
import numpy as onp

Expand Down Expand Up @@ -282,14 +282,13 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
return current_custom_rules()[call_primitive](self, f, *tracers, **params)
if call_primitive in pe.call_partial_eval_rules:
raise NotImplementedError
in_pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
in_pvals = [t.pval for t in tracers]
if is_map:
pvs = [
None if pv is None else mapped_aval(params['axis_size'], pv)
for pv in in_pvs
]
else:
pvs = in_pvs
unknown = pe.PartialVal.unknown
in_pvals = [pval if pval.is_known() or in_axis is None else
unknown(mapped_aval(params['axis_size'], in_axis, pval[0]))
for pval, in_axis in zip(in_pvals, params['in_axes'])]
pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
keys = tuple(t.is_key() for t in tracers)
new_settings = UnzipSettings(settings.tag, call_primitive in block_registry)
fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings)
Expand Down Expand Up @@ -360,12 +359,6 @@ def _bound_output_tracers(self, primitive, params, jaxpr, consts, env,
for pv, const, key in safe_zip(out_pvs, out_consts, out_keys)
]
new_params = dict(params, name=name, call_jaxpr=lifted_jaxpr)
if is_map:
new_params = dict(
new_params,
mapped_invars=tuple([True] * len(const_tracers) +
[False] * len(env_tracers) +
[True] * len(in_tracers)))
if 'donated_invars' in params:
new_donated_invars = (
(False,) * len(const_tracers) + (False,) * len(env_tracers) +
Expand Down
19 changes: 4 additions & 15 deletions spinoffs/oryx/oryx/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,12 @@
from oryx.distributions import distribution_extensions
from tensorflow_probability.substrates import jax as tfp

__all__ = [
'distribution_extensions'
]


tfd = tfp.distributions

_distributions = {}
__all__ = tfd.__all__

for name in dir(tfd):
for name in __all__:
dist = getattr(tfd, name)
_distributions[name] = dist


for key, val in _distributions.items():
locals()[key] = val

locals()[name] = dist

del _distributions
del distribution_extensions # Only needed for registration.
del tfd
2 changes: 1 addition & 1 deletion spinoffs/oryx/oryx/experimental/nn/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_check_grads(self):
net = net_init.init(net_rng, state.Shape(in_shape))

x = random.normal(data_rng, in_shape)
jtu.check_grads(net, (x,), 2)
jtu.check_grads(net.call, (x,), 2)


def mse(x, y):
Expand Down
24 changes: 15 additions & 9 deletions tensorflow_probability/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,23 @@
from __future__ import division
from __future__ import print_function

import functools

from tensorflow_probability.python.internal import all_util
from tensorflow_probability.python.internal import lazy_loader


# Ensure TensorFlow is importable and its version is sufficiently recent. This
# needs to happen before anything else, since the imports below will try to
# import tensorflow, too.
# pylint: disable=g-import-not-at-top
def _ensure_tf_install():
"""Attempt to import tensorflow, and ensure its version is sufficient.
def _validate_tf_environment(package):
"""Check TF version and (depending on package) warn about TensorFloat32.

Args:
package: Python `str` indicating which package is being imported. Used for
package-dependent warning about TensorFloat32.

Raises:
ImportError: if either tensorflow is not importable or its version is
inadequate.
inadequate.
"""
try:
import tensorflow.compat.v1 as tf
Expand Down Expand Up @@ -62,9 +65,10 @@ def _ensure_tf_install():
required=required_tensorflow_version,
present=tf.__version__))

if tf.config.experimental.tensor_float_32_execution_enabled():
if (package == 'mcmc' and
tf.config.experimental.tensor_float_32_execution_enabled()):
# Must import here, because symbols get pruned to __all__.
import warnings # pylint: disable=g-import-not-at-top
import warnings
warnings.warn(
'TensorFloat-32 matmul/conv are enabled for NVIDIA Ampere+ GPUs. The '
'resulting loss of precision may hinder MCMC convergence. To turn off, '
Expand Down Expand Up @@ -94,6 +98,8 @@ def _ensure_tf_install():
for pkg in _allowed_symbols:
globals()[pkg] = lazy_loader.LazyLoader(
pkg, globals(), 'tensorflow_probability.python.{}'.format(pkg),
on_first_access=_ensure_tf_install)
# These checks need to happen before lazy-loading, since the modules
# themselves will try to import tensorflow, too.
on_first_access=functools.partial(_validate_tf_environment, pkg))

all_util.remove_undocumented(__name__, _allowed_symbols)
2 changes: 0 additions & 2 deletions tensorflow_probability/python/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from tensorflow_probability.python.bijectors.expm1 import Log1p
from tensorflow_probability.python.bijectors.ffjord import FFJORD
from tensorflow_probability.python.bijectors.fill_scale_tril import FillScaleTriL
from tensorflow_probability.python.bijectors.fill_scale_tril import ScaleTriL
from tensorflow_probability.python.bijectors.fill_triangular import FillTriangular
from tensorflow_probability.python.bijectors.frechet_cdf import FrechetCDF
from tensorflow_probability.python.bijectors.generalized_pareto import GeneralizedPareto
Expand Down Expand Up @@ -159,7 +158,6 @@
"ScaleMatvecLinearOperatorBlock",
"ScaleMatvecLU",
"ScaleMatvecTriL",
"ScaleTriL",
"Shift",
"ShiftedGompertzCDF",
"Sigmoid",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
'ScaleMatvecTriL',
'Shift',
'ShiftedGompertzCDF',
'ScaleTriL',
'Sigmoid',
'Sinh',
'SinhArcsinh',
Expand Down
55 changes: 0 additions & 55 deletions tensorflow_probability/python/bijectors/fill_scale_tril.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@
from tensorflow_probability.python.bijectors import transform_diagonal
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import tensor_util
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import


__all__ = [
'FillScaleTriL',
'ScaleTriL',
]


Expand Down Expand Up @@ -127,56 +125,3 @@ def __init__(self,
validate_args=validate_args,
parameters=parameters,
name=name)


class ScaleTriL(chain.Chain):
"""DEPRECATED. Please use `tfp.bijectors.FillScaleTriL`."""

@deprecation.deprecated(
'2020-01-01',
'`ScaleTriL` has been deprecated and renamed `FillScaleTriL`; please use '
'that symbol instead.')
def __init__(self,
diag_bijector=None,
diag_shift=1e-5,
validate_args=False,
name='scale_tril'):
"""Instantiates the `ScaleTriL` bijector.

Args:
diag_bijector: `Bijector` instance, used to transform the output diagonal
to be positive.
Default value: `None` (i.e., `tfb.Softplus()`).
diag_shift: Float value broadcastable and added to all diagonal entries
after applying the `diag_bijector`. Setting a positive
value forces the output diagonal entries to be positive, but
prevents inverting the transformation for matrices with
diagonal entries less than this value.
Default value: `1e-5`.
validate_args: Python `bool` indicating whether arguments should be
checked for correctness.
Default value: `False` (i.e., arguments are not validated).
name: Python `str` name given to ops managed by this object.
Default value: `scale_tril`.
"""
parameters = dict(locals())
with tf.name_scope(name) as name:
if diag_bijector is None:
diag_bijector = softplus.Softplus(validate_args=validate_args)

if diag_shift is not None:
dtype = dtype_util.common_dtype([diag_bijector, diag_shift], tf.float32)
diag_shift = tensor_util.convert_nonref_to_tensor(diag_shift,
name='diag_shift',
dtype=dtype)
diag_bijector = chain.Chain([
shift.Shift(diag_shift),
diag_bijector
])

super(ScaleTriL, self).__init__(
[transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector),
fill_triangular.FillTriangular()],
validate_args=validate_args,
parameters=parameters,
name=name)
5 changes: 2 additions & 3 deletions tensorflow_probability/python/bijectors/glow.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def bijector_fn(inputs, ignored_input):
output = this_shift(this_scale)
elif target_shape[-1] == output_shape[-1]:

output = shift.Shift(possible_output[..., c:])
output = shift.Shift(possible_output[..., :c])
else:
raise ValueError('Shape inconsistent with input. Expected shape'
'{0} or {1} but tensor was shape {2}'.format(
Expand Down Expand Up @@ -676,7 +676,7 @@ def bijector_fn(inputs, ignored_input):
output = this_shift(this_scale)
elif input_shape[-1] == output_shape[-1]:

output = shift.Shift(possible_output[..., c:])
output = shift.Shift(possible_output[..., :c])
else:
raise ValueError('Shape inconsistent with input. Expected shape'
'{0} or {1} but tensor was shape {2}'.format(
Expand Down Expand Up @@ -860,4 +860,3 @@ def __init__(self, input_shape, output_chan, kernel_shape=3):
super(GlowDefaultExitNetwork, self).__init__([
tfkl.Input(input_shape),
conv(this_nchan, kernel_shape)])

29 changes: 29 additions & 0 deletions tensorflow_probability/python/bijectors/glow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,5 +351,34 @@ def float64_exit(input_shape, output_chan):
self.assertAllFinite(self.evaluate(z))
self.assertAllFinite(self.evaluate(zf64))

def testBijectorFn(self):
"""Test if the bijector function works for additive coupling."""
ims = self._make_images()
def shiftfn(input_shape):
input_nchan = input_shape[-1]
return tf.keras.Sequential([
tf.keras.layers.Input(input_shape),
tf.keras.layers.Conv2D(
input_nchan, 3, padding='same')])

def shiftexitfn(input_shape, output_chan):
return tf.keras.Sequential([
tf.keras.layers.Input(input_shape),
tf.keras.layers.Conv2D(
output_chan, 3, padding='same')])

shiftonlyglow = tfb.Glow(
output_shape=self.output_shape,
num_glow_blocks=2,
num_steps_per_block=1,
coupling_bijector_fn=shiftfn,
exit_bijector_fn=shiftexitfn,
grab_after_block=[0.5, 0.5]
)
z = shiftonlyglow.inverse(ims)
self.evaluate([v.initializer for v in shiftonlyglow.variables])
self.assertAllFinite(self.evaluate(z))


if __name__ == '__main__':
tf.test.main()
3 changes: 0 additions & 3 deletions tensorflow_probability/python/bijectors/hypothesis_testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,6 @@ def bijector_supports():
'ScaleMatvecTriL':
BijectorSupport(Support.VECTOR_UNCONSTRAINED,
Support.VECTOR_UNCONSTRAINED),
'ScaleTriL':
BijectorSupport(Support.VECTOR_SIZE_TRIANGULAR,
Support.MATRIX_LOWER_TRIL_POSITIVE_DEFINITE),
'Shift':
BijectorSupport(Support.SCALAR_UNCONSTRAINED,
Support.SCALAR_UNCONSTRAINED),
Expand Down
Loading