Skip to content

Commit

Permalink
Use dtype_util inside LGSSM to infer dtype.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 472602722
  • Loading branch information
srvasude authored and tensorflower-gardener committed Sep 7, 2022
1 parent 24301c5 commit 14db88b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
13 changes: 11 additions & 2 deletions tensorflow_probability/python/distributions/linear_gaussian_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,17 @@ def __init__(self,
mask, dtype_hint=tf.bool, name='mask')
self._experimental_parallelize = experimental_parallelize

# TODO(b/78475680): Friendly dtype inference.
dtype = initial_state_prior.dtype
dtype_list = [initial_state_prior,
observation_matrix,
transition_matrix,
transition_noise,
observation_noise]

# Infer dtype from time invariant objects. This list will be non-empty
# since it will always include `initial_state_prior`.
dtype = dtype_util.common_dtype(
list(filter(lambda x: not callable(x), dtype_list)),
dtype_hint=tf.float32)

# Internally, the transition and observation matrices are
# canonicalized as callables returning a LinearOperator. This
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,23 @@ def _build_iid_normal_model(self, num_timesteps, latent_size,
observation_variance):
"""Build a model whose outputs are IID normal by construction."""

transition_variance = self._build_placeholder(transition_variance)
observation_variance = self._build_placeholder(observation_variance)
transition_variance = self._build_placeholder(
self.dtype(transition_variance))
observation_variance = self._build_placeholder(
self.dtype(observation_variance))

# Use orthogonal matrices to project a (potentially
# high-dimensional) latent space of IID normal variables into a
# low-dimensional observation that is still IID normal.
random_orthogonal_matrix = lambda: np.linalg.qr(
np.random.randn(latent_size, latent_size))[0][:observation_size, :]
observation_matrix = self._build_placeholder(random_orthogonal_matrix())
observation_matrix = self._build_placeholder(
random_orthogonal_matrix().astype(self.dtype))

model = lgssm.LinearGaussianStateSpaceModel(
num_timesteps=num_timesteps,
transition_matrix=self._build_placeholder(
np.zeros([latent_size, latent_size])),
np.zeros([latent_size, latent_size]).astype(self.dtype)),
transition_noise=mvn_diag.MultivariateNormalDiag(
scale_diag=tf.sqrt(transition_variance) *
tf.ones([latent_size], dtype=self.dtype)),
Expand Down Expand Up @@ -389,23 +392,27 @@ def testExcessiveConcretizationOfParams(self):
transition_std = 3.0
observation_std = 5.0

dtype = np.float32

num_timesteps = tfp_hps.defer_and_count_usage(
tf.Variable(1, name='num_timesteps'))
transition_matrix = tfp_hps.defer_and_count_usage(
tf.Variable(np.eye(latent_size), name='transition_matrix'))
tf.Variable(
np.eye(latent_size).astype(dtype), name='transition_matrix'))
transition_noise_scale = tfp_hps.defer_and_count_usage(
tf.Variable(
np.full([latent_size], transition_std),
np.full([latent_size], transition_std).astype(dtype),
name='transition_noise_scale'))
observation_matrix = tfp_hps.defer_and_count_usage(
tf.Variable(np.eye(latent_size), name='observation_matrix'))
tf.Variable(
np.eye(latent_size).astype(dtype), name='observation_matrix'))
observation_noise_scale = tfp_hps.defer_and_count_usage(
tf.Variable(
np.full([latent_size], observation_std),
np.full([latent_size], observation_std).astype(dtype),
name='observation_noise_scale'))
initial_state_prior_scale = tfp_hps.defer_and_count_usage(
tf.Variable(
np.full([latent_size], observation_std),
np.full([latent_size], observation_std).astype(dtype),
name='initial_state_prior_scale'))

model = lgssm.LinearGaussianStateSpaceModel(
Expand Down

0 comments on commit 14db88b

Please sign in to comment.