Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added live_traceplot function #1934

Merged
merged 56 commits into from
Mar 27, 2017
Merged

Conversation

davidbrochart
Copy link
Contributor

This allows to have live trace plots in the notebook, which can be useful when you want to see what is going on as you run the simulation.

@twiecki
Copy link
Member

twiecki commented Mar 22, 2017

This is really cool. Will try to take a look soon.

@twiecki
Copy link
Member

twiecki commented Mar 22, 2017

I wonder if perhaps the API should be:
pm.sample(..., live=True)

Copy link
Member

@ColCarroll ColCarroll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks super cool, and a really nice use of the iter_sample. Will take a more substantive look later.

@@ -16,7 +16,7 @@
import sys
sys.setrecursionlimit(10000)

__all__ = ['sample', 'iter_sample', 'sample_ppc', 'init_nuts']
__all__ = ['assign_step_methods', 'sample', 'iter_sample', 'sample_ppc', 'init_nuts']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes because iter_sample needs a step function, so if you don't have one already...

varnames = get_default_varnames(trace, plot_transformed)
x0 = skip_first
elif (it - skip_first) % refresh_every == 0:
for i, v in enumerate(varnames):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked closely, but do you think it is possible for this code to be factored out and shared with the base traceplot code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably yes, it's a first shot.

@davidbrochart
Copy link
Contributor Author

Yes, that could be an option, I wasn't sure.

@fonnesbeck
Copy link
Member

I get a TypeError when I try to run the notebook:

TypeError: ufunc 'true_divide' output (typecode 'd') could not be coerced to provided output parameter (typecode 'l') according to the casting rule ''same_kind''

I agree with @twiecki that this should be a flag in sample rather than a new sampling function (even if a separate function is ultimately called underneath). The required additional arguments could be passed as kwargs.

@aseyboldt
Copy link
Member

aseyboldt commented Mar 22, 2017

Maybe this is also a good time to mention that iter_sample has a bit of a performance problem. In each iteration it slices the whole trace. In the ndarray backend that is bad, but not terrible (when sampling a single normal it takes about 3 times longer), but in the sqlite and the hdf5 backend that means loading the whole thing into memory from scratch at each iteration.
Edit: Wrong place for this, sorry. It's #1935. Really neat idea with the live update by the way.

@twiecki
Copy link
Member

twiecki commented Mar 22, 2017

@aseyboldt I hadn't realized that. Any ideas on how this could be improved?

@aseyboldt aseyboldt mentioned this pull request Mar 22, 2017
@davidbrochart
Copy link
Contributor Author

@fonnesbeck Sorry I merged a bit quickly, I'll fix and update with the feedback I get.

@davidbrochart
Copy link
Contributor Author

@fonnesbeck I checked with Python2.7 and Python3.5, no error on my side. Has anyone else run the notebook?

@aseyboldt
Copy link
Member

I get the same error TypeError.

@aseyboldt
Copy link
Member

It works with this change:

diff --git a/pymc3/plots/utils.py b/pymc3/plots/utils.py
index b4ff6f5d..15694817 100644
--- a/pymc3/plots/utils.py
+++ b/pymc3/plots/utils.py
@@ -99,6 +99,6 @@ def fast_kde(x):

     norm_factor = n * dx * (2 * np.pi * std_x ** 2 * scotts_factor ** 2) ** 0.5

-    grid /= norm_factor
+    grid = grid / norm_factor

     return grid, xmin, xmax

@davidbrochart
Copy link
Contributor Author

Yes I get the error when upgrading Numpy from 1.11.3 to 1.12.1.

@twiecki
Copy link
Member

twiecki commented Mar 23, 2017

@davidbrochart You get the error with the fix @aseyboldt posted (which should already be on master, I think @aloctavodia did that)?

@davidbrochart
Copy link
Contributor Author

@twiecki No, with @aseyboldt 's fix there is no error any more.

@twiecki
Copy link
Member

twiecki commented Mar 23, 2017

Are you not on pymc3 master?

@davidbrochart
Copy link
Contributor Author

Yes I am now, it works fine.

@davidbrochart
Copy link
Contributor Author

I changed the API to what @twiecki suggested, and took into account @ColCarroll 's suggestion on sharing the traceplot code.

@fonnesbeck
Copy link
Member

How does this work with a large model? It might be worth allowing a list of the subset of parameters to monitor.

@davidbrochart
Copy link
Contributor Author

davidbrochart commented Mar 25, 2017

You can still pass a varnames list of variables to be plotted, all the parameters of the traceplot function are accepted.

@@ -7,7 +7,7 @@

def traceplot(trace, varnames=None, transform=identity_transform, figsize=None, lines=None,
combined=False, plot_transformed=False, grid=False, alpha=0.35, priors=None,
prior_alpha=1, prior_style='--', ax=None):
prior_alpha=1, prior_style='--', ax=None, live_plot=False, skip_first=0, refresh_every=100, roll_over=1000):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better formatting (new-lines so we don't exceed 80 chars).

@twiecki
Copy link
Member

twiecki commented Mar 27, 2017

LGTM. Seems like you need to rebase.

Junpeng Lao and others added 4 commits March 27, 2017 10:21
* Expand sampler-stats.ipynb example

include model diagnose from case study example in Stan http://mc-stan.org/documentation/case-studies/divergences_and_bias.html

* Sampler Diagnose for NUTS

* descriptive annotation and axis labels

* Fix typos

* PEP8 styling

* minor updates

1, add example to examples.rst
2, original content in Markdown code block
fonnesbeck and others added 24 commits March 27, 2017 10:21
)

* refactor module, add histogram

* add more tests

* refactor some code concerning AEVB histogram

* fix test for histogram

* use mean as deterministic point in Histogram

* remove unused import

* change names of shortcuts

* add names to shared params

* add new line at the end of `approximations.py`
* fix some svgd problems

* switch -> ifelse

* except in record
* add docs

* delete redundant code

* add usage example

* remove unused import
* use only free RVs from trace

* use memoize in Histogram.histogram_logp

* Change tests for histogram
@davidbrochart
Copy link
Contributor Author

Not sure what I did here! Let me know if that's okay.

@twiecki
Copy link
Member

twiecki commented Mar 27, 2017

@davidbrochart Seems like something went wrong, maybe this is helpful: https://github.com/edx/edx-platform/wiki/How-to-Rebase-a-Pull-Request

@twiecki
Copy link
Member

twiecki commented Mar 27, 2017

Oh actually this seems fine, my bad.

@twiecki twiecki merged commit d2bca90 into pymc-devs:master Mar 27, 2017
@twiecki
Copy link
Member

twiecki commented Mar 27, 2017

This is awesome, thanks!

@jake-westfall jake-westfall mentioned this pull request Apr 5, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet