Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GAN Inference #1164

Open
sameerkhurana10 opened this issue May 27, 2018 · 13 comments
Open

GAN Inference #1164

sameerkhurana10 opened this issue May 27, 2018 · 13 comments
Labels
enhancement good first issue help wanted Issues suitable for, and inviting external contributions question

Comments

@sameerkhurana10
Copy link

Hi,

thanks for releasing Pyro.

Any plans to have GAN like inference in pyro?

Thanks

@fritzo
Copy link
Member

fritzo commented May 27, 2018

Hi @sameerkhurana10 we've discussed adding a discriminator example following Mohamed, Lakshminarayanan (2016) and Tran, Ranganath, Blei (2017), but we have no immediate plans. If you'd like to add something yourself, @eb8680 or @martinjankowiak could suggest where to start.

@sameerkhurana10
Copy link
Author

sameerkhurana10 commented May 27, 2018

thanks @fritzo

i could have a go at it. Some pointers will be great.

@eb8680
Copy link
Member

eb8680 commented May 27, 2018

@sameerkhurana10 most such algorithms should pretty easy to implement with Pyro's existing tools (though if you find that's not the case, feel free to open another issue!). In my admittedly limited experience, the algorithms themselves (independent of Pyro) are extremely brittle and sensitive to hyperparameters, so there's not a compelling reason for us to add whole algorithms to Pyro instead of just providing tools for implementing them concisely. That said, however, we'd definitely welcome any contributions of examples or tutorials, or even a contributed library like pyro.contrib.gp.

Here's an almost-complete idiomatic Pyro implementation of a VAE with an implicit variational distribution and an Adversarial Variational Bayes loss optimized with simultaneous gradient descent, which seems like the simplest GAN inference variant:

import torch
import torch.nn as nn

import pyro
import pyro.distributions as dist
import pyro.optim
import pyro.poutine as poutine


# only these and the data are missing
decoder = nn.Sequential(...)
encoder = nn.Sequential(...)  # should have some internal randomness, e.g. call to torch.randn()
discriminator = nn.Sequential(...)


def model():
    z = pyro.sample("z", dist.Normal(torch.zeros(10), torch.ones(10)))
    loc_h, scale_h = pyro.module("decoder", decoder)(z)
    return pyro.sample("x", dist.Normal(loc_h, scale_h))


constrained_model = lambda x: pyro.condition(model, data={"x": x})


def guide(x):
    # encoder should have some internal randomness not exposed to pyro
    return pyro.sample("z", pyro.module("encoder", encoder), x)


def loss(model, guide, *args, **kwargs):
    pyro.module("discriminator", discriminator)

    guide_tr = poutine.trace(guide).get_trace(*args, **kwargs)
    model_tr = poutine.trace(poutine.replay(model, trace=guide_tr)).get_trace(*args, **kwargs)
    prior_tr = poutine.trace(model).get_trace(*args, **kwargs)

    # main loss
    elbo = model_tr.nodes["x"]["fn"].log_prob().sum()
    elbo -= discriminator(guide_tr.nodes["z"]["value"],
                          *guide_tr.nodes["z"]["args"]).sum()

    # discriminator loss
    aux_loss = torch.log(torch.sigmoid(discriminator(guide_tr.nodes["z"]["value"],
                                                     *guide_tr.nodes["z"]["args"]))).sum()
    aux_loss -= torch.log(1. - torch.sigmoid(discriminator(prior_tr.nodes["z"]["value"],
                                                           prior_tr.nodes["x"]["value"]))).sum()

    return main_loss, aux_loss


main_optim = pyro.optim.Adam({"lr": 0.001})
aux_optim = pyro.optim.Adam({"lr": 0.001})

...  # load data

for batch in data:
    with poutine.trace(param_only=True) as param_capture:
        main_loss, aux_loss = loss(constrained_model, guide, batch)

    # since discriminator is nn.Module, could also use:
    # aux_params = discriminator.named_parameters()
    aux_params = {name: node["value"].unconstrained()
                  for name, node in param_capture.nodes.items()
                  if "discriminator" in name}

    # since encoder/decoder are nn.Modules, could also use:
    # main_params = encoder.named_parameters()
    # main_params.update(decoder.named_parameters())  # assuming names are different
    main_params = {name: node["value"].unconstrained()
                   for name, node in param_capture.nodes.items()
                   if "discriminator" not in name}

    for main_param in main_params.values():
        if main_param.grad is not None:
            main_param.grad.fill_(0)
    main_loss.backward()
    main_optim.step(main_params.values())

    for aux_param in aux_params.values():
        aux_param.grad.fill_(0)
    aux_loss.backward()
    aux_optim.step(aux_params.values())

@sameerkhurana10
Copy link
Author

Great, thanks @eb8680

this should be very helpful.

@ibulu
Copy link

ibulu commented Jun 7, 2018

Hi sameerkhurana,
have you been able to get this to work?

@sameerkhurana10
Copy link
Author

sameerkhurana10 commented Jun 7, 2018 via email

@eb8680 eb8680 added help wanted Issues suitable for, and inviting external contributions good first issue and removed help wanted Issues suitable for, and inviting external contributions good first issue labels Oct 17, 2018
@varenick
Copy link
Contributor

@eb8680 @sameerkhurana10 Can I take this issue? If yes, I am right that what I need to implement is just to introduce a new class inherited from ELBO (say, AdversarialELBO), that would introduce a classifier for every latent variable?

@eb8680
Copy link
Member

eb8680 commented Feb 24, 2019 via email

@varenick
Copy link
Contributor

@eb8680 I have read the discussion here once again; am I right, that there are actually no need for a separate functionality, like, say, AdversarialELBO class, but rather an example of using adversarial methods in pyro?

Looking at your example, it seems like computing a KL-divergence with a classifier requires a user to "dig into guts" of pyro. By saying this, I mean a user should at least use a pyro.poutine module which seems quite low-level. May be, it would be convenient to provide a class that would implement methods expected_log_likelihood, entropy, cross_entropy separately, together with methods get_prior_samples and get_posterior_samples. This class would also be convenient for KL annealing, which is a popular technique

@eb8680
Copy link
Member

eb8680 commented Feb 26, 2019

@varenick it would be easy enough to turn my example code above into a simple generic AdversarialELBO class. I suppose I'm biased, but it seems pretty readable to me, especially if paired with our custom SVI objectives tutorial :)

I'm not opposed to having an AdversarialELBO class, but starting with an example makes sense to me for two reasons: first, because it will help familiarize you with the APIs you'd use to write a generic version, and second, you'll get a better sense for how difficult these algorithms are to tune, especially if you haven't worked with them before. You'll find that if you dig into the gritty details of even the simplest nontrivial examples, like the MNIST experiments in the Adversarial Variational Bayes paper, there are always several layers of hacks ("Adaptive Contrast" in this case) required to get the optimization algorithm to converge that may not generalize usefully beyond those examples in practice. Starting with an example or two would help guide the design of a generic implementation that works more reliably.

Re: adding various ELBO term methods, that's an interesting idea but seems distinct from the discussion here. Feel free to open a separate issue to discuss further.

@deepaks4077
Copy link

deepaks4077 commented Dec 31, 2019

@eb8680: I've worked a fair bit on implementing this paper, using Pytorch, on a reasonably sized dataset (100,000 dims, 5000-50000 samples). It involves a GAN-like objective and ties in nicely with a causal inference in genomics problem. I've been meaning to port my code to Pyro but may need some help since I'm a beginner with PPLs. Mind if I have a go?

Also, this could be a nice project for GSoc 2020.

@eb8680
Copy link
Member

eb8680 commented Jan 1, 2020

Mind if I have a go?

@deepaks4077 sure! I'd encourage you to first make sure your existing code is correct and produces the results you expect on your data. Once you've done that, it should be easy to try porting it to Pyro following the discussion in this issue and in our custom objectives tutorial. If you get stuck, please don't hesitate to ask questions or open a PR with incomplete or incorrect code so that we can help you get it finished.

@deepaks4077
Copy link

deepaks4077 commented Jan 2, 2020

@eb8680 : Great, I'll have a go at this soon. Does pyro have anything akin to ImplicitKLqp in Edward 1?

Edit: Nevermind, I believe that is what we are trying to implement here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement good first issue help wanted Issues suitable for, and inviting external contributions question
Projects
None yet
Development

No branches or pull requests

6 participants