Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Circular Reparameterization #1080

Merged
merged 32 commits into from Jul 23, 2021
Merged

Conversation

alexlyttle
Copy link
Contributor

@alexlyttle alexlyttle commented Jul 1, 2021

Add circular reparameterization to improve the sampling of the VonMises distribution re. issue #1070.

To Do:

  • Add CircularReparam class
  • Add circular = _Circular() constraint
    • Must be in the range -pi < x < pi
  • Change support for VonMises to circular constraint
  • Modify existing tests for VonMises distribution
  • Add tests for CircularReparam
  • Update docs

Issues so far:

  • VonMises does not work without using CircularReparam and so test_distribution_constraints fails with NotImplementedError: <numpyro.distributions.constraints._Circular object at 0x13129f520> not implemented..
    • Do we want to change these tests and change the docs so that it tells the user to use VonMises with CircularReparam?
    • I understand there are plans to make reparameterization automatic in the future, so this shouldn't be an issue?

@alexlyttle alexlyttle marked this pull request as draft July 1, 2021 09:48
@fritzo
Copy link
Member

fritzo commented Jul 1, 2021

Do we want to change these tests and change the docs so that it tells the user to use VonMises with CircularReparam?

Yes I think so. I believe you can follow the pattern for the ProjectedNormal distribution and constraints.sphere, including a new case in helpful_support_errors().

I understand there are plans to make reparameterization automatic in the future, so this shouldn't be an issue?

Yes, our vague vision is for reparametrizers to handle cases where bijections don't suffice. We started out with bijections registered with biject_to(), but those don't handle many-to-one transforms. It's still unclear to me how much functionality should be in each of biject_to and handlers.reparam, but I would like to make handlers.reparam as nearly-automatic as our constraints registered with biject_to.

@fritzo fritzo added the enhancement New feature or request label Jul 1, 2021
@alexlyttle
Copy link
Contributor Author

Yes I think so. I believe you can follow the pattern for the ProjectedNormal distribution and constraints.sphere, including a new case in helpful_support_errors().

Great, thanks! I have added a VonMises docstring and to helpful_support_errors() (following ProjectedNormal) and added to test_distributions to account for the new circular constraint.

Yes, our vague vision is for reparametrizers to handle cases where bijections don't suffice. We started out with bijections registered with biject_to(), but those don't handle many-to-one transforms. It's still unclear to me how much functionality should be in each of biject_to and handlers.reparam, but I would like to make handlers.reparam as nearly-automatic as our constraints registered with biject_to.

This sounds good!

I still need to come up with a test for CircularReparam but this shouldn't take too long.

@fritzo
Copy link
Member

fritzo commented Jul 2, 2021

@alexlyttle looks great, do you mind if I port this to Pyro quick for our upcoming 1.7 release?

@alexlyttle
Copy link
Contributor Author

@alexlyttle looks great, do you mind if I port this to Pyro quick for our upcoming 1.7 release?

Thanks, go for it!

@fritzo
Copy link
Member

fritzo commented Jul 2, 2021

Hmm on closer inspection it looks like PyTorch simply uses constraints.real for the VonMises constraint, which ends up working fine during inference (both HMC and SVI). I think that's actually a decent solution, and since changing pyro.distributions.VonMises.support = circular would break backwards compatibility, I think it may be better to just leave the Pyro/PyTorch version as-is.

@fehiepsi do you think it would break anyone's code to simply relax

numpyro.distributions.VonMises.support = constraints.real

as an improper distribution, thereby addressing @alexlyttle's original issue in one line?

@fehiepsi
Copy link
Member

fehiepsi commented Jul 2, 2021

I'm not sure what's the best action here. Matching PyTorch behavior is nice but I think we need to discuss this a bit more. After reading through the previous discussions, here is what I understand (assuming that posterior is single-modal on the circle):

using VonMises with real support

  • the distribution is improper, the sample method does not match the support and the density
  • posterior is multi-modal. MCMC chain diagnostics will be off because the chains will likely stay at different 2pi periods
  • seems to fit well for 1-chain MCMC
  • match PyTorch behavior

using VonMises with interval(-pi, pi) support

  • IIUC, posterior is 2-modal if it concentrates at the boundary (this is the original issue.) A "non-automatic" solution for this is to shift the center by pi (or simply by an estimated value of loc) and then shift the samples back. I guess it is a fair trade-off.

using VonMises with circular support

  • the distribution is proper and the sample method matches the support and the density
  • posterior is multi-modal. This seems to face similar issues as VonMises with real support.
  • reparam is needed

We might add a docstring mentioning different interpretations and expectations w.r.t. different supports. Then users can just simply make a wrapper of VonMises and change the support. Personally, I prefer using the interval(-pi, pi) support.

break anyone's code to simply relax

@OlaRonning Do you expect this change will affect your work?

@OlaRonning
Copy link
Member

@fehiepsi, regarding #1063 the support would become constraints.independent(constraints.circular, 1), and a reparam is needed regardless. For #1055 a circular support would make it easy to check the antecedent distribution has the correct support.

My personal opinion is a circular support is the best solution; neither real nor interval(-pi, pi) encodes the periodic boundary.

@alexlyttle
Copy link
Contributor Author

Apologies for slow progress. From the discussion above it sounds like this idea is still good to go? I'm working on the tests and changes today.

@OlaRonning

My personal opinion is a circular support is the best solution; neither real nor interval(-pi, pi) encodes the periodic boundary.

If people are happy with this that's great! Technically, the circular support I have put in doesn't really encode the periodic boundary, it's using the CircularReparam which does that. The circular support is more representative of the distribution being circular (similar to sphere constraint for ProjectedNormal) and bound by -pi and +pi. In #1070 I discussed this with @fritzo and we decided to define the circular constraint as similar to interval(-pi, +pi). I hope that makes sense!

@fehiepsi
Copy link
Member

fehiepsi commented Jul 8, 2021

I just have a random thought that we can have circular support that behaves as constraints.interval(...). This allows users to control the behavior as they want without having to redefine the class with some default behavior:

  • using circular without reparam: this uses biject_to(interval) for transformation and behaves as interval(-pi, pi) as previously. I think we can achieve this by simply making _Circular a subclass of _Interval and changing the raise ValueError to a warning, so users get informed about the behavior under the hood.
  • using circular with reparam: as the current state of this PR

@alexlyttle
Copy link
Contributor Author

@fehiepsi so have circular inherit constraints.interval? I think I had it like that first, but @fritzo suggested that the sigmoid transform in biject_to(interval) might interfere with CircularReparam or do an unnecessary transformation. What do you think?

@alexlyttle
Copy link
Contributor Author

I've been looking into this more. Attempting to recreate the ProjectedNormalReparam test but for CircularReparam, it doesn't produce the same distribution. This is because in CircularReparam we sample N(0, 1) noise and then make the transformation. So when reparam is used in the following way we just get the N(0, 1) distribution.

import numpy as np
import numpyro
import numpyro.distributions as dist
import arviz as az

from numpyro.infer.reparam import CircularReparam

def model(loc, concentration, n):
    with numpyro.plate_stack("plates", shape):
        with numpyro.plate("particles", n):
            numpyro.sample("x", dist.VonMises(loc, concentration))

shape = ()
loc = np.zeros(shape) + 2.0
conc = np.ones(shape)

with numpyro.handlers.trace() as trace_exp:
    with numpyro.handlers.seed(rng_seed=0):
        model(loc, conc, 10000)

with numpyro.handlers.reparam(config={"x": CircularReparam()}):
    with numpyro.handlers.trace() as trace_act:
        with numpyro.handlers.seed(rng_seed=0):
            model(loc, conc, 10000)

az.plot_trace(np.array(trace_exp['x']['value']))
az.plot_trace(np.array(trace_act['x']['value']))

image

image

@fritzo why did you suggest sampling parameter-free noise in CircularReparam here over something more like the following code? Using the code below with VonMises.support = constraints.real, when doing the above test you get the same distributions, and it works with MCMC sampler. This way reparam would only be required for MCMC. Or, is there something I am not understanding here?

class CircularReparam(Reparam):
    def __call__(self, name, fn, obs):

        # Draw from VonMises with constraints.real support
        value = numpyro.sample(
            f"{name}_unwrapped",
            fn,
            obs=obs,
        )

        # Differentiably transform.
        value = jnp.remainder(value + math.pi, 2 * math.pi) - math.pi

        numpyro.factor(f"{name}_factor", fn.log_prob(value))
        return None, value

@fritzo
Copy link
Member

fritzo commented Jul 8, 2021

have circular inherit constraints.interval?

I like @fehiepsi's suggestion, and this inheritance now seems like the right design choice. I had been worried about allowing a default registration for biject_to(circular), but it seems ok to allow that registration but emit a warning in helpful_support_errors().

@alexlyttle the motivation for my suggestion of a masked normal

numpyro.sample(f"{name}_unwrapped", dist.Normal(0,1).mask(False), obs=obs)

is that that sample statement gets the constraint correct and works during inference. If we tried to use the original fn as in

numpyro.sample(f"{name}_unwrapped", fn, obs=obs)

then the original fn.support would be used and we'd get nowhere. Note that when reparametrizing via Normal(0,1).mask(False), the density and posterior distributions will be correct, but sampling from the prior will be incorrect (as in your snippet). I'm not sure whether sampling from the prior matters (except during initialization), and if so I'm not sure how to fix it (@fehiepsi perhaps by providing an optional sampler arg to ImproperUniform?).

@fehiepsi
Copy link
Member

fehiepsi commented Jul 9, 2021

I don't have a better idea. Having sampler argument seems reasonable to me. For now, I think it is fine to just use ImproperUniform(constraints.real, ...) and by default, it will raise an error if a user wants to get samples from it. To get samples from priors, users can run the non-reparam model.

Alternatively, we can just create a new instance of fn with support is real. But I think it is unnecessary.

@alexlyttle
Copy link
Contributor Author

I have updated the docs and will push soon. I am getting less confident in the remainder solution for the reparam. When testing CircularReparam in my work I found that arctan2(sin(x), cos(x)) performed better during inference than the remainder method. By running better, I mean for the same model, data, random key, samples, warmup, chains etc. remainder resulted in more divergences and higher r_hat statistics. I looked into it and found that

jnp.remainder(x + math.pi, 2 * math.pi) - math.pi is slightly different to jnp.arctan2(jnp.sin(x), jnp.cos(x)). See the plots below for the difference.

This difference is tiny, but could be affecting inference. I'm not sure which method is right, my instinct says arctan because it performs better. Comparing the speed of the two in a jit compiled function they are as fast as each other, so I don't see much speed-advantage to choosing remainder over arctan.

(32-bit precision)
image

(64-bit precision)
image

@fehiepsi
Copy link
Member

Hi @alexlyttle, for now, I don't have much intuition for the difference between (a + pi) % 2pi - pi and arctan2(sin, cos). I'll take a closer look sometime this weekend and will let you know.

@fehiepsi
Copy link
Member

fehiepsi commented Jul 18, 2021

@alexlyttle It seems to me that we lost precision with arctan

import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

def f1(x):
    return (x + jnp.pi) % (2 * jnp.pi) - jnp.pi

def f2(x):
    return jnp.arctan2(jnp.sin(x), jnp.cos(x))

def f3(x):
    return (x.copy() + np.pi) % (2 * np.pi) - np.pi

x = jnp.linspace(-10., 10., 1000)
y1 = f1(x)
y2 = f2(x)
y3 = f3(x)

print("numpy vs remainder:", abs(y1 - y3).max())
print("numpy vs arctan:", abs(y2 - y3).max())
plt.plot(x, y2 - y3, "o");
numpy vs remainder: 0.0
numpy vs arctan: 7.1525574e-07

more divergences and higher r_hat statistics

I am a bit curious. Does this happen persistently across random seeds? How much different the results are? Maybe MCMC is happier with imprecise computation in this case. >__<

Btw, you can find and resolve lint issues by running make format or make lint.

@alexlyttle
Copy link
Contributor Author

@fehiepsi Thanks for investigating, I found the same thing and couldn't understand why arctan seemed to produced better inference. I will check in case it was to do with the random seed.

Btw, you can find and resolve lint issues by running make format or make lint.

Thanks! I forgot to do this last time round!

@alexlyttle
Copy link
Contributor Author

@fehiepsi It seems it is the choice of random key that was the issue, I should remember to vary it more often! The remainder method is perfectly fine.

Just to summary changes I am about to push:

When CircularReparam was not used, MCMC didn't just fall back to using Interval(-pi, +pi) support as I thought. We want this to happen because no reparam is still useful in some cases. To get it to do this, I changed the circular constraint to circular = _Interval(-math.pi, math.pi) rather than its own class. VonMises now samples without reparam -- e.g. this can be useful if your VonMises prior is ~ uniform and you want to do a prior=predictive check.

We wanted a warning to encourage the user to consider CircularReparam, so I added this to helpful_support_errors. Since VonMises.support no longer throws a NotImplementedError, I added the warning before the try/except. I'm not sure this is the perfect solution, because this warning shows up several times if running MCMC with many chains. We probably only want the warning to occur on model initialisation, so I could add a separate helpful_support_warnings context manager which is only called once?

Otherwise, this work should be nearly done!

@fehiepsi fehiepsi marked this pull request as ready for review July 21, 2021 03:01
@@ -962,10 +962,21 @@ def single_loglik(samples):

@contextmanager
def helpful_support_errors(site):
# Warnings
name = site["name"]
support = getattr(site["fn"], "support", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for completeness, you can add

if isinstance(support, constraints.independent):
    support = support.base_constraint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I've added that now

fehiepsi
fehiepsi previously approved these changes Jul 21, 2021
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @alexlyttle! Regarding the multiple warnings, one way is to add a boolean argument raise_warnings to helpful_support_errors and raise it here.

@alexlyttle
Copy link
Contributor Author

I've added the raise_warnings flag and the warning now only occurs once as hoped. It looks like we're good to go! Thanks for everyone's help on this one 😊

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to merge this! Thanks so much for addressing this tricky issue.

@fehiepsi fehiepsi merged commit e55f0d4 into pyro-ppl:master Jul 23, 2021
@alexlyttle alexlyttle requested a review from fehiepsi July 23, 2021 08:13
@alexlyttle alexlyttle deleted the circular-reparam branch July 23, 2021 08:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants