Skip to content

make jit model args work for chains#694

Merged
neerajprad merged 5 commits intopyro-ppl:masterfrom
fehiepsi:jitarg
Jul 24, 2020
Merged

make jit model args work for chains#694
neerajprad merged 5 commits intopyro-ppl:masterfrom
fehiepsi:jitarg

Conversation

@fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Jul 23, 2020

Fixes #691. It is surprised to me that there was no test to detect this issue previously.

Test on test_chain_jit_args_smoke

# sequential
warmup: 100%|██████████████| 2/2 [00:06<00:00,  3.18s/it, 3 steps of size 5.89e-01. acc. prob=0.00]
warmup: 100%|██████████████| 2/2 [00:00<00:00, 52.53it/s, 1 steps of size 5.89e-01. acc. prob=0.00]
sample: 100%|█████████████| 5/5 [00:00<00:00, 114.76it/s, 1 steps of size 5.89e-01. acc. prob=0.89]
sample: 100%|█████████████| 5/5 [00:00<00:00, 144.18it/s, 1 steps of size 5.89e-01. acc. prob=0.55]
sample: 100%|██████████████| 5/5 [00:06<00:00,  1.21s/it, 1 steps of size 5.89e-01. acc. prob=0.60]
sample: 100%|█████████████| 5/5 [00:00<00:00, 228.72it/s, 1 steps of size 5.89e-01. acc. prob=0.60]
PASSED
test/test_mcmc.py::test_chain_jit_args_smoke[False-parallel] PASSED
# vectorized
warmup: 100%|██████████████████████████████████████████████████| 2/2 [00:06<00:00,  3.29s/it, None]
sample: 100%|█████████████████████████████████████████████████| 5/5 [00:00<00:00, 977.56it/s, None]
sample: 100%|██████████████████████████████████████████████████| 5/5 [00:09<00:00,  1.83s/it, None]
PASSED
# sequential
warmup: 100%|██████████████| 2/2 [00:04<00:00,  2.43s/it, 3 steps of size 5.89e-01. acc. prob=0.00]
warmup: 100%|█████████████| 2/2 [00:00<00:00, 717.22it/s, 1 steps of size 5.89e-01. acc. prob=0.00]
sample: 100%|█████████████| 5/5 [00:00<00:00, 832.14it/s, 1 steps of size 5.89e-01. acc. prob=0.89]
sample: 100%|████████████| 5/5 [00:00<00:00, 1019.12it/s, 1 steps of size 5.89e-01. acc. prob=0.55]
sample: 100%|████████████| 5/5 [00:00<00:00, 1187.11it/s, 1 steps of size 5.89e-01. acc. prob=0.60]
sample: 100%|████████████| 5/5 [00:00<00:00, 1260.16it/s, 1 steps of size 5.89e-01. acc. prob=0.60]
PASSED
test/test_mcmc.py::test_chain_jit_args_smoke[True-parallel] PASSED
# vectorized
warmup: 100%|██████████████████████████████████████████████████| 2/2 [00:09<00:00,  4.65s/it, None]
sample: 100%|█████████████████████████████████████████████████| 5/5 [00:00<00:00, 816.74it/s, None]
sample: 100%|█████████████████████████████████████████████████| 5/5 [00:00<00:00, 704.69it/s, None]
PASSED

We can observe that sequential and vectorized methods take advantage of jit_model_args=True. For parallel method, it is not clear to me if jit_model_args works or pmap works out-of-the-box: I added some timing for the second run but got similar results for True and False and for both cases, the second runs are much faster than the first runs: 5s vs 20s).

@neerajprad
Copy link
Member

@fehiepsi - Seems like you'll need to merge changes from master.

neerajprad
neerajprad previously approved these changes Jul 24, 2020
Copy link
Member

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning this up, a big improvement over what we had!

states[self._sample_field] = lax.map(postprocess_fn, states[self._sample_field])
return states, last_state

def _single_chain_jit_args(self, init, collect_fields):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm..I don't remember why we needed this (or whether we needed it then and there have been some changes in jax since). I am also noticing that we can also remove sample_fn_jit_args and sample_fn_nojit_args and directly use partial(sampler_sample, args=, kwargs=) in _get_cached_fn(..) and simplify the state to be HMCState instead of a tuple. Let me know if you want me to push this to your branch.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this method is used to resolved *init vs init as an argument for single_chain_mcmc. I don't remember why we need to distinguish jit_args vs nojit_args neither. For cached_fn, please feel free to push the simplification to this branch. I guess it is needed because if args, kwargs changed, the cached fn is also needed to be changed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was confused. I don't think we can avoid cycling the args, kwargs unfortunately.

@neerajprad neerajprad merged commit e3eaa15 into pyro-ppl:master Jul 24, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Can't use jit_model_args for multi chains

2 participants