Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
68872f2
Add transition kernel for online adaptation of a diagonal mass matrix…
ColCarroll Nov 20, 2020
6c5f081
Disable problematic pytype checks in a few places. Add type hints to …
csuter Nov 20, 2020
c879b80
Add `tfp.math.reduce_kahan_sum`, supported for tf, jax, numpy + xla.
brianwa84 Nov 20, 2020
1d7bf08
[Oryx] Fix bug in reshape ILDJ rule
sharadmv Nov 20, 2020
52f511e
Remove duplicate `distributions` type annotation from python/__init__.py
csuter Nov 20, 2020
0352e20
Introduce Composition base-class for pipelined bijectors.
a-googler Nov 20, 2020
4f20221
Enable use of `auto_composite_tensor` without altering extra class hi…
csuter Nov 20, 2020
8b4d6bf
Rename `tfp.experimental.stats.RunningVariance.init_from_stats` to ju…
ColCarroll Nov 23, 2020
598e8a3
Fix ildj gradients in ScalarFunctionWithInferredInverse bijector.
davmre Nov 23, 2020
b98a9b2
Improve numeric stability of HalfNormal log probs.
davmre Nov 24, 2020
b0d2145
Merge pull request #4904 from apaszke:pmap-out-axes
tensorflower-gardener Nov 24, 2020
f3d84ff
Prune redundant sample_stats tests, and use tf.while throughout inste…
axch Nov 24, 2020
11f10f8
Add utility that abstracts over tracing a callable in JAX and TF.
davmre Nov 24, 2020
e187e29
Remove inaccurate comment about parameter properties API.
davmre Nov 25, 2020
4c2fed9
Merge pull request #1142 from blacksde:fix-geo-exp-under-zero
tensorflower-gardener Nov 25, 2020
1e7ae5b
Merge pull request #1140 from blacksde:gev_dist_pr
tensorflower-gardener Nov 25, 2020
603ced5
Make GradientBasedTrajectoryAdaptation infer the dtype from log-accep…
SiegeLordEx Nov 25, 2020
2e2d7ac
Make tfb.LKJ.experimental_default_event_space_bijector.forward output…
SiegeLordEx Nov 25, 2020
424a35f
Add the TFP_HYPOTHESIS_TIMEOUT_SECS environment variable to control t…
axch Nov 25, 2020
51659e0
fix validation of batched initial_inverse_hessian_estimate
jeffpollock9 Nov 27, 2020
1540820
Use legacy pip resolver when looking for valid tf-nightly version
sharadmv Dec 1, 2020
6177e18
Fix bug in `Distribution.__str__` affecting multipart bijector `Trans…
emilyfertig Dec 1, 2020
14d3143
Switch to GitHub Actions for the OSS CI.
SiegeLordEx Dec 1, 2020
c0c8b1e
Display fixes to "Auto-Batched Joint Distributions: A Gentle Tutorial".
ColCarroll Dec 1, 2020
fffd1c6
Actually randomize Hypothesis runs started with TFP_RANDOMIZE_HYPOTHE…
axch Dec 1, 2020
ac771a7
Relax `rtol` in `tfb.Sigmoid` test.
emilyfertig Dec 1, 2020
5cc3c2b
Use PyPI package JSON to find valid tf-nightly packages
sharadmv Dec 1, 2020
55528b6
Override the super's `parameters` field in ShardedSample.
SiegeLordEx Dec 2, 2020
c48fe71
Ignore non-positive definite examples in hypothesis tests.
ColCarroll Dec 2, 2020
25d8d5f
Fix typo (ijber -> über) in gram_schmidt.py
theblackfly Dec 2, 2020
66704cf
Fix docstring rendering error for tfd.Empirical.
davmre Dec 2, 2020
60714a3
Move tfb.Restructure examples to module docstring.
ColCarroll Dec 2, 2020
e755efc
Make API for `JointDistribution.sample` and `JDAutoBatched.sample` ma…
ColCarroll Dec 2, 2020
c07a92a
Add `bracket_root` method to automatically initialize bounds for a ro…
davmre Dec 2, 2020
95e25d8
[Oryx] Fix name change from key -> closure
sharadmv Dec 2, 2020
1c919ef
Update CONTRIBUTING.md to mention Github Actions rather than Travis.
SiegeLordEx Dec 2, 2020
c444ea2
Merge pull request #1186 from theblackfly:theblackfly-patch-1
tensorflower-gardener Dec 2, 2020
899c07e
Adds a test for non-list sequences of multi-part state. Makes a small…
brianwa84 Dec 2, 2020
8f94d90
Add a test that verifies the Kahan correction term can be nonzero.
brianwa84 Dec 2, 2020
6e4318e
Allow using arbitrary structures for make_sharded_log_prob_parts.
SiegeLordEx Dec 2, 2020
e3564c2
Undo accidental test disabling.
brianwa84 Dec 3, 2020
034a632
Add `experimental_use_kahan_sum` argument to `tfd.Independent` and `t…
brianwa84 Dec 3, 2020
44305d7
Use new spherical uniform sampler in tfd.SphericalUniform and mcmc.Sl…
csuter Dec 3, 2020
ccd9f72
Invert bijector works with Chain bijector a little more nicely.
ColCarroll Dec 4, 2020
60e608b
Add exponentially-modified Gaussian distribution to TensorFlow Probab…
a-googler Dec 4, 2020
90209b4
Add `experimental_use_kahan_sum` option to auto-batched joint distrib…
brianwa84 Dec 4, 2020
9b57197
Replace local rademachers with tfp.random.rademacher
hartikainen Dec 4, 2020
85aab1f
Fix {conv,dense}_variational_test rademachers
hartikainen Dec 4, 2020
c41de6f
Add new MCMC driver `run_kernel` for conveniently using the features …
axch Dec 4, 2020
574dbb0
Filter positive definite errors in precision test.
ColCarroll Dec 7, 2020
10872eb
Implement stddev and variance for some TransformedDistributions with …
davmre Dec 7, 2020
866cfbc
Remove inaccurate comment about numpy.cov.
a-googler Dec 8, 2020
8afd0c6
Add `prefer_static.reverse`.
emilyfertig Dec 8, 2020
9bebb90
Attempt to fix rendering error in JointMap docstring.
davmre Dec 8, 2020
50a1ca9
Add another structured state test, then make it pass with a small cha…
brianwa84 Dec 8, 2020
60edbd3
Numpy fix for Multinomial sampling.
brianwa84 Dec 8, 2020
988a3f8
Apply _convert_to_tensor in exp implementation in NumPy/JAX backends.
jburnim Dec 8, 2020
3de3fe0
Set the version for the TFP 0.12-rc4 release.
jburnim Dec 8, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions .github/workflows/continuous-integration.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2020 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
name: Tests
on: [push, pull_request]
env:
TEST_VENV_PATH: ~/test_virtualenv
jobs:
lints:
name: Lints
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7]
steps:
- name: Checkout
uses: actions/checkout@v1
with:
fetch-depth: 20
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Setup virtualenv
run: |
sudo apt install virtualenv
virtualenv -p python${{ matrix.python-version }} ${TEST_VENV_PATH}
- name: Lints
run: |
source ${TEST_VENV_PATH}/bin/activate
./testing/run_github_lints.sh
tests:
name: Tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7]
shard: [0, 1, 2, 3, 4]
env:
TEST_VENV_PATH: ~/test_virtualenv
SHARD: ${{ matrix.shard }}
NUM_SHARDS: 5
steps:
- name: Checkout
uses: actions/checkout@v1
with:
fetch-depth: 1
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Setup virtualenv
run: |
sudo apt install virtualenv
virtualenv -p python${{ matrix.python-version }} ${TEST_VENV_PATH}
- name: Tests
run: |
source ${TEST_VENV_PATH}/bin/activate
./testing/run_github_tests.sh
56 changes: 0 additions & 56 deletions .travis.yml

This file was deleted.

19 changes: 9 additions & 10 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,19 @@ repository (with credit to the original author) and closes the pull request.

## Continuous Integration

We use [Travis CI](https://travis-ci.org/tensorflow/probability) to do automated
style checking and run unit-tests (discussed in more detail below). A build
will be triggered when you open a pull request, or update the pull request by
adding a commit, rebasing etc.
We use [GitHub Actions](https://github.com/tensorflow/probability/actions) to do
automated style checking and run unit-tests (discussed in more detail below). A
build will be triggered when you open a pull request, or update the pull request
by adding a commit, rebasing etc.

We test against TensorFlow nightly on Python 2.7 and 3.6. We shard our tests
We test against TensorFlow nightly on Python 3.7. We shard our tests
across several build jobs (identified by the `SHARD` environment variable).
Linting, in particular, is only done on the first shard, so look at that shard's
logs for lint errors if any.
Lints are also done in a separate job.

All pull-requests will need to pass the automated lint and unit-tests before
being merged. As Travis-CI tests can take a bit of time, see the following
sections on how to run the lint checks and unit-tests locally while you're
developing your change.
being merged. As the tests can take a bit of time, see the following sections
on how to run the lint checks and unit-tests locally while you're developing
your change.

## Style

Expand Down
2 changes: 1 addition & 1 deletion discussion/fun_mcmc/prefab.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def kernel(adaptive_hmc_state):
hmc_state.state,
axis=tuple(range(chain_ndims)) if chain_ndims else None,
window_size=int(np.prod(hmc_state.target_log_prob.shape)) *
variance_window_steps)
variance_window_steps) # pytype: disable=wrong-arg-types

if num_adaptation_steps is not None:
# Take care of adaptation for variance and step size.
Expand Down
5 changes: 1 addition & 4 deletions spinoffs/inference_gym/inference_gym/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# A package for target densities and benchmarking of inference algorithms
# against the same.

# [internal] load pytype.bzl (pytype_library, pytype_strict_library)
# [internal] load pytype.bzl (pytype_strict_library)
# [internal] load dummy dependency

package(
Expand All @@ -42,7 +42,6 @@ py_library(
],
)

# pytype
py_library(
name = "using_numpy",
srcs = ["using_numpy.py"],
Expand All @@ -56,7 +55,6 @@ py_library(
],
)

# pytype
py_library(
name = "using_jax",
srcs = ["using_jax.py"],
Expand All @@ -71,7 +69,6 @@ py_library(
],
)

# pytype
py_library(
name = "using_tensorflow",
srcs = ["using_tensorflow.py"],
Expand Down
15 changes: 10 additions & 5 deletions spinoffs/oryx/oryx/core/interpreters/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,22 +333,27 @@ def process_higher_order_primitive(self, primitive, f, tracers, params,
if is_map:
# TODO(sharadmv): figure out if invars are mapped or unmapped
params = params.copy()
out_axes_thunk = params['out_axes_thunk']
@jax_util.as_hashable_function(closure=('harvest', out_axes_thunk))
def new_out_axes_thunk():
out_axes = out_axes_thunk()
assert all(out_axis == 0 for out_axis in out_axes)
return (0,) * out_tree().num_leaves
new_params = dict(
params,
in_axes=(0,) * len(tree_util.tree_leaves(plants)) +
params['in_axes'])
in_axes=(0,) * len(tree_util.tree_leaves(plants)) + params['in_axes'],
out_axes_thunk=new_out_axes_thunk)
else:
new_params = dict(params)
all_args, all_tree = tree_util.tree_flatten((plants, vals))
num_plants = len(all_args) - len(vals)
if 'donated_invars' in params:
new_params['donated_invars'] = ((False,) * num_plants
+ params['donated_invars'])
f, aux = harvest_eval(f, self, context.settings, all_tree)
f, out_tree = harvest_eval(f, self, context.settings, all_tree)
out_flat = primitive.bind(
f, *all_args, **new_params, name=jax_util.wrap_name(name, 'harvest'))
out_tree = aux()
out, reaps = tree_util.tree_unflatten(out_tree, out_flat)
out, reaps = tree_util.tree_unflatten(out_tree(), out_flat)
out_tracers = safe_map(self.pure, out)
reap_tracers = tree_util.tree_map(self.pure, reaps)
if primitive is nest_p and reap_tracers:
Expand Down
10 changes: 8 additions & 2 deletions spinoffs/oryx/oryx/core/interpreters/inverse/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def wrapped(*args, **kwargs):
flat_incells = [InverseAndILDJ.unknown(aval) for aval in flat_forward_avals]
flat_outcells = safe_map(InverseAndILDJ.new, flat_args)
env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr,
flat_constcells, flat_incells, flat_outcells)
flat_constcells, flat_incells, flat_outcells) # pytype: disable=wrong-arg-types
flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars]
if any(not flat_incell.top() for flat_incell in flat_incells):
raise ValueError('Cannot invert function.')
Expand Down Expand Up @@ -332,7 +332,7 @@ def hop_inverse_rule(prim):
def initial_ildj(incells, outcells, *, jaxpr, num_consts, **_):
const_cells, incells = jax_util.split_list(incells, [num_consts])
env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr, const_cells,
incells, outcells)
incells, outcells) # pytype: disable=wrong-arg-types
new_incells = [env.read(invar) for invar in jaxpr.invars]
new_outcells = [env.read(outvar) for outvar in jaxpr.outvars]
return const_cells + new_incells, new_outcells, None
Expand Down Expand Up @@ -377,6 +377,12 @@ def remove_slice(cell):
new_params = dict(params, in_axes=new_in_axes)
if 'donated_invars' in params:
new_params['donated_invars'] = (False,) * len(flat_vals)
if 'out_axes' in params:
assert all(out_axis == 0 for out_axis in params['out_axes'])
new_params['out_axes_thunk'] = jax_util.HashableFunction(
lambda: (0,) * aux().num_leaves,
closure=('ildj', params['out_axes']))
del new_params['out_axes']
subenv_vals = prim.bind(f, *flat_vals, **new_params)
subenv_tree = aux()
subenv = tree_util.tree_unflatten(subenv_tree, subenv_vals)
Expand Down
8 changes: 8 additions & 0 deletions spinoffs/oryx/oryx/core/interpreters/inverse/inverse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ def f(x, y):
onp.testing.assert_allclose(y, np.ones(2))
onp.testing.assert_allclose(ildj_, 0., atol=1e-6, rtol=1e-6)

def test_inverse_of_reshape(self):
def f(x):
return np.reshape(x, (4,))
f_inv = core.inverse_and_ildj(f, np.ones((2, 2)))
x, ildj_ = f_inv(np.ones(4))
onp.testing.assert_allclose(x, np.ones((2, 2)))
onp.testing.assert_allclose(ildj_, 0.)

def test_sigmoid_ildj(self):
def naive_sigmoid(x):
# This is the default JAX implementation of sigmoid.
Expand Down
3 changes: 1 addition & 2 deletions spinoffs/oryx/oryx/core/interpreters/inverse/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,8 @@ def reshape_ildj(incells, outcells, **params):
))], None
elif outcell.top() and not incell.top():
val = outcell.val
ndslice = NDSlice.new(np.reshape(val, incell.aval.shape))
new_incells = [
InverseAndILDJ(incell.aval, [ndslice])
InverseAndILDJ.new(np.reshape(val, incell.aval.shape))
]
return new_incells, outcells, None
return incells, outcells, None
Expand Down
46 changes: 32 additions & 14 deletions spinoffs/oryx/oryx/core/interpreters/unzip.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,19 +288,29 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
in_pvals = [pval if pval.is_known() or in_axis is None else
unknown(mapped_aval(params['axis_size'], in_axis, pval[0]))
for pval, in_axis in zip(in_pvals, params['in_axes'])]
out_axes_thunk = params['out_axes_thunk']
@jax_util.as_hashable_function(closure=('unzip', out_axes_thunk))
def new_out_axes_thunk():
out_axes = out_axes_thunk()
assert all(out_axis == 0 for out_axis in out_axes)
_, num_outputs, _ = aux()
return (0,) * num_outputs
new_params = dict(params, out_axes_thunk=new_out_axes_thunk)
else:
new_params = params
pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
keys = tuple(t.is_key() for t in tracers)
new_settings = UnzipSettings(settings.tag, call_primitive in block_registry)
fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings)
out_flat = call_primitive.bind(fun, *in_consts, **params)
success, results = aux()
out_flat = call_primitive.bind(fun, *in_consts, **new_params)
success, _, results = aux()
if not success:
out_pvs, out_keys, jaxpr, env = results
out_pv_consts, consts = jax_util.split_list(out_flat, [len(out_pvs)])
out_tracers = self._bound_output_tracers(call_primitive, params, jaxpr,
consts, env, tracers, out_pvs,
out_pv_consts, out_keys, name,
is_map)
out_tracers = self._bound_output_tracers(call_primitive, new_params,
jaxpr, consts, env, tracers,
out_pvs, out_pv_consts,
out_keys, name, is_map)
return out_tracers
init_name = jax_util.wrap_name(name, 'init')
apply_name = jax_util.wrap_name(name, 'apply')
Expand All @@ -319,15 +329,16 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
[len(apply_pvs)])

variable_tracers = self._bound_output_tracers(
call_primitive, params, init_jaxpr, init_consts, init_env, key_tracers,
init_pvs, init_pv_consts, [True] * len(init_pvs), init_name, is_map)
call_primitive, new_params, init_jaxpr, init_consts, init_env,
key_tracers, init_pvs, init_pv_consts, [True] * len(init_pvs),
init_name, is_map)

unflat_variables = tree_util.tree_unflatten(variable_tree, variable_tracers)
if call_primitive is harvest.nest_p:
variable_dict = harvest.sow(
dict(safe_zip(variable_names, unflat_variables)),
tag=settings.tag,
name=params['scope'],
name=new_params['scope'],
mode='strict')
unflat_variables = tuple(variable_dict[name] for name in variable_names)
else:
Expand All @@ -342,7 +353,7 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
variable_tracers = tree_util.tree_leaves(unflat_variables)

out_tracers = self._bound_output_tracers(
call_primitive, params, apply_jaxpr, apply_consts, apply_env,
call_primitive, new_params, apply_jaxpr, apply_consts, apply_env,
variable_tracers + abstract_tracers, apply_pvs, apply_pv_consts,
apply_keys, apply_name, is_map)
return out_tracers
Expand All @@ -365,6 +376,11 @@ def _bound_output_tracers(self, primitive, params, jaxpr, consts, env,
tuple(v for v, t in zip(params['donated_invars'], in_tracers)
if not t.pval.is_known()))
new_params['donated_invars'] = new_donated_invars
if is_map:
out_axes = params['out_axes_thunk']()
assert all(out_axis == 0 for out_axis in out_axes)
new_params['out_axes'] = (0,) * len(out_tracers)
del new_params['out_axes_thunk']
eqn = pe.new_eqn_recipe(
tuple(const_tracers + env_tracers + in_tracers), out_tracers, primitive,
new_params, source_info_util.current()) # pytype: disable=wrong-arg-types
Expand Down Expand Up @@ -442,14 +458,16 @@ def unzip_eval_wrapper(pvs, *consts):
out = (
tuple(init_pv_consts) + tuple(init_consts) + tuple(apply_pv_consts) +
tuple(apply_consts))
yield out, (success, ((init_pvs, len(init_consts), apply_pvs),
(init_jaxpr, apply_jaxpr), (init_env,
apply_env), metadata))
yield out, (success, len(out),
((init_pvs, len(init_consts), apply_pvs),
(init_jaxpr, apply_jaxpr),
(init_env, apply_env),
metadata))
else:
jaxpr, (out_pvals, out_keys, consts, env) = result
out_pvs, out_consts = jax_util.unzip2(out_pvals)
out = tuple(out_consts) + tuple(consts)
yield out, (success, (out_pvs, out_keys, jaxpr, env))
yield out, (success, len(out), (out_pvs, out_keys, jaxpr, env))


@lu.transformation
Expand Down
2 changes: 1 addition & 1 deletion spinoffs/oryx/oryx/core/state/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,4 @@ def variables(self) -> Dict[str, Any]:

@ppl.log_prob.register(Module)
def module_log_prob(module, *args, **kwargs):
return log_prob.log_prob(module, *args, **kwargs)
return log_prob.log_prob(module, *args, **kwargs) # pytype: disable=wrong-arg-count
Loading