Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle dynamic support #241

Closed
fehiepsi opened this issue Jul 11, 2019 · 9 comments · Fixed by #268
Closed

Handle dynamic support #241

fehiepsi opened this issue Jul 11, 2019 · 9 comments · Fixed by #268
Labels
bug Something isn't working enhancement New feature or request

Comments

@fehiepsi
Copy link
Member

Just discover some edge case which we have not thought about it before. Consider the model

def model():
    x = sample('x', dist.Normal(0, 1))
    y = sample('y', dist.Uniform(x, x + 1))

The support of y site does depend on x so initialize_model will return different transforms for different initial rng.

@neerajprad
Copy link
Member

neerajprad commented Jul 11, 2019

Hmm..interesting. This will also be an issue with Pyro, and does seem serious because we will silently give wrong results. Maybe we should just have transform_fn take in both constraints and params.

EDIT: This will cause issues with JIT. I think if we just throw an error in initialize_model for such cases, that will be a good start.

@fehiepsi
Copy link
Member Author

Yup, this affects both autoguide and hmc inferences because both of them do inference in unconstrained space. As you said, this is indeed a serious problem. I think that we can resolve it by using non-centered reparametrization technique, that is to define each distribution as a transformed distribution with an affine transform. We do inference in the unconstrained domain of base_dist's support. It seems also have an advantage that the inference brings all variables to a unit scale (hence it is easier to adapt mass matrix). The disadvantage is we have to take 2 steps to get the samples (but it is cheap): first transform the unconstrained value back to base dist support domain, then trace the model again with these in-base_dist-support-domain values. I have not put more thoughts on this but the above solution seems doable. We might leverage handlers to avoid modifying our distribution implementation. And we only need to do it for interval/half-interval domains with constant dist args.

@neerajprad
Copy link
Member

I think that we can resolve it by using non-centered reparametrization technique, that is to define each distribution as a transformed distribution with an affine transform.

You mean the user has to use TransformedDistribution in their model or we do it behind the scenes? It seems to me that if it is the latter, that will have the same issue with having to detect the correct loc, scale for distributions with dynamic support? Maybe an example will help clarify.

The alternative seems even trickier because that would involve finding the new transform for each single step of the integrator, and also storing the transform for each sample so that the unconstrained samples can later be transformed into constrained samples.

@fehiepsi
Copy link
Member Author

fehiepsi commented Jul 11, 2019

I think we can just do it behind the scene. What I thought is to have log_density computed as follows

def potential_energy(params, base_dist_inv_transforms):
    log_joint = 0.

    # step 1:
    base_dist_samples = {}
    for name, value in params.items():
         base_dist_samples[name] = base_dist_inv_transforms(value)

    # step 2:
    # we can modify behaviour of handlers to make this job easy
    model_trace = get_trace(...)  # substitute base_dist sample and get trace

    log_joint = 0.
    for name, site in trace.nodes.items():
        if site['fn'] is TransformedDistribution:
            log_joint += site['fn'].base_dist.log_prob(base_dist_samples[name])
        else:
            log_joint += site['fn'].log_prob(base_dist_samples[name])

    # adjust log_det_jacobian just as the same way as the current behaviour, using
    # base_dist_inv_transforms

This way, we only need to revise Uniform/TruncatedCauchy/TruncatedNormal so that they are TransformedDistribution. Other transformed distributions or user-defined transformed distribution should work (edge case is when user wants to define transformed distribution with base dist is transformed distribution edit: we just need to merge the transforms at init method). In some sense, this makes inferences involving TransformedDistribution faster. For example, consider LogNormal potential_fn:

def f(unconstrained_x):
    d = dist.LogNormal(0, 1)
    t = ExpTransform()  # the same as biject_to(constraints.positive)
    return d.log_prob(t(unconstrained_x)) + t.log_det_abs_jacobian(unconstrained_x, x)

against the above proposal

def f(unconstrained_x):
    d = dist.LogNormal(0, 1)
    return d.base_dist.log_prob(unconstrained_x)

Does this proposal sound reasonable to you?

@neerajprad
Copy link
Member

This seems very reasonable!

I'm not quite sure why we need site['fn'].base_dist.log_prob(base_dist_samples[name]). e.g. once a Uniform(x, x + 2.) is changed to TransformedDistribution(Uniform(0, 1), AffineTransform(x., 2.)), can't we simply use the transformed distribution's log_prob on the site values obtained from model_trace? That should also account for the jacobian adjustment from having to stretch / contract from the unit interval to the site's support. There is of course another jacobian adjstment from base_dist_inv_transforms that will apply to all sites.

We can probably just test this out on a model that is user transformed to see that the logic in potential_fn works as expected, and then do this automatically using handlers.

@neerajprad
Copy link
Member

Also, this probably implies that we will need to do something like predictive and run the model forward to get constrained samples?

@fehiepsi
Copy link
Member Author

fehiepsi commented Jul 11, 2019

can't we simply use the transformed distribution's log_prob on the site values obtained from model_trace?

I think that this way is somehow redundant when we compute log_jacobian of AffineTransform(x., 2.) then remove it. For example, the potential of Uniform(x, x + 2.) would be

def f(unconstrained_u):
    ta = AffineTransform(x., 2.)
    d = TransformedDistribution(Uniform(0, 1), ta)
    ts = SigmoidTransform()
    u = ts(unconstrained_u)
    log_prob_u = d.base_dist.log_prob(u) - ta.log_abs_det_jacobian(u, ta(u))
    return log_prob_u + ta.log_abs_det_jacobian(u, ta(u)) + ts.log_abs_det_jacobian(unconstrained_u, u)

comparing to the above proposal

def f(unconstrained_u):
    ta = AffineTransform(x., 2.)
    d = TransformedDistribution(Uniform(0, 1), ta)
    ts = SigmoidTransform()
    u = ts(unconstrained_u)
    return d.base_dist.log_prob(u)  + ts.log_abs_det_jacobian(unconstrained_u, u)

Basically, we don't care what dist's support is, we just care about base_dist and base_dist_inv_transforms to compute log_prob of each site's unconstrained value. But transforms is useful to trace the model forward.

Also, this probably implies that we will need to do something like predictive and run the model forward to get constrained samples?

Yes, that is a disadvantage of the above proposal (but I guess it is cheap comparing to run an MCMC trajectory).

@fehiepsi
Copy link
Member Author

We can probably just test this out on a model that is user transformed to see that the logic in potential_fn works as expected, and then do this automatically using handlers.

Yup, I'll work on this next week but will definitely need your help on handlers stuff. There are several pending (but interesting) PRs I need to finish in upcoming days.

@neerajprad
Copy link
Member

Thanks for explaining, this is pretty neat! 😄

Yes, that is a disadvantage of the above proposal (but I guess it is cheap comparing to run an MCMC trajectory).

I suppose this will only matter if we have TransformedDistributions in the model, otherwise the computations should remain the same.

Yup, I'll work on this next week but will definitely need your help on handlers stuff. There are several pending (but interesting) PRs I need to finish in upcoming days.

Please take your time. I'll work on the handlers part that you can build on top of.

@fehiepsi fehiepsi added bug Something isn't working and removed low priority labels Jul 13, 2019
This was referenced Jul 13, 2019
This was referenced Jul 28, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants