Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code vs paper #19

Closed
dev7mt opened this issue Jul 9, 2019 · 5 comments
Closed

Code vs paper #19

dev7mt opened this issue Jul 9, 2019 · 5 comments

Comments

@dev7mt
Copy link

dev7mt commented Jul 9, 2019

Hello,

thank you for publishing your code - outstanding work :)
However I have a question regarding JSD/GAN based estimators and differences between implementation and formulation in your paper:
Eq. 4:
image

Eq. 7:
image
At the same time, in your code:
(for the JSD estimator)

if mode == 'fd':
    loss = fenchel_dual_loss(l_enc, m_enc, measure=measure)
[...]
E_pos = get_positive_expectation(u, measure, average=False).mean(2).mean(2)
E_neg = get_negative_expectation(u, measure, average=False).mean(2).mean(2)
[...]
Ep = log_2 - F.softplus(-p_samples)  # Note JSD will be shifted
Eq = F.softplus(-q_samples) + q_samples - log_2  # Note JSD will be shifted

While I do know, where does thie log_2 come from [Nowozin et al., 2016], the addition of q_samples in Eq is a bit more mysterious :D
And then for the prior matching:

    if not loss_type or loss_type == 'minimax':
        return get_negative_expectation(q_samples, measure)
    elif loss_type == 'non-saturating':
        return -get_positive_expectation(q_samples, measure)

Seems like you are using just half of the equation 7 to obtain loss value.

Could you clarify those differences (maybe I am missing something in the code)? I have been trying to merge DIM with my existing code (a bit different setup, yet should work together properly) and cannot get it to work well.

Thanks in advance!

@rdevon
Copy link
Owner

rdevon commented Jul 9, 2019

So the log 2 in the code is an artifact of another model (BGAN) where I just wanted all the f-divergences to align in a way that the estimate of p/q ended up being e^T. This doesn't change anything from a learning perspective, but it does shift the JSD.

q_samples are just the negative samples. So in the case of MI estimation, these would come from the product of marginals. If you write out the Eq part as log (1+ exp (-qsamples)) + log exp qsamples, then refactor, you get something of the form of (4).

For the prior matching, this is just done as in GAN, where the second term is all that's used for training the generator in the case of minmax. Or if you do non-saturating, you replace the first term of the f-divergence with q_samples (from the generator).

Does that make sense?

@dev7mt
Copy link
Author

dev7mt commented Jul 9, 2019

Okey, that makes sense! Thank you for clarification :)

One question - why would you use the log(1+exp(-qsamples)) + qsamples, instead of just log(1+exp(qsamples))? Is it more stable during the training?

@rdevon
Copy link
Owner

rdevon commented Jul 9, 2019

I think it's an artifact of the olden days of Theano, as I'm guessing in PyTorch these are equivalent.

@dev7mt dev7mt closed this as completed Jul 9, 2019
@SkyeLu
Copy link

SkyeLu commented Jul 24, 2019

Hi devon, Thanks for your interesting work! I have some trouble in reading your code about prior matching. Like your said:

For the prior matching, this is just done as in GAN, where the second term is all that's used for training the generator in the case of minmax.

But besides training the generator, we also need to train the discriminator, that is, to maximize Equation (7) in your paper, but I didn't find any code to do this part. Am I missing anything? Or have I misunderstood anything?
Thanks a lot!

@SkyeLu
Copy link

SkyeLu commented Jul 24, 2019

Oh sorry, I think you have done this maximize procedure in class Discriminator. Sorry to disturb.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants