New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mlpPriorsDemo2 #304
mlpPriorsDemo2 #304
Conversation
I think your browser doesn't paste images from the file manager (I smell Ubuntu). |
Yes, it is Ubuntu :). I uploaded figures 4, 1, 2, 3 |
This looks good, but all the book-keeping code for flattening / unflattening parameters complicates things. |
Ok, I will try doing it. |
On second thoughts, it might be useful to have a 'vanilla' and jax version of this demo, just as example of how to do things using both approaches. So please check in your jax version as a separate file, called mlpPriorsDemo2_jax.py. Ideally this will use the same random seed as your current code, so both scripts will produce identical output. Please make a separate PR for the jax version, and I can merge the current one as is (once I check it). |
Ok, I feel this is better to write a new file in JAX. I will do my best. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
It seems self.beta is not used.
-
rename aw1, ab1, aw2, ab2 to alpha1, beta1, alpha2, beta2 to match the text (even though this does not match the original matlab).
-
MLP function assigns random weights, but does not draw from the user's prior. This should call MLP_init to create initial weights.
-
You can eliminate all the logic about indices and pack/unpack by simply sampling 4 times for w1, b1, w2, b2, just like you do in the MLP function. This is slightly repetitious but is easier to read (since the code already hard-codes assumption that there is only 1 hidden layer). Then you can replace the Prior class by just a tuple (alpha1, beta1, alpha2, beta2). The general case can be handled by your jax version.
|
It's a small thing, but currently the MLP constructor assigns random parameters, and then MLP_init overwrites them with new random values. This is a bit redundant. Also, only MLP_init uses the specified alpha/beta hyper-parameters, so the initial random parameters are 'wrong' in some sense.
Exactly. |
I did the requested changes, changed the names and packing algorithm. |
Much improved, thanks! |
You are welcome. I will start working on the jax version. |
Hello, this is a Python implementation for mlpPriorsDemo2. The figures are slightly different from the book because NumPy generates different random vectors than Matlab.