New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/nitpicks #35
base: master
Are you sure you want to change the base?
Feature/nitpicks #35
Conversation
* use logging.getLogger to get a logger to write to instead of calling logging directly. Also set the matplotlib logger to only print ERROR level messages, so we don't get extra output when plotting. * create a generic function to plot (an) image(s) using a specified amount of columns. This allows the user to easily plot more samples
Passing in the class instead of a string makes construct_vae generic enough that it doesn't need changes when you want to play with different implementations.
the kl_divergence parameter in the GaussianVAE was never used, and it seems diagonal_gaussian_kl should be private to GaussianInferenceNetwork, so remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This a nice clean-up of the notebook. Thanks for doing this. I only have one remark on the changes that I left inline. Sorry taking so long to get back to you.
@@ -525,11 +549,8 @@ | |||
"\n", | |||
" def __init__(self,\n", | |||
" generator: Generator,\n", | |||
" inference_net: GaussianInferenceNetwork,\n", | |||
" kl_divergence: Callable) -> None:\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you removing the KL divergence. It is needed to properly construct the ELBO.
I have thought about this a lot before and one alternative is to rewrite the ELBO as E[p(x,z)] + H(q(z)). Then we could provide the entropy as a method of the inference net. However, it's out of whack with the slides in that case. That's why for the time being I would just pass in the kl_divergence as an argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed it because it wasn't actually being used (as far as I can see). The GaussianInferenceNetwork
calls diagonal_gaussian_kl()
directly. I've been implementing inference networks with different distributions with Wilker, which require different ways to calculate the KL divergence, it seemed to me that the KL divergence was conceptually linked more to the inference network than to the VAE as a whole, which is why I opted for removing the parameter rather than making the GaussianInferenceNetwork
not hardcode the function.
Some bits I changed while trying to figure things out/implement a pytorch version. Mostly consist of a flexible plot_images function that lets you display arbitrary amounts of images.