Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mlpPriorsDemo2 #304

Merged
merged 2 commits into from Mar 26, 2021
Merged

mlpPriorsDemo2 #304

merged 2 commits into from Mar 26, 2021

Conversation

Abdelrahman350
Copy link
Contributor

Hello, this is a Python implementation for mlpPriorsDemo2. The figures are slightly different from the book because NumPy generates different random vectors than Matlab.

@mjsML
Copy link
Member

mjsML commented Mar 21, 2021

I think your browser doesn't paste images from the file manager (I smell Ubuntu).
Try opening the screenshot in the image app and then copy the image from there.

@Abdelrahman350
Copy link
Contributor Author

image
image

@Abdelrahman350
Copy link
Contributor Author

image

@Abdelrahman350
Copy link
Contributor Author

image

@Abdelrahman350
Copy link
Contributor Author

I think your browser doesn't paste images from the file manager (I smell Ubuntu).
Try opening the screenshot in the image app and then copy the image from there.

Yes, it is Ubuntu :). I uploaded figures 4, 1, 2, 3

@murphyk
Copy link
Member

murphyk commented Mar 21, 2021

This looks good, but all the book-keeping code for flattening / unflattening parameters complicates things.
Please can you rewrite it using jax.tree_util, following this example. This should make it short and sweet :)

@Abdelrahman350
Copy link
Contributor Author

This looks good, but all the book-keeping code for flattening / unflattening parameters complicates things.
Please can you rewrite it using jax.tree_util, following this example. This should make it short and sweet :)

Ok, I will try doing it.

@murphyk
Copy link
Member

murphyk commented Mar 21, 2021

This looks good, but all the book-keeping code for flattening / unflattening parameters complicates things.
Please can you rewrite it using jax.tree_util, following this example. This should make it short and sweet :)

Ok, I will try doing it.

On second thoughts, it might be useful to have a 'vanilla' and jax version of this demo, just as example of how to do things using both approaches. So please check in your jax version as a separate file, called mlpPriorsDemo2_jax.py. Ideally this will use the same random seed as your current code, so both scripts will produce identical output. Please make a separate PR for the jax version, and I can merge the current one as is (once I check it).

@Abdelrahman350
Copy link
Contributor Author

This looks good, but all the book-keeping code for flattening / unflattening parameters complicates things.
Please can you rewrite it using jax.tree_util, following this example. This should make it short and sweet :)

Ok, I will try doing it.

On second thoughts, it might be useful to have a 'vanilla' and jax version of this demo, just as example of how to do things using both approaches. So please check in your jax version as a separate file, called mlpPriorsDemo2_jax.py. Ideally this will use the same random seed as your current code, so both scripts will produce identical output. Please make a separate PR for the jax version, and I can merge the current one as is (once I check it).

Ok, I feel this is better to write a new file in JAX. I will do my best.

Copy link
Member

@murphyk murphyk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • It seems self.beta is not used.

  • rename aw1, ab1, aw2, ab2 to alpha1, beta1, alpha2, beta2 to match the text (even though this does not match the original matlab).

  • MLP function assigns random weights, but does not draw from the user's prior. This should call MLP_init to create initial weights.

  • You can eliminate all the logic about indices and pack/unpack by simply sampling 4 times for w1, b1, w2, b2, just like you do in the MLP function. This is slightly repetitious but is easier to read (since the code already hard-codes assumption that there is only 1 hidden layer). Then you can replace the Prior class by just a tuple (alpha1, beta1, alpha2, beta2). The general case can be handled by your jax version.

@Abdelrahman350
Copy link
Contributor Author

Abdelrahman350 commented Mar 22, 2021

* It seems self.beta is not used.

* rename aw1, ab1, aw2, ab2 to alpha1, beta1, alpha2, beta2 to match the text (even though this does not match the original matlab).

* MLP function assigns random weights, but does not draw from the user's prior. This should call MLP_init to create initial weights.

* You can eliminate all the logic about indices and  pack/unpack by simply sampling 4 times for w1, b1, w2, b2, just like you do in the MLP function. This is slightly repetitious but is easier to read (since the code already hard-codes assumption that there is only 1 hidden layer). Then you can replace the Prior class by just a tuple (alpha1, beta1, alpha2, beta2). The general case can be handled by your jax version.
  • In the first point, self.beta is not used. Yes, that's true as it is not used in MLP_init like self.alpha and this is exactly the case in the MATLAB code.
  • In the second issue ----> done.
  • In the third issue, what I get from your requested change that you want to do the MLP_init initialization using prior directly instead of initializing it using random values then using the prior in MLP_init in the for loop, am I right?
  • In the fourth issue, what I get from your requested change is to discard all (mark1, extra, ...) and use alpha1, alpha2, ...etc directly instead of the packed alpha, indx, am I right?

@murphyk
Copy link
Member

murphyk commented Mar 22, 2021

* It seems self.beta is not used.

* rename aw1, ab1, aw2, ab2 to alpha1, beta1, alpha2, beta2 to match the text (even though this does not match the original matlab).

* MLP function assigns random weights, but does not draw from the user's prior. This should call MLP_init to create initial weights.

* You can eliminate all the logic about indices and  pack/unpack by simply sampling 4 times for w1, b1, w2, b2, just like you do in the MLP function. This is slightly repetitious but is easier to read (since the code already hard-codes assumption that there is only 1 hidden layer). Then you can replace the Prior class by just a tuple (alpha1, beta1, alpha2, beta2). The general case can be handled by your jax version.
  • In the first point, self.beta is not used. Yes, that's true as it is not used in MLP_init like self.alpha and this is exactly the case in the MATLAB code.
  • In the second issue ----> done.
  • In the third issue, what I get from your requested change that you want to do the MLP_init initialization using prior directly instead of initializing it using random values then using the prior in MLP_init in the for loop, am I right?

It's a small thing, but currently the MLP constructor assigns random parameters, and then MLP_init overwrites them with new random values. This is a bit redundant. Also, only MLP_init uses the specified alpha/beta hyper-parameters, so the initial random parameters are 'wrong' in some sense.

  • In the fourth issue, what I get from your requested change is to discard all (mark1, extra, ...) and use alpha1, alpha2, ...etc directly instead of the packed alpha, indx, am I right?

Exactly.

@mjsML mjsML linked an issue Mar 22, 2021 that may be closed by this pull request
@Abdelrahman350
Copy link
Contributor Author

I did the requested changes, changed the names and packing algorithm.

@murphyk
Copy link
Member

murphyk commented Mar 26, 2021

Much improved, thanks!

@murphyk murphyk merged commit 4662232 into probml:master Mar 26, 2021
@Abdelrahman350
Copy link
Contributor Author

Much improved, thanks!

You are welcome. I will start working on the jax version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Convert mlpPriorsDemo2 to python
3 participants