Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Observations that are deterministic functions of latents #773

Closed
innuo opened this issue Feb 14, 2018 · 18 comments
Closed

Observations that are deterministic functions of latents #773

innuo opened this issue Feb 14, 2018 · 18 comments
Labels

Comments

@innuo
Copy link

innuo commented Feb 14, 2018

This is a follow-up to
#568

As was recommended in the discussion of the above issue, even if we can model them with delta or normal with small variance samples, observations that are deterministic transformations of latent variables have another consequence when performing variational inference.

Assume the model, where we are interested in learning something about mu given observation x = x0. (Assume a is known and f is a deterministic function.)

z ~ Normal(mu, a)
x ~ Normal(f(z), sigma=epsilon, obs=x0)

In the guide, the latents can be modeled as being sampled based on the parameters or as small deviations of the observed values.

Either the guide looks something like

mu = param
z ~ Normal(mu, a)

or

z ~ Normal(approx_f_inverse(x0), sigma=epsilon)

In the first case, because x0 has negligible probability given guide-sampled z, the convergence is very slow.

The second case doesn't allow learning about mu at all because the sample z doesn't depend on mu.

Perhaps this is not an issue but only a consequence of some misunderstanding on my part, but how can I model a situation as the one above? Does the likelihood-free stuff mentioned above handle this?

Any pointers on how inference in the above model can be accomplished with Pyro currently?

@fritzo
Copy link
Member

fritzo commented Feb 14, 2018

If f is invertible you can implement it as a TransformedDistribution:

class MyTransform(torch.distributions.Transform):
    def _call(self, x):
        ...
    def _inverse(self, y):
        ...
    def log_abs_det_jacobian(self, x, y):
        ...

def model(x0):
    pyro.sample("x", dist.TransformedDistribution(Normal(mu, a), MyTransform()), obs=x0)

See http://pytorch.org/docs/master/distributions.html#module-torch.distributions.transforms

@innuo
Copy link
Author

innuo commented Feb 14, 2018

Thanks. That definitely solves the problem by transforming the deterministic function to a sample call when f is invertible.

Is there a way to do this in case f is not invertible?

This is not an uncommon scenario. For example consider the following model of wind readings from a Met tower.

x component of wind vector  = u ~ Normal(mu_x, sigma_x)
y component of wind vector  = v ~ Normal(mu_y, sigma_y)
wind_speed = sqrt(u^2 + v^2)

where we would like to infer E[u | wind_speed = w_o].

In general are non-invertible transformations disallowed inside Pyro models?

Please see also my related question on the pyro forum: https://forum.pyro.ai/t/variational-inference-with-non-bijective-stochastic-functions/108

@fritzo
Copy link
Member

fritzo commented Feb 14, 2018

are non-invertible transformations disallowed inside Pyro models?

Correct, Pyro does not allow observation of non-invertible transformation of a random variable. As a workaround you can implement a Distribution object that performs your non-invertible transformation and computes .log_prob() by integrating over any marginalized-out variables. E.g. in your case I believe you could use a Chi2 distribution instead of the 2-norm of two Normals; in this case the implementation of Chi2.log_prob() effectively integrates over the marginalized-out angle.

@innuo
Copy link
Author

innuo commented Feb 14, 2018

Is the non-invertibility explicitly checked for? Or does the inference just proceed with possibly wrong answers?

The reason I ask is that I was modeling a situation where one latent variable was the maximum of two other latent variables (I did something like z3 ~ Normal(torch.max(z1, z3), sigma)) and the inference proceeded without errors but gave nonsensical results.

@eb8680
Copy link
Member

eb8680 commented Feb 14, 2018

@innuo your example in your first post will function correctly with SVI. If you want to learn about mu you'll have to put a prior on it and include a mu term in the guide or wrap it in pyro.param in the model.

However, the SVI approximation to the posterior is biased, and may produce incorrect results. Have you tried pyro.infer.HMC?

@innuo
Copy link
Author

innuo commented Feb 14, 2018

@eb8680 I didn't try HMC, but I did have a prior on mu.

Could you please comment more about my question about the choices about the guides? If x is just a delta distribution around f(z) how can I pick a guide that works for inference?

@innuo
Copy link
Author

innuo commented Feb 14, 2018

@fritzo What I don't really understand is that when there are non-invertible transformations how Pyro goes about computing the log_pdf.

Essentially my model is specifying sampling according to some P_i(x_i | f(x_1, x_2, ..., x_i-1) where f is some arbitrary composition. Could you give me some intuition about the procedure to carry forward the computation of the log-pdf log P(x_1, x_2, ..., x_i-1, x_i )?

@eb8680
Copy link
Member

eb8680 commented Feb 14, 2018

@innuo I think you may be getting caught up in notation. If x is a delta around f(z) and you give it a Gaussian likelihood with a small sigma, that's an ABC likelihood for x, which is a new random variable y in Pyro. Pyro never computes the marginal log density of x (which is not possible in general), just the sum of the prior log p(z; mu), the trivial delta conditional of x given z and Gaussian conditional of y given x, log p_y (x_observed | x). This is a Monte Carlo ABC approximation to log p_x (x_observed; mu), which is the thing you're actually after.

Currently, random variables at sample statements are expected to have tractable conditional densities, so that the log-joint is just a sum of all the tractable log-conditionals plus implicit deltas on all intermediate deterministic computation.

Re: guide choice, your intuition about choosing an approximate inverse of f is correct. The second case in your first post actually does allow learning mu through the prior term in the ELBO log p(z; mu) provided mu is wrapped in pyro.param in the model. However, since you're using an approximate likelihood, it may be difficult to get inference to work reliably. Your best bet is to follow @fritzo's suggestion and use or implement another distribution representing the true likelihood.

@innuo
Copy link
Author

innuo commented Feb 14, 2018

@eb8680 I don't think I follow yet. I must be missing some crucial insight.

Here is a runnable Pyro code to demonstrate what I am doing. I would appreciate it if you could point out why neither of these two guides (which follow the patterns in my original post) work for inferring mu in my model.

(I tried to find the simplest example to illustrate the issue. I understand that I could eliminate the need for the z_i variables in my code by reparameterizing x_i)

import numpy as np
import torch
from torch.autograd import Variable

import pyro
import pyro.infer
import pyro.optim

import pyro.distributions as dist
from pyro.optim import Adam
from pyro.infer import SVI

def model(data, sigma):
    mu = pyro.sample("mu", dist.uniform, Variable(torch.Tensor([0])), Variable(torch.Tensor([20])))

    for i in range(len(data)):
        z_i = pyro.sample("z_{}".format(i), dist.normal, mu, Variable(torch.Tensor([sigma])))

        #x = 2 * z which is modeled with a normal with small variance
        pyro.sample("x_{}".format(i), dist.normal, 2*z_i, Variable(torch.Tensor([0.01])),
                    obs=Variable(torch.Tensor([data[i]])))


def guide1(data, sigma):
    mu_mean = pyro.param("mu_mean", Variable(torch.Tensor([15]), requires_grad=True))
    log_sigma = pyro.param("log_sigma", Variable(torch.Tensor([np.log(.1)]), requires_grad=True))

    mu = pyro.sample("mu", dist.normal, mu_mean, torch.exp(log_sigma))

    for i in range(len(data)):
        z_i = pyro.sample("z_{}".format(i), dist.normal, mu, Variable(torch.Tensor([sigma])))


def guide2(data, sigma):
    mu_mean = pyro.param("mu_mean", Variable(torch.Tensor([15]), requires_grad=True))
    log_sigma = pyro.param("log_sigma", Variable(torch.Tensor([np.log(.1)]), requires_grad=True))

    mu = pyro.sample("mu", dist.normal, mu_mean, torch.exp(log_sigma))

    for i in range(len(data)):
        z_i = pyro.sample("z_{}".format(i), dist.normal, Variable(torch.Tensor([0.5 * data[i]])), Variable(torch.Tensor([0.01])))

if __name__ == '__main__':
    #Generate some data
    mu = 10 #unknown
    sigma = 2 #known
    N =  20  # size of the data
    data = np.random.normal(mu, sigma, N)

    adam_params = {"lr": 0.01, "betas": (0.90, 0.999)}
    optimizer = Adam(adam_params)

    svi = SVI(model, guide2, optimizer, loss="ELBO", num_particles=10)

    losses = []
    for step in range(1000):
        losses.append(svi.step(data, sigma))

    a = pyro.get_param_store()
    print(a._params)

@eb8680
Copy link
Member

eb8680 commented Feb 15, 2018

I understand that I could eliminate the need for the z_i variables in my code by reparameterizing x_i

But this is probably the primary reason it's not working - you're doing biased approximate inference in a noisy approximation to your true model. If you want higher quality inference, you'll have to reduce bias and variance. You can remove bias in inference by using HMC, and you can remove the variance in your model approximation entirely by writing the true likelihood as a Pyro Distribution and eliminating those z_i variables.

Otherwise, you can try using more particles in SVI by passing num_particles=N to SVI or tuning hyperparameters like parameter initialization, ABC kernel bandwidth, and learning rate. SVI can unfortunately be rather finicky, even in systems like Stan which limit the universe of possible models and focus much more on truly automated inference.

@innuo
Copy link
Author

innuo commented Feb 15, 2018

@fritzo @eb8680 Thanks for your replies. I think I have a much better understanding of the limitations of variational inference and how to circumvent them in Pyro.

I think it might make sense to add a section on modeling "Observations that are deterministic functions of latents" or some such to the Pyro tutorials. This is an important class of models and other users may have the same questions I did.

It might also be useful to provide some guidance in the tutorials on non-invertible transformations in Pyro models.

@eb8680
Copy link
Member

eb8680 commented Feb 15, 2018

It might also be useful to provide some guidance in the tutorials on non-invertible transformations in Pyro models.

Agreed, we may do that as we add algorithms aimed at that class of problems (e.g. LFVI). Contributions are welcome in the meantime!

@eb8680 eb8680 closed this as completed Feb 15, 2018
@rachtsingh
Copy link

I think I roughly understand the above discussion (@eb8680 , your comment about ABC made a lot of sense to me).

However, @fritzo's example as given here (copied below) doesn't work anymore since Pyro has changed a lot. Could someone translate that example?

def add_one_or_two(guess):
    init = Variable(torch.Tensor([2]))
    choice = pyro.sample("choice", dist.categorical, ps=guess,vs=[False,True])
    if choice:
        outcome = init + 1
    else:
        outcome = init + 2
    return pyro.sample("outcome", dist.normal, outcome, 0.1 * ng_ones(1))

guess = Variable(torch.Tensor([0.5,0.5]))
conditioned = pyro.condition(add_one_or_two, data={"outcome": Variable(torch.Tensor([4]))})
marginal = pyro.infer.Marginal(pyro.infer.Importance(conditioned, num_samples=100), sites=["choice"])

marginal(guess)

Things I struggled with trying to convert it:

  1. Categorical doesn't have vs anymore, so the line if choice: keeps failing, since either choice seems to vary from between multidimensional to 0-dimensional during runtime.
  2. Trying to sidestep that issue by doing something a little incorrect (conditioning on choice.sum()) raises a model issue, which I think might be a bug

If it helps I tried to use HMC here, since that's what @eb8680 recommended a few posts up. Here's the stacktrace from 2:

...pyro/infer/mcmc/util.py in _populate_cache(self, model_trace)
    170             raise ValueError("Finite value required for `max_plate_nesting` when model "
    171                              "has discrete (enumerable) sites.")
--> 172         model_trace.compute_log_prob()
    173         model_trace.pack_tensors()
    174         for name, site in model_trace.nodes.items():

...pyro/poutine/trace_struct.py in compute_log_prob(self, site_filter)
    161                 if "log_prob" not in site:
    162                     try:
--> 163                         log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
    164                     except ValueError:
    165                         _, exc_value, traceback = sys.exc_info()

TypeError: log_prob() takes 2 positional arguments but 3 were given

Thanks!

@fritzo
Copy link
Member

fritzo commented Jan 9, 2019

I haven't tried running this but here is an attempted translation:

def add_one_or_two(guess):
    init = torch.tensor(2.)
    choice = pyro.sample("choice", dist.Categorical(ps=guess))
    if choice:
        outcome = init + 1
    else:
        outcome = init + 2
    return pyro.sample("outcome", dist.Normal(outcome, 0.1 * ng_ones(1)))

guess = torch.tensor([0.5,0.5])

I'm not sure about the rest, @eb8680 does this look right?

conditioned = pyro.condition(add_one_or_two, data={"outcome": torch.tensor(4.)})
marginal = pyro.infer.EmpiricalMarginal(
    pyro.infer.Importance(conditioned, num_samples=100), sites=["choice"])
marginal(guess)

@rachtsingh
Copy link

bump @eb8680 :)

I would really appreciate an example here - every version of this that I've tried has failed, and while I know it's not exactly an intended use-case it would be really helpful.

@eb8680
Copy link
Member

eb8680 commented Jan 14, 2019

@rachtsingh I'm not sure what you're looking for, can you be more specific? What do you mean by failed? In general, to apply the ABC trick, just add an auxiliary variable in your model and proceed in the usual way as if that auxiliary variable were the true likelihood:

def model(data):
    ...
    deterministic_value = deterministic_fn(stuff)
    abc_bandwidth = torch.tensor(1e-2)
    abc_dist = Normal(loc=deterministic_value, scale=abc_bandwith)
    pyro.sample("constrained_value", abc_dist, obs=data)
    ...

The model in the example you linked to would become something like that with

deterministic_value = (init + 2) if pyro.sample("choice", dist.Bernoulli(guess)).item() else (init + 1)

Note that if data is very different from deterministic_value, inference will be necessarily be difficult because the data is very unlikely under the approximate model and will take a long time to converge and generally be noisy and unstable. If deterministic_fn is invertible, you should use a TransformedDistribution to compute the likelihood term p(data | deterministic_value, ...) exactly.

@rachtsingh
Copy link

Ok, I think I pretty much understand now. Here's a complete working example:

def model(data):
    init = torch.tensor(2.)
    choice = pyro.sample("choice", dist.Categorical(probs=torch.tensor([0.5, 0.5])))
    if choice:
        outcome = init + 1
    else:
        outcome = init + 2
    abc_dist = dist.Normal(loc=outcome, scale=torch.tensor(0.1))
    return pyro.sample("outcome", abc_dist, obs=data)
data = torch.tensor(4.)
importance = pyro.infer.Importance(model, num_samples=100).run(data)
marginal = pyro.infer.EmpiricalMarginal(importance, "choice")

If I understand correctly - the guide (here just the prior) provides a distribution that we sample from, pass through the deterministic function, and then evaluate the ABC likelihood on the data, which here is easy since it's ~0 for choice = 1, and something positive for choice = 0. Just to clarify, importance then corresponds to the distribution which samples from {0, 1} with importance sampling weights (I think I confirmed this with some experiments). I don't quite understand how this works with HMC, but I think I can probably figure this out.

Sorry about the lack of specific example until now - I had one in mind but I wanted to understand the general principle. Essentially we have a parameter p I'd like to fit, such that x ~ Dist(param=p) for some distribution Dist. And then we observe f(x) where f is a deterministic function. We should be able to get a posterior distribution over p using ABC here, right? I just need to add an abc_dist = Normal(f(x), bandwidth) and observe that.

The point you make above is essentially saying that if the probability of the right p under the prior is low, then we will sort of blind about what's possible. If performance isn't an issue, then in 1D we can counteract this by taking many samples from the prior.

Thanks for the help! Please let me know if I've got something wrong.

@eb8680
Copy link
Member

eb8680 commented Jan 14, 2019

Yeah, you've got it.

I don't quite understand how this works with HMC

The latent variable in this model is discrete, so HMC wouldn't apply here, but if it was continuous you could use HMC instead of importance sampling for better results in high dimensions.

... in 1D we can counteract this by taking many samples from the prior

Yes, and note that this corresponds to the way ABC is typically presented, as a kernel density estimate of the marginal distribution of f(.) for a fixed value of p.

We should be able to get a posterior distribution over p using ABC here, right?

Yes, that's right, but also remember that this posterior is coming from an approximate model, and should be interpreted with care. In practice abc_bandwidth is sometimes annealed toward zero over the course of training to counteract this source of bias.

A related caveat which may not apply to your case: if there are other downstream computations in your model that consume deterministic_value, you want to pass the noisy value (or the data when it's observed) to them instead:

deterministic_value = deterministic_fn(stuff)
...  # as above
noisy_value = pyro.sample("noisy_value", abc_dist, obs=data)  # == data when data is not None
# correct
more_computation(value=noisy_value)
# incorrect
more_computation(value=deterministic_value)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants