Conversation
…tervals for final figure
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
View / edit / reply to this conversation on ReviewNB murphyk commented on 2022-04-07T05:59:35Z The colab button is incorrectly pointing to your repo :) |
View / edit / reply to this conversation on ReviewNB murphyk commented on 2022-04-07T05:59:36Z Line #1. def create_train_state(key, X, beta, lr): It seems that beta is ignored? |
View / edit / reply to this conversation on ReviewNB murphyk commented on 2022-04-07T05:59:37Z Line #18. ensemble.append(train(key, X_b, Y_b, beta, lr, epochs)) Maybe it would be clearer (and faster) to first build the dataset for each ensemble member (which is fast), and then vmap the training process over the list of datasets (which could be slow). |
View / edit / reply to this conversation on ReviewNB murphyk commented on 2022-04-07T05:59:38Z Line #8. in_axes=(None, None, None, None, 0, None, None,)) Consider using https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html :) |
Hi @peterchang0414 , I left some comments in https://app.reviewnb.com/probml/probml-notebooks/pull/69/, not sure if you can see them? (I'm new to ReviewNB) |
Hello @murphyk, thank you very much for your thoughtful comments! Following your advice, I made the following revisions for the most recent commit(s):
For creating a vectorized function for computing predictions for an array of different beta values, I studied the Thanks again. |
I am not very familar with xmap either, it just seemed like an easy way to name axes, so you know what is being vmap'd over. But it's just syntactic sugar, and is not needed. |
Description
Continued from PR61.
get_ensemble_predictions()
and the code that generates predictions over a range of beta values usingjax.vmap()
.create_train_state()
method that was surreptitiously leaking a Trace ofbeta
.fill_between
1 and 2 standard deviations, with a largerbeta=30
value. This resulted in a much more clear distinction among the four types of models that closely matches the description given by the original Kaggle notebook:Colab link
https://colab.research.google.com/github/probml/probml-notebooks/blob/main/notebooks/randomized_priors.ipynb
Issue
#708: translate ensembles with random priors demo from keras to JAX
Figures
Gist for Updated Bootstrap/Prior-Dependence Figure
Checklist:
Potential problems/Important remarks