Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added mcmc_traceplots_unigauss.ipynb Fig 11.14 to 11.17 | Book2 #908

Merged
merged 3 commits into from
Jun 11, 2022
Merged

Added mcmc_traceplots_unigauss.ipynb Fig 11.14 to 11.17 | Book2 #908

merged 3 commits into from
Jun 11, 2022

Conversation

karm-patel
Copy link
Collaborator

@karm-patel karm-patel commented Jun 11, 2022

Description

Converted numpyro implementation (notebooks/book2/11/mcmc_traceplots_unigauss_numpyro.ipynb) to blackjax.
Please take a note of the following points

  1. I kept arviz default plots and arviz latexified plots separately since defaults plots look good for the notebook.
  2. For Fig 11.14 & 11.15, I didn't give yticks, since values are too small (like 0.0026) and taking space.

Figure Number

Fig 11.14
Fig 11.15
Fig 11.16
Fig 11.17

Figures

1. Fig 11.14

  • Before PR
    image

  • After PR
    image

2. Fig 11.15

  • Before PR
    image

  • After PR
    image

3. Fig 11.16

  • Before PR
    image

  • After PR
    image

4. Fig 11.17

  • Before PR
    image

  • After PR
    image

Issue

#890

Checklist

  • Performed a self-review of the code
  • Tested on Google Colab.

cc: Dr. @murphyk, Prof. @nipunbatra

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@murphyk
Copy link
Member

murphyk commented Jun 11, 2022

LGTM. create_trace_from_states might be worth adding to blackjax repo.

@karm-patel
Copy link
Collaborator Author

LGTM. create_trace_from_states might be worth adding to blackjax repo.

Yes Sir, that's a good idea, I'm thinking to create an issue in blackjax for arviz + blackjax demos, where I can also mention about create_trace_from_states.

@murphyk murphyk merged commit 7222fb7 into probml:master Jun 11, 2022
@@ -0,0 +1,965 @@
{
Copy link
Contributor

@nipunbatra nipunbatra Jun 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #6.            return states, {"states": states, "info": info}

Not clear what info is? Is this something your implementation adds or something provided out of the box by BlackJax?


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, info is returned by blackjax model, I've added comment to clarify it.

@@ -0,0 +1,965 @@
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it may be good to add 1-2 lines of comments on why swap axes is needed. Something like -- expected format by arviz is ... and our current format is ...


Reply via ReviewNB

@@ -0,0 +1,965 @@
{
Copy link
Contributor

@nipunbatra nipunbatra Jun 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #15.                samples[param] = states.position[param]

for dims > 1, you considered burn-in, but not here.


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, thanks for pointing out, I've fixed it.

@@ -0,0 +1,965 @@
{
Copy link
Contributor

@nipunbatra nipunbatra Jun 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We initialised prior_alpha in the initial cells and now we're using log_prob_alpha etc. Should we reuse variables?


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, done.

@@ -0,0 +1,965 @@
{
Copy link
Contributor

@nipunbatra nipunbatra Jun 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #1.    print(f"Number of divergences (bad prior) = {info.is_divergent[500:,:].sum()}")

What is 500 here? Burn-in? if so, should we create a variable for it?


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have created burn_in as global variable now

@@ -0,0 +1,965 @@
{
Copy link
Contributor

@nipunbatra nipunbatra Jun 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sensible


Reply via ReviewNB

@nipunbatra
Copy link
Contributor

LGTM. create_trace_from_states might be worth adding to blackjax repo.

I agree too! It would be useful for the larger community!

I have given some small comments on the notebook.

@karm-patel karm-patel mentioned this pull request Jun 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants