-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Trace_MMD class, add tests that MMD correctly fits distributions #1818
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@varenick thanks for the PR! I have a few questions/comments. Would you also be able to add a simple standalone example script, either here or in a followup PR?
or a dict that maps latent variable names to instances of :class: `pyro.contrib.gp.kernels.kernel.Kernel`. | ||
In the latter case, different kernels are used for different latent variables. | ||
|
||
:param mmd_scale: A scaling factor for MMD terms. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be separate? Shouldn't scaling be handled within the kernels (via e.g. pyro.contrib.gp.kernels.VerticalScaling
) or in the learning rate during optimization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is ok to handle scaling at initialization (simply with variance
argument of Kernel
instance), but I am not sure, if it would be convenient to handle scale change during optimization process via accessing kernel arguments. Also, not every Kernel
instance has variance
property.
…xample of MMD-VAE loss
@eb8680 Here's an example jupyter notebook with MMD-VAE model implemented with |
pyro/infer/trace_mmd.py
Outdated
model_samples = independent_model_site['value'] | ||
guide_samples = guide_site['value'] | ||
model_samples = model_samples.view( | ||
-1, *[model_samples.size(j) for j in range(-independent_model_site['fn'].event_dim, 0)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I don't think what you've implemented here, where you're computing MMD across plate slices, is valid for models and guides with any nontrivial dependency structure. It happens to correspond to the objective in the InfoVAE paper for the InfoVAE graphical model because all the latent variables in that model are independent, but is not correct in general.
You'll need to do what you were doing before and treat only the particle dimension as a batch dimension, but compute it explicitly following my suggestion:
particle_dim = -self.max_plate_nesting - independent_model_site["fn"].event_dim
model_samples = independent_model_site['value']
model_samples = model_samples.transpose(-model_samples.dim(), particle_dim)
model_samples = model_samples.view(model_samples.shape[0], -1)
# and similar for guide_samples
Unfortunately, this correct objective (which treats a probabilistic program as a big black box with no internal structure) does not correspond to the one in the InfoVAE paper. I suspect a general algorithm for MMD computation between arbitrary graphical models and mean-field guides that exploits all available conditional independence structure in the model, which would recover the InfoVAE objective as a trivial special case, would require kernels that decompose over Markov blankets in the model, rather than over individual variables.
It would probably look somewhat similar to this message-passing algorithm for Stein variational gradient descent or this message-passing algorithm for Jensen-Shannon divergences with neural density ratio estimators.
That's way beyond the scope of this PR, though - in fact, if you work out the general case you could write a nice paper about it. Note also that @fritzo and I are working on a new backend for Pyro that should make implementing such algorithms significantly easier, though it won't be ready for a few months.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, I see that my implementation doesn't work correctly for the general case. I see two possible ways to move on:
- Explicitly indicate for which special cases my class does work correctly.
- Elaborate the general case.
I would like to choose the 2nd one, however, my contribution is a part of an educational project of my Masters degree. So, it is better for me to put some endpoint until the deadline comes, which is the 7th of May (may be, plus several days). However, if the PR won't be approved by this deadline, there will be no catastrophe, I will just receive a lower mark.
Is it acceptable to choose the 1st alternative? If it is, it seems to me, that my class works correctly if every batch dimension is marked by plate
. Am I right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it acceptable to choose the 1st alternative? If it is, it seems to me, that my class works correctly if every batch dimension is marked by plate. Am I right?
Sorry, I'm not sure I understand your disclaimer. It would probably be helpful for you to write out the math and describe exactly the computations you're performing in your current implementation, the assumptions you're implicitly making about the model and guide, and the precise circumstances under which your MMD estimators are unbiased.
Without something like that to refer to, I don't think we can accept the PR as is, given that it's probably not correct for any models with nontrivial plate structure. Compare that with the situation of TraceMeanField_ELBO
, which only works with mean-field guides but is otherwise provably correct for all reparametrizable models with static control flow. I also don't think it's reasonable for us to expect you to work out (or even promise to work out) the general case for MMD between arbitrary graphical models just to get a first PR merged, though I would certainly encourage you to try outside this PR if you're interested.
Instead of having to do a lot of extra work, then, here are two simpler alternatives that would build on all the good work you've done already to get your code merged by May 7:
- Follow my suggestion above and only compute MMD across the particle dimension, not all plate dimensions; that's guaranteed to be correct.
- Put the code you wrote here into the nice example notebook you've already written and repurpose it as an advanced tutorial/example on implementing custom model-specific objectives in Pyro.
Either of these would be fine with me and would make great first contributions. I suspect other Pyro users would really appreciate a more thorough and well-motivated custom objective tutorial (specialized to the InfoVAE) that we could feature prominently on our example web site, since that's a very common problem faced by researchers and the only relevant tutorial we have now is not very detailed or thorough and does not contain a working end-to-end example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that I have already missed the deadline, however, I still want to get the job done. I have modified the class following your suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@varenick sorry for the delayed review, I'll try to get to this sometime this week
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@varenick sorry again for the delay. I'm going to go ahead and merge this. Would you mind (1) updating your example notebook to use this version of the loss, (2) confirming that it produces similar reconstructions/samples with sufficiently high num_particles
, (3) converting it to a .py
script and (4) opening another PR with the example script and a corresponding .rst
stub in the examples
directory? That way we can include it on the examples webpage.
Adds a feature #1780