Skip to content

Commit

Permalink
Change stats/, mcmc/, monte_carlo/ optimizer/ and vi/ to use most spe…
Browse files Browse the repository at this point in the history
…cific imports.

  - Modify moving_stats_test to work in the numpy backend (via stateless sampling).

PiperOrigin-RevId: 471616046
  • Loading branch information
srvasude authored and tensorflower-gardener committed Sep 1, 2022
1 parent de75e2d commit 0990bc9
Show file tree
Hide file tree
Showing 57 changed files with 2,234 additions and 1,892 deletions.
1 change: 1 addition & 0 deletions tensorflow_probability/python/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2284,6 +2284,7 @@ multi_substrate_py_library(
":kullback_leibler",
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/bijectors:invert",
"//tensorflow_probability/python/bijectors:sigmoid",
"//tensorflow_probability/python/bijectors:softmax_centered",
"//tensorflow_probability/python/bijectors:square",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ multi_substrate_py_library(
deps = [
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/distributions:mvn_tril",
"//tensorflow_probability/python/internal:prefer_static",
"//tensorflow_probability/python/internal:samplers",
"//tensorflow_probability/python/math:scan_associative",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
import numpy as np
import tensorflow.compat.v2 as tf

from tensorflow_probability.python import math as tfp_math
from tensorflow_probability.python.distributions import mvn_tril
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.math.scan_associative import scan_associative


__all__ = ['kalman_filter',
Expand Down Expand Up @@ -230,13 +230,14 @@ def sample_walk(transition_matrix,
observation_mean=observation_mean))

s1, s2, s3 = samplers.split_seed(seed, n=3)
updates = tfp_math.scan_associative(
updates = scan_associative(
combine_walk,
AffineUpdate(transition_matrix=time_dep.transition_matrix[:-1],
mean=mvn_tril.MultivariateNormalTriL(
loc=time_dep.transition_mean[:-1],
scale_tril=time_dep.transition_scale_tril[:-1]
).sample(seed=s1)))
AffineUpdate(
transition_matrix=time_dep.transition_matrix[:-1],
mean=mvn_tril.MultivariateNormalTriL(
loc=time_dep.transition_mean[:-1],
scale_tril=time_dep.transition_scale_tril[:-1]).sample(
seed=s1)))
x0 = mvn_tril.MultivariateNormalTriL(
loc=time_indep.initial_mean,
scale_tril=time_indep.initial_scale_tril).sample(seed=s2)
Expand Down Expand Up @@ -596,10 +597,9 @@ def kalman_filter(transition_matrix,
mask=observation.mask)

# Run Kalman filter.
filtered = tfp_math.scan_associative(combine_filter_elements,
filter_elements(time_indep,
time_dep,
observation))
filtered = scan_associative(
combine_filter_elements,
filter_elements(time_indep, time_dep, observation))
filtered_means = filtered.posterior_mean
filtered_covs = filtered.posterior_cov
log_likelihoods = None
Expand Down
1 change: 1 addition & 0 deletions tensorflow_probability/python/experimental/vi/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ multi_substrate_py_library(
"//tensorflow_probability/python/distributions:independent",
"//tensorflow_probability/python/distributions:joint_distribution_auto_batched",
"//tensorflow_probability/python/distributions:joint_distribution_coroutine",
"//tensorflow_probability/python/distributions:markov_chain",
"//tensorflow_probability/python/distributions:sample",
"//tensorflow_probability/python/distributions:transformed_distribution",
"//tensorflow_probability/python/distributions:truncated_normal",
Expand Down
1 change: 1 addition & 0 deletions tensorflow_probability/python/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ multi_substrate_py_test(
# optax dep,
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/experimental/util",
"//tensorflow_probability/python/internal:test_util",
"//tensorflow_probability/python/internal:trainable_state_util",
],
Expand Down
Loading

0 comments on commit 0990bc9

Please sign in to comment.