New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GAN Inference #1164
Comments
Hi @sameerkhurana10 we've discussed adding a discriminator example following Mohamed, Lakshminarayanan (2016) and Tran, Ranganath, Blei (2017), but we have no immediate plans. If you'd like to add something yourself, @eb8680 or @martinjankowiak could suggest where to start. |
thanks @fritzo i could have a go at it. Some pointers will be great. |
@sameerkhurana10 most such algorithms should pretty easy to implement with Pyro's existing tools (though if you find that's not the case, feel free to open another issue!). In my admittedly limited experience, the algorithms themselves (independent of Pyro) are extremely brittle and sensitive to hyperparameters, so there's not a compelling reason for us to add whole algorithms to Pyro instead of just providing tools for implementing them concisely. That said, however, we'd definitely welcome any contributions of examples or tutorials, or even a contributed library like Here's an almost-complete idiomatic Pyro implementation of a VAE with an implicit variational distribution and an Adversarial Variational Bayes loss optimized with simultaneous gradient descent, which seems like the simplest GAN inference variant: import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
import pyro.optim
import pyro.poutine as poutine
# only these and the data are missing
decoder = nn.Sequential(...)
encoder = nn.Sequential(...) # should have some internal randomness, e.g. call to torch.randn()
discriminator = nn.Sequential(...)
def model():
z = pyro.sample("z", dist.Normal(torch.zeros(10), torch.ones(10)))
loc_h, scale_h = pyro.module("decoder", decoder)(z)
return pyro.sample("x", dist.Normal(loc_h, scale_h))
constrained_model = lambda x: pyro.condition(model, data={"x": x})
def guide(x):
# encoder should have some internal randomness not exposed to pyro
return pyro.sample("z", pyro.module("encoder", encoder), x)
def loss(model, guide, *args, **kwargs):
pyro.module("discriminator", discriminator)
guide_tr = poutine.trace(guide).get_trace(*args, **kwargs)
model_tr = poutine.trace(poutine.replay(model, trace=guide_tr)).get_trace(*args, **kwargs)
prior_tr = poutine.trace(model).get_trace(*args, **kwargs)
# main loss
elbo = model_tr.nodes["x"]["fn"].log_prob().sum()
elbo -= discriminator(guide_tr.nodes["z"]["value"],
*guide_tr.nodes["z"]["args"]).sum()
# discriminator loss
aux_loss = torch.log(torch.sigmoid(discriminator(guide_tr.nodes["z"]["value"],
*guide_tr.nodes["z"]["args"]))).sum()
aux_loss -= torch.log(1. - torch.sigmoid(discriminator(prior_tr.nodes["z"]["value"],
prior_tr.nodes["x"]["value"]))).sum()
return main_loss, aux_loss
main_optim = pyro.optim.Adam({"lr": 0.001})
aux_optim = pyro.optim.Adam({"lr": 0.001})
... # load data
for batch in data:
with poutine.trace(param_only=True) as param_capture:
main_loss, aux_loss = loss(constrained_model, guide, batch)
# since discriminator is nn.Module, could also use:
# aux_params = discriminator.named_parameters()
aux_params = {name: node["value"].unconstrained()
for name, node in param_capture.nodes.items()
if "discriminator" in name}
# since encoder/decoder are nn.Modules, could also use:
# main_params = encoder.named_parameters()
# main_params.update(decoder.named_parameters()) # assuming names are different
main_params = {name: node["value"].unconstrained()
for name, node in param_capture.nodes.items()
if "discriminator" not in name}
for main_param in main_params.values():
if main_param.grad is not None:
main_param.grad.fill_(0)
main_loss.backward()
main_optim.step(main_params.values())
for aux_param in aux_params.values():
aux_param.grad.fill_(0)
aux_loss.backward()
aux_optim.step(aux_params.values()) |
Great, thanks @eb8680 this should be very helpful. |
Hi sameerkhurana, |
sorry, did not have time to work on it. Won't be able to get to it anytime
soon.
…On Thu, Jun 7, 2018 at 2:37 PM, ibulu ***@***.***> wrote:
Hi sameerkhurana,
have you been able to get this to work?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#1164 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/AHV3feloJndUMV_cvkdO_WKcxaRdDdFVks5t6XLlgaJpZM4UPVde>
.
--
conversation enriches understanding, but solitude is the school of genius.
|
@eb8680 @sameerkhurana10 Can I take this issue? If yes, I am right that what I need to implement is just to introduce a new class inherited from ELBO (say, AdversarialELBO), that would introduce a classifier for every latent variable? |
Sure, go for it! To make it easier to get started, I would recommend first
implementing a self-contained example taken directly from a single paper
rather than a general piece of machinery - maybe the first toy example in
the Adversarial Variational Bayes paper?
…On Sun, Feb 24, 2019, 10:00 AM varenick ***@***.***> wrote:
@eb8680 <https://github.com/eb8680> @sameerkhurana10
<https://github.com/sameerkhurana10> Can I take this issue? If yes, I am
right that what I need to implement is just to introduce a new class
inherited from ELBO (say, AdversarialELBO), that would introduce a
classifier for every latent variable?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#1164 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/AB8CwB9SHrSqaeTieTMaRAQA7cFrPMSnks5vQtM8gaJpZM4UPVde>
.
|
@eb8680 I have read the discussion here once again; am I right, that there are actually no need for a separate functionality, like, say, AdversarialELBO class, but rather an example of using adversarial methods in pyro? Looking at your example, it seems like computing a KL-divergence with a classifier requires a user to "dig into guts" of pyro. By saying this, I mean a user should at least use a |
@varenick it would be easy enough to turn my example code above into a simple generic I'm not opposed to having an Re: adding various ELBO term methods, that's an interesting idea but seems distinct from the discussion here. Feel free to open a separate issue to discuss further. |
@eb8680: I've worked a fair bit on implementing this paper, using Pytorch, on a reasonably sized dataset (100,000 dims, 5000-50000 samples). It involves a GAN-like objective and ties in nicely with a causal inference in genomics problem. I've been meaning to port my code to Pyro but may need some help since I'm a beginner with PPLs. Mind if I have a go? Also, this could be a nice project for GSoc 2020. |
@deepaks4077 sure! I'd encourage you to first make sure your existing code is correct and produces the results you expect on your data. Once you've done that, it should be easy to try porting it to Pyro following the discussion in this issue and in our custom objectives tutorial. If you get stuck, please don't hesitate to ask questions or open a PR with incomplete or incorrect code so that we can help you get it finished. |
@eb8680 : Great, I'll have a go at this soon. Does pyro have anything akin to ImplicitKLqp in Edward 1? Edit: Nevermind, I believe that is what we are trying to implement here. |
Hi,
thanks for releasing Pyro.
Any plans to have GAN like inference in pyro?
Thanks
The text was updated successfully, but these errors were encountered: