Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Analogue to TraceELBO class, but with MMD instead of KL #1780

Open
varenick opened this issue Mar 2, 2019 · 12 comments
Open

Comments

@varenick
Copy link
Contributor

varenick commented Mar 2, 2019

Feature:

A new MMDTraceELBO class, that will implement a Maximum Mean Discrepancy between samples from guide an from model instead of KL-divergence as in TraceELBO class.

Motivation:

Elbo is a sum of an expected loglikelihood and a minus KL-divergence between the posterior distribution and the prior. In order to compute a KL-term, we have to either have an ability to compute log-probabilities of both prior and posterior distributions at posterior samples, or train a classifier to distinguish between prior and posterior samples. The second alternative have not been implemented in pyro yet, however, using a classifier for computing density-ratios leads to a minimax-game objective and seems quite unreliable.

In Wasserstein Auto-Encoder paper https://arxiv.org/abs/1711.01558 authors propose two alternatives to distinguish between prior and posterior distributions: the first one is training a classifier, as discussed above, and the second one is using a Maximum Mean Discrepancy (MMD) instead of KL.

Advantages of MMD:

  1. Requires only samples from prior and posterior distributions, does not require explicit log-probabilities;
  2. Does not produce a minimax-game objective.

The main disadvantage of using MMD instead of KL is that the former does not provide us a valid variational lower bound for evidence. However, it leads us to an approximation for an optimal transport cost between training dataset and model distribution.

If this looks acceptable, I would like to try to implement this.

@eb8680
Copy link
Member

eb8680 commented Mar 2, 2019

@varenick sure, PRs are welcome! You should be able to use some of the existing kernels in pyro.contrib.gp. How are you thinking of computing the MMD for models with multiple variables? Are you planning to use one additive kernel per latent variable and combine them with a sum?

@karalets
Copy link
Collaborator

karalets commented Mar 3, 2019 via email

@varenick
Copy link
Contributor Author

varenick commented Mar 3, 2019

@eb8680 Thanks for a tip with existing kernels; I didn't know about them.

I was thinking of using (generally) different kernel k_i(*,*) per latent variable z_i and combine them with a (weighted) sum. Since all latent variables live in different spaces, the resulting kernel k(*,*) for a joint latent variable z = (z_1, ..., z_n) breaks down into sum of kernels: k(z,*) = Sum_{i=1}^n k_i(z_i,*). Since c k(*,*) is a kernel as long as k(*,*) is a kernel, the weighted sum of kernels for every latent variables is also valid: k(z,*) = Sum_{i=1}^n c_i k_i(z_i,*).

@varenick
Copy link
Contributor Author

varenick commented Apr 9, 2019

@eb8680 I've recently made a working prototype, planning to make a PR soon. I have a small problem: I don't know how to name the corresponding class.

Candidates:

  1. MMD_ELBO. Intuititive, but incorrect: it is not a valid variational lower bound for evidence.
  2. MMD_PseudoELBO. Better, but PseudoELBO is not a commonly-used term.
  3. MMD_VAE_Loss. Refers to Ermon Group blogpost based on InfoVAE paper. Not actually good, since it explicitly mentions VAE model.
  4. MMD_Based_ELBO_Approximation. Formally correct, but mentions ELBO explicitly, and looks too long.
  5. MMD_Based_Evidence_Variational_Approximation. Formally correct, but way too long.

Could you please suggest the name? I've only see such an objective in the context of VAE: see Ermon Group blogpost, where it is called MMD-VAE, and InfoVAE paper, where it is called InfoVAE.

@fritzo
Copy link
Member

fritzo commented Apr 9, 2019

How about Trace_MMD?

@eb8680
Copy link
Member

eb8680 commented Apr 9, 2019

@varenick great! Looking forward to seeing your PR. I agree with @fritzo's suggestion of Trace_MMD.

@varenick
Copy link
Contributor Author

@fritzo @eb8680 Hmm, Trace_MMD suggests that there is only an MMD term in the objective, however there are two terms: expected log-likelihood of the observed data, and MMD between marginal variational posterior and prior distributions. May be, Trace_MMD_Variational_Loss or Trace_MMD_Variational_Objective?

@varenick
Copy link
Contributor Author

@fritzo @eb8680 I've tried to push my branch into remote repo, but git returns error 403: permission denied

@eb8680
Copy link
Member

eb8680 commented Apr 10, 2019 via email

@wthrif
Copy link

wthrif commented Aug 8, 2019

Hey @varenick would you be willing to share and example VAE code with your Trace_MMD class?
I'm getting a shape error that I don't understand when I try to run it: ValueError: Shape mismatch inside plate('num_particles_vectorized') at site obs dim -2, 32 vs 1024
Full disclosure, I'm new to pyro and don't know what I'm doing. I set num_particles to be the same as my batch size, 32, and used an rbf kernel, with dimensions set as the same as my latent space (2). The 1024 number is the num_particles*batch size
Otherwise I'm just replacing the typical Trace_ELBO with Trace_MMD in a regular VAE code that already works.
The code runs if I set num_particles to 1. Although it doesn't converge to a useful latent space as the same infoVAE does made in pytorch.

@eb8680
Copy link
Member

eb8680 commented Aug 8, 2019

Although it doesn't converge to a useful latent space as the same infoVAE does made in pytorch.

@wthrif this is to be expected, since the Trace_MMD loss is not the same as the infoVAE loss. See #1818 for discussion. Unfortunately, implementing general-purpose inference algorithms is difficult and I'm not sure Trace_MMD as it currently exists is very useful. Without additional examples I'm inclined to remove it (at least temporarily) before the upcoming 0.4 release since we don't have spare cycles to maintain or improve it.

@varenick also wrote a nice example notebook in #1818 with a version of Trace_MMD that specifically replicates the infoVAE loss, but which is not correct for arbitrary models. If you're up for it (and @varenick doesn't mind), you could take that notebook, turn it into an example script similar to other examples in the examples/ folder, and submit it as a PR - I think @varenick did a great job with that notebook and lots of users would appreciate it as an example of a more complicated custom loss function.

@wthrif
Copy link

wthrif commented Aug 9, 2019

Thanks for the link @eb8680 I'll work on implementing it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants