make jit model args work for chains#694
Conversation
|
@fehiepsi - Seems like you'll need to merge changes from master. |
neerajprad
left a comment
There was a problem hiding this comment.
Thanks for cleaning this up, a big improvement over what we had!
| states[self._sample_field] = lax.map(postprocess_fn, states[self._sample_field]) | ||
| return states, last_state | ||
|
|
||
| def _single_chain_jit_args(self, init, collect_fields): |
There was a problem hiding this comment.
Hmm..I don't remember why we needed this (or whether we needed it then and there have been some changes in jax since). I am also noticing that we can also remove sample_fn_jit_args and sample_fn_nojit_args and directly use partial(sampler_sample, args=, kwargs=) in _get_cached_fn(..) and simplify the state to be HMCState instead of a tuple. Let me know if you want me to push this to your branch.
There was a problem hiding this comment.
I think this method is used to resolved *init vs init as an argument for single_chain_mcmc. I don't remember why we need to distinguish jit_args vs nojit_args neither. For cached_fn, please feel free to push the simplification to this branch. I guess it is needed because if args, kwargs changed, the cached fn is also needed to be changed.
There was a problem hiding this comment.
Sorry, I was confused. I don't think we can avoid cycling the args, kwargs unfortunately.
Fixes #691. It is surprised to me that there was no test to detect this issue previously.
Test on
test_chain_jit_args_smokeWe can observe that
sequentialandvectorizedmethods take advantage ofjit_model_args=True. Forparallelmethod, it is not clear to me ifjit_model_argsworks orpmapworks out-of-the-box: I added some timing for the second run but got similar results forTrueandFalseand for both cases, the second runs are much faster than the first runs: 5s vs 20s).