-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
problems with tfp.sts.fit_with_hmc #348
Comments
I've also just noticed that >>> print([p.name for p in seasonal.parameters])
['drift_scale']
>>> print([p.name for p in tfp.sts.Sum([seasonal]).parameters])
['observation_noise_scale', 'seasonal/_drift_scale'] |
Hi Jeff, thanks for the questions and feedback!
Taking 2+ minutes to run a few HMC steps is definitely a bug; we'll look
into that. At a high level, there are still quite a few quirks with fitting
STS models in TF2 -- we're working to get everything fixed up, but for the
next few weeks you'll likely have a better experience with TF1 and graph
mode.
Re your second Q, in the current setup, observation noise is added only by
the Sum component (this is admittedly a little bit non-intuitive). The
Seasonal component by itself represents a noise-free seasonal process, and
similarly for the other components. Fitting a 'bare' component by itself
might work, if your data are consistent with a noise-free process (even
then it's possible you'll hit numerical issues), but usually you'll want to
wrap it in a Sum so that it gets a noise model, even if it's just a trivial
one-component sum.
Dave
…On Thu, Apr 4, 2019 at 12:49 PM Jeff ***@***.***> wrote:
I've also just noticed that tfp.sts.Seasonal doesn't have a parameter for
the observation noise scale so really not sure what it is doing now -
should it raise an error when I call tfp.sts.fit_with_hmc?
>>> print([p.name for p in seasonal.parameters])
['drift_scale']>>> print([p.name for p in tfp.sts.Sum([seasonal]).parameters])
['observation_noise_scale', 'seasonal/_drift_scale']
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
<#348 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AA8iiu0R99NaJr7_l0Uqk7UeRq1sdW2Mks5vdlc0gaJpZM4cdjaf>
.
|
@davmre thanks that's really helpful! I'll have another try tomorrow and report back. |
I've tried to re-organise the code a bit (mainly to help with my understanding) and run with TF1 but this also errors :-(. Here is the code: import tensorflow as tf
import tensorflow_probability as tfp
import time as tm
tf.random.set_random_seed(42)
tfd = tfp.distributions
num_observed_seasons = 10
num_seasons = 2
num_steps_per_season = 10
num_timesteps = num_observed_seasons * num_steps_per_season
# initial prior on latent state
initial_state_prior = tfd.MultivariateNormalLinearOperator(
loc=tf.fill([num_seasons], 10.0),
scale=tf.linalg.LinearOperatorScaledIdentity(num_seasons, 5.0),
)
# parameter priors
seasonal_drift_scale_prior = tfd.HalfNormal(scale=1.0)
observation_noise_scale_prior = tfd.HalfNormal(scale=1.0)
# parameter values
seasonal_drift_scale = 1.0
observation_noise_scale = 0.2
seasonal_component = tfp.sts.Seasonal(
num_seasons=num_seasons,
num_steps_per_season=num_steps_per_season,
drift_scale_prior=seasonal_drift_scale_prior,
initial_effect_prior=initial_state_prior,
name="seasonal",
)
model = tfp.sts.Sum(
[seasonal_component], observation_noise_scale_prior=observation_noise_scale_prior
)
state_space_model = model.make_state_space_model(
num_timesteps, [observation_noise_scale, seasonal_drift_scale]
)
sampled_time_series = state_space_model.sample()
mcmc, kernel_results = tfp.sts.fit_with_hmc(
model=model,
observed_time_series=sampled_time_series,
num_results=3,
num_warmup_steps=2,
num_variational_steps=1,
)
with tf.Session() as sess:
start = tm.time()
mcmc_, kernel_results_ = sess.run([mcmc, kernel_results])
end = tm.time() which outputs: Traceback (most recent call last):
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1337, in _do_call
return fn(*args)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1322, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1410, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable make_variational/observation_noise_scale_loc from Container: localhost. This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/make_variational/observation_noise_scale_loc)
[[{{node fit_with_hmc/make_variational/build_factored_variational_loss/Normal/ReadVariableOp}}]]
[[fit_with_hmc/mcmc_sample_chain/scan/while/smart_for_loop/while/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/while/LoopCond/_247]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/jeff/workspace/misc/fit_with_hmc_issue.py", line 61, in <module>
mcmc_, kernel_results_ = sess.run([mcmc, kernel_results])
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 932, in run
run_metadata_ptr)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1155, in _run
feed_dict_tensor, options, run_metadata)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1331, in _do_run
run_metadata)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1351, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable make_variational/observation_noise_scale_loc from Container: localhost. This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/make_variational/observation_noise_scale_loc)
[[node fit_with_hmc/make_variational/build_factored_variational_loss/Normal/ReadVariableOp (defined at /home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow_probability/python/distributions/normal.py:132) ]]
[[fit_with_hmc/mcmc_sample_chain/scan/while/smart_for_loop/while/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/while/LoopCond/_247]]
Errors may have originated from an input operation.
Input Source operations connected to node fit_with_hmc/make_variational/build_factored_variational_loss/Normal/ReadVariableOp:
make_variational/observation_noise_scale_loc (defined at /home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow_probability/python/sts/fitting.py:110)
Original stack trace for 'fit_with_hmc/make_variational/build_factored_variational_loss/Normal/ReadVariableOp':
File "<stdin>", line 1, in <module>
File "/home/jeff/workspace/misc/fit_with_hmc_issue.py", line 51, in <module>
num_variational_steps=1,
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow_probability/python/sts/fitting.py", line 522, in fit_with_hmc
_, variational_distributions = make_variational()
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/ops/template.py", line 380, in __call__
return self._call_func(args, kwargs)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/ops/template.py", line 343, in _call_func
result = self._func(*args, **kwargs)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow_probability/python/sts/fitting.py", line 520, in make_variational
init_batch_shape=chain_batch_shape, seed=seed())
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow_probability/python/sts/fitting.py", line 279, in build_factored_variational_loss
q = _build_trainable_posterior(param, initial_loc_fn=initial_loc_fn)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow_probability/python/sts/fitting.py", line 116, in _build_trainable_posterior
q = tfd.Normal(loc=loc, scale=scale)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow_probability/python/distributions/normal.py", line 132, in __init__
loc = tf.convert_to_tensor(loc, name="loc", dtype=dtype)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1053, in convert_to_tensor
return convert_to_tensor_v2(value, dtype, preferred_dtype, name)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1111, in convert_to_tensor_v2
as_ref=False)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1190, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 1435, in _dense_var_to_tensor
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 1390, in _dense_var_to_tensor
return self.value()
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 789, in value
return self._read_variable_op()
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 865, in _read_variable_op
self._dtype)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/ops/gen_resource_variable_ops.py", line 587, in read_variable_op
"ReadVariableOp", resource=resource, dtype=dtype, name=name)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
op_def=op_def)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3545, in create_op
op_def=op_def)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1969, in __init__
self._traceback = tf_stack.extract_stack() some version stuff: >>> print(tf.__version__)
1.14.1-dev20190405
>>> print(tfp.__version__)
0.6.0-dev |
Argh, yeah, the downside with running in graph mode is that you have to
deal with graph mode annoyances. :-( In this case you need
sess.run(tf.global_variables_initializer())
before the sess.run([mcmc, kernel_results]) that runs the sampler.
That's because `fit_with_hmc` creates some internal variables that need to
be initialized (these define the variational model used to
initialize/precondition HMC, as well as the step size adaptation, though
the latter should go variable-less soon).
Dave
…On Fri, Apr 5, 2019 at 7:11 AM Jeff ***@***.***> wrote:
I've tried to re-organise the code a bit (mainly to help with my
understanding) and run with TF1 but this also errors :-(. Here is the code:
import tensorflow as tfimport tensorflow_probability as tfpimport time as tm
tf.random.set_random_seed(42)
tfd = tfp.distributions
num_observed_seasons = 10
num_seasons = 2
num_steps_per_season = 10
num_timesteps = num_observed_seasons * num_steps_per_season
# initial prior on latent state
initial_state_prior = tfd.MultivariateNormalLinearOperator(
loc=tf.fill([num_seasons], 10.0),
scale=tf.linalg.LinearOperatorScaledIdentity(num_seasons, 5.0),
)
# parameter priors
seasonal_drift_scale_prior = tfd.HalfNormal(scale=1.0)
observation_noise_scale_prior = tfd.HalfNormal(scale=1.0)
# parameter values
seasonal_drift_scale = 1.0
observation_noise_scale = 0.2
seasonal_component = tfp.sts.Seasonal(
num_seasons=num_seasons,
num_steps_per_season=num_steps_per_season,
drift_scale_prior=seasonal_drift_scale_prior,
initial_effect_prior=initial_state_prior,
name="seasonal",
)
model = tfp.sts.Sum(
[seasonal_component], observation_noise_scale_prior=observation_noise_scale_prior
)
state_space_model = model.make_state_space_model(
num_timesteps, [observation_noise_scale, seasonal_drift_scale]
)
sampled_time_series = state_space_model.sample()
mcmc, kernel_results = tfp.sts.fit_with_hmc(
model=model,
observed_time_series=sampled_time_series,
num_results=3,
num_warmup_steps=2,
num_variational_steps=1,
)
with tf.Session() as sess:
start = tm.time()
mcmc_, kernel_results_ = sess.run([mcmc, kernel_results])
end = tm.time()
which outputs:
Traceback (most recent call last):
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1337, in _do_call
return fn(*args)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1322, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1410, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable make_variational/observation_noise_scale_loc from Container: localhost. This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/make_variational/observation_noise_scale_loc)
[[{{node fit_with_hmc/make_variational/build_factored_variational_loss/Normal/ReadVariableOp}}]]
[[fit_with_hmc/mcmc_sample_chain/scan/while/smart_for_loop/while/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/while/LoopCond/_247]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/jeff/workspace/misc/fit_with_hmc_issue.py", line 61, in <module>
mcmc_, kernel_results_ = sess.run([mcmc, kernel_results])
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 932, in run
run_metadata_ptr)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1155, in _run
feed_dict_tensor, options, run_metadata)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1331, in _do_run
run_metadata)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1351, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable make_variational/observation_noise_scale_loc from Container: localhost. This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/make_variational/observation_noise_scale_loc)
[[node fit_with_hmc/make_variational/build_factored_variational_loss/Normal/ReadVariableOp (defined at /home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow_probability/python/distributions/normal.py:132) ]]
[[fit_with_hmc/mcmc_sample_chain/scan/while/smart_for_loop/while/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/while/LoopCond/_247]]
Errors may have originated from an input operation.
Input Source operations connected to node fit_with_hmc/make_variational/build_factored_variational_loss/Normal/ReadVariableOp:
make_variational/observation_noise_scale_loc (defined at /home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow_probability/python/sts/fitting.py:110)
Original stack trace for 'fit_with_hmc/make_variational/build_factored_variational_loss/Normal/ReadVariableOp':
File "<stdin>", line 1, in <module>
File "/home/jeff/workspace/misc/fit_with_hmc_issue.py", line 51, in <module>
num_variational_steps=1,
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow_probability/python/sts/fitting.py", line 522, in fit_with_hmc
_, variational_distributions = make_variational()
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/ops/template.py", line 380, in __call__
return self._call_func(args, kwargs)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/ops/template.py", line 343, in _call_func
result = self._func(*args, **kwargs)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow_probability/python/sts/fitting.py", line 520, in make_variational
init_batch_shape=chain_batch_shape, seed=seed())
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow_probability/python/sts/fitting.py", line 279, in build_factored_variational_loss
q = _build_trainable_posterior(param, initial_loc_fn=initial_loc_fn)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow_probability/python/sts/fitting.py", line 116, in _build_trainable_posterior
q = tfd.Normal(loc=loc, scale=scale)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow_probability/python/distributions/normal.py", line 132, in __init__
loc = tf.convert_to_tensor(loc, name="loc", dtype=dtype)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1053, in convert_to_tensor
return convert_to_tensor_v2(value, dtype, preferred_dtype, name)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1111, in convert_to_tensor_v2
as_ref=False)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1190, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 1435, in _dense_var_to_tensor
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 1390, in _dense_var_to_tensor
return self.value()
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 789, in value
return self._read_variable_op()
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 865, in _read_variable_op
self._dtype)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/ops/gen_resource_variable_ops.py", line 587, in read_variable_op
"ReadVariableOp", resource=resource, dtype=dtype, name=name)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
op_def=op_def)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3545, in create_op
op_def=op_def)
File "/home/jeff/.virtualenvs/tf1/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1969, in __init__
self._traceback = tf_stack.extract_stack()
some version stuff:
>>> print(tf.__version__)1.14.1-dev20190405>>> print(tfp.__version__)0.6.0-dev
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#348 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AA8iirkcxws68MAmYLBl5kekNI_Bg6WWks5vd1mcgaJpZM4cdjaf>
.
|
Ah - thanks! Sorry I've been using TF 2 for a few weeks and totally forgot all about the global variable initializer. I have that running in about 8 seconds with graph mode, the default number of iterations takes about 2 minutes and gives sensible looking results, too :-) I'll leave this open for now since this doesn't seem to work with TF2 + |
@davmre The reported problem with @tf.function still happens. Are we expecting to fix that? Also, @jeffpollock9: do you need this to work in TF2 under @tf.function? For the record, here's a slightly reduced reproduction:
|
Hi @axch, thanks for the example - it is much nicer than mine. I’m travelling and not able to test anything at the moment, but I think that this needs to work with @tf.function otherwise it won’t be very usable with TF2 (as it is really slow) |
I'm having similar issues, and seeing as we are quite a few months on from the last response, I thought I would post. I'm using tfp.sts.fit_with_hmc and coming across the following issues:
@tf.function(autograph=False) samples, kernel = fit()
samples, kernel_results = sts.fit_with_hmc(model, y_label, num_results=1, num_warmup_steps=1) Worth noting that i have a time series with 200k observations, and a large proportion of my y label values are zero (it's basically a Poisson count distribution). However, I don't think this is related to the run time issue. I'm running in google colab using GPU. TF/TFP versions below: Any advice? |
I'm not sure your runtime issues are so adding an extra layer of wrapping externally shouldn't speed it up any further. (put differently, it should already be fast even without additional wrapping). The Kalman filtering algorithm that we use to compute each iteration of HMC fitting is sequential, so the runtime scales with the total length of the series (not the number of nonzero entries). I'd expect anything with 200k points to be inherently pretty slow. I've been working on a parallel implementation, which may make this substantially faster in the next ~month. That said, hours to run a couple of MCMC steps seems like a lot even with the current code. Can you measure how long it takes to evaluate the likelihood of your series just once? E.g., @tf.function(autograph=False)
def series_log_prob(param_samples)
d = model.make_state_space_model(
num_timesteps=len(y_label),
param_vals=param_samples)
return d.log_prob(y_labels)
lp = series_log_prob([p.prior.sample() for p in self.parameters]) My hypothesis is that this will already be quite slow---if it's not, that would indicate something going wrong in the fitting code. Alternately, if you can replicate the issue in a colab, I'd be happy to take a look. For what it's worth, the immediate error you're seeing ( |
Hello, I am triying to use the impute_missing_values, but for this the parameter_samples are needed so I just followed the code as it follows:
The data I am using have this:
The null count is = 3809:
It´s running for more than one hour but it still didnt finnished, I have it into a google colaboratory notebook. |
I'm having a few issues with
tfp.sts.fit_with_hmc
so thought I would reach out, I am using a very recent tfp-nightly and tensorflow 2 alpha.Firstly, the function signature has
num_variational_steps=150
but the documentation has "Default value:200
".Secondly, I've tried a simple model to try and understand things, but
tfp.sts.fit_with_hmc
is taking over 2 minutes for 3 mcmc samples with 2 warmup steps and 1 variational step - is this expected?forward_filter
only takes about 0.6 seconds so this surprised me.Here is the code - am I doing something stupid?
which outputs (alongside many warnings):
Finally, in an attempt to speed this up, I tried wrapping
tfp.sts.fit_with_hmc
with@tf.function
but this gives the following error:output:
Many thanks for any replies and all your amazing work on tensorflow probability!
The text was updated successfully, but these errors were encountered: