Skip to content
This repository has been archived by the owner on Jun 20, 2022. It is now read-only.

Update Random Priors Demo #69

Merged
merged 7 commits into from Apr 9, 2022
Merged

Update Random Priors Demo #69

merged 7 commits into from Apr 9, 2022

Conversation

petergchang
Copy link
Contributor

@petergchang petergchang commented Apr 6, 2022

Description

Continued from PR61.

  1. Set the adam optimizer learning rate as a parameter for training functions.
  2. Vectorized get_ensemble_predictions() and the code that generates predictions over a range of beta values using jax.vmap().
  3. Fixed an issue in the implementation of create_train_state() method that was surreptitiously leaking a Trace of beta.
  4. Modified the final plot comparing with/without bootstrap and/or prior to fill_between 1 and 2 standard deviations, with a larger beta=30 value. This resulted in a much more clear distinction among the four types of models that closely matches the description given by the original Kaggle notebook:

Generally, turning bootstrapping off will reduce uncertainty the most (upper right corner), as opposed to turning priors off (lower left corner), but it can vary a bit across different seeds. With both bootstrapping and priors off, there's still a little disagreement between ensemble members due to random initialization of weights, but is a lot less compared to the other models.

Colab link

https://colab.research.google.com/github/probml/probml-notebooks/blob/main/notebooks/randomized_priors.ipynb

Issue

#708: translate ensembles with random priors demo from keras to JAX

Figures

Gist for Updated Bootstrap/Prior-Dependence Figure

Checklist:

  • Performed a self-review of the code
  • Tested on Google Colab.

Potential problems/Important remarks

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@review-notebook-app
Copy link

View / edit / reply to this conversation on ReviewNB

murphyk commented on 2022-04-07T05:59:35Z
----------------------------------------------------------------

The colab button is incorrectly pointing to your repo :)


@review-notebook-app
Copy link

review-notebook-app bot commented Apr 7, 2022

View / edit / reply to this conversation on ReviewNB

murphyk commented on 2022-04-07T05:59:36Z
----------------------------------------------------------------

Line #1.    def create_train_state(key, X, beta, lr):

It seems that beta is ignored?


@review-notebook-app
Copy link

View / edit / reply to this conversation on ReviewNB

murphyk commented on 2022-04-07T05:59:37Z
----------------------------------------------------------------

Line #18.            ensemble.append(train(key, X_b, Y_b, beta, lr, epochs))

Maybe it would be clearer (and faster) to first build the dataset for each ensemble member (which is fast), and then vmap the training process over the list of datasets (which could be slow).


@review-notebook-app
Copy link

View / edit / reply to this conversation on ReviewNB

murphyk commented on 2022-04-07T05:59:38Z
----------------------------------------------------------------

Line #8.                                 in_axes=(None, None, None, None, 0, None, None,))

Consider using https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html :)


@murphyk
Copy link
Member

murphyk commented Apr 7, 2022

Hi @peterchang0414 , I left some comments in https://app.reviewnb.com/probml/probml-notebooks/pull/69/, not sure if you can see them? (I'm new to ReviewNB)

@petergchang
Copy link
Contributor Author

Hello @murphyk, thank you very much for your thoughtful comments!

Following your advice, I made the following revisions for the most recent commit(s):

  1. Updated the "Open in Colab" button to point to the probml repo (though I had to use a contrived method to link to the remote repo since I don't have write access).
  2. Deleted the unnecessary beta parameter from create_train_state() function.
  3. Modified the build_ensemble() such that it first stacks each of the (possibly bootstrapped) training set for each model in ensemble, then vmaps the train() function to the resulting stack.
  4. The modification led to noticeable speedup, e.g., %%timeit for get_ensemble_predictions() (2000 epochs) improved from 13s per loop (best of 5) to 5.08s per loop (best of 5)

For creating a vectorized function for computing predictions for an array of different beta values, I studied the xmap tutorial you linked and perused the source code to see how it may help, but failed to see how it would help in this case (the cluttering was even worse in my xmap experiments). I'm new to the JAX and I may not yet be grokking xmap enough to understand your advice. If you have more hints as to how xmap would help here, I'd greatly appreciate it!

Thanks again.

@murphyk
Copy link
Member

murphyk commented Apr 9, 2022

I am not very familar with xmap either, it just seemed like an easy way to name axes, so you know what is being vmap'd over. But it's just syntactic sugar, and is not needed.

@murphyk murphyk merged commit 268fd60 into probml:main Apr 9, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants