New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Circular Reparameterization #1080
Conversation
Yes I think so. I believe you can follow the pattern for the
Yes, our vague vision is for reparametrizers to handle cases where bijections don't suffice. We started out with bijections registered with |
Great, thanks! I have added a
This sounds good! I still need to come up with a test for |
@alexlyttle looks great, do you mind if I port this to Pyro quick for our upcoming 1.7 release? |
Thanks, go for it! |
Hmm on closer inspection it looks like PyTorch simply uses @fehiepsi do you think it would break anyone's code to simply relax numpyro.distributions.VonMises.support = constraints.real as an improper distribution, thereby addressing @alexlyttle's original issue in one line? |
I'm not sure what's the best action here. Matching PyTorch behavior is nice but I think we need to discuss this a bit more. After reading through the previous discussions, here is what I understand (assuming that posterior is single-modal on the circle): using VonMises with real support
using VonMises with interval(-pi, pi) support
using VonMises with circular support
We might add a docstring mentioning different interpretations and expectations w.r.t. different supports. Then users can just simply make a wrapper of
@OlaRonning Do you expect this change will affect your work? |
@fehiepsi, regarding #1063 the support would become My personal opinion is a circular support is the best solution; neither |
Apologies for slow progress. From the discussion above it sounds like this idea is still good to go? I'm working on the tests and changes today.
If people are happy with this that's great! Technically, the |
I just have a random thought that we can have
|
I've been looking into this more. Attempting to recreate the import numpy as np
import numpyro
import numpyro.distributions as dist
import arviz as az
from numpyro.infer.reparam import CircularReparam
def model(loc, concentration, n):
with numpyro.plate_stack("plates", shape):
with numpyro.plate("particles", n):
numpyro.sample("x", dist.VonMises(loc, concentration))
shape = ()
loc = np.zeros(shape) + 2.0
conc = np.ones(shape)
with numpyro.handlers.trace() as trace_exp:
with numpyro.handlers.seed(rng_seed=0):
model(loc, conc, 10000)
with numpyro.handlers.reparam(config={"x": CircularReparam()}):
with numpyro.handlers.trace() as trace_act:
with numpyro.handlers.seed(rng_seed=0):
model(loc, conc, 10000)
az.plot_trace(np.array(trace_exp['x']['value']))
az.plot_trace(np.array(trace_act['x']['value'])) @fritzo why did you suggest sampling parameter-free noise in class CircularReparam(Reparam):
def __call__(self, name, fn, obs):
# Draw from VonMises with constraints.real support
value = numpyro.sample(
f"{name}_unwrapped",
fn,
obs=obs,
)
# Differentiably transform.
value = jnp.remainder(value + math.pi, 2 * math.pi) - math.pi
numpyro.factor(f"{name}_factor", fn.log_prob(value))
return None, value |
I like @fehiepsi's suggestion, and this inheritance now seems like the right design choice. I had been worried about allowing a default registration for @alexlyttle the motivation for my suggestion of a masked normal numpyro.sample(f"{name}_unwrapped", dist.Normal(0,1).mask(False), obs=obs) is that that sample statement gets the constraint correct and works during inference. If we tried to use the original numpyro.sample(f"{name}_unwrapped", fn, obs=obs) then the original |
I don't have a better idea. Having Alternatively, we can just create a new instance of |
I have updated the docs and will push soon. I am getting less confident in the
This difference is tiny, but could be affecting inference. I'm not sure which method is right, my instinct says |
Hi @alexlyttle, for now, I don't have much intuition for the difference between |
@alexlyttle It seems to me that we lost precision with import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
def f1(x):
return (x + jnp.pi) % (2 * jnp.pi) - jnp.pi
def f2(x):
return jnp.arctan2(jnp.sin(x), jnp.cos(x))
def f3(x):
return (x.copy() + np.pi) % (2 * np.pi) - np.pi
x = jnp.linspace(-10., 10., 1000)
y1 = f1(x)
y2 = f2(x)
y3 = f3(x)
print("numpy vs remainder:", abs(y1 - y3).max())
print("numpy vs arctan:", abs(y2 - y3).max())
plt.plot(x, y2 - y3, "o");
I am a bit curious. Does this happen persistently across random seeds? How much different the results are? Maybe MCMC is happier with imprecise computation in this case. >__< Btw, you can find and resolve lint issues by running |
@fehiepsi Thanks for investigating, I found the same thing and couldn't understand why arctan seemed to produced better inference. I will check in case it was to do with the random seed.
Thanks! I forgot to do this last time round! |
@fehiepsi It seems it is the choice of random key that was the issue, I should remember to vary it more often! The remainder method is perfectly fine. Just to summary changes I am about to push: When We wanted a warning to encourage the user to consider Otherwise, this work should be nearly done! |
@@ -962,10 +962,21 @@ def single_loglik(samples): | |||
|
|||
@contextmanager | |||
def helpful_support_errors(site): | |||
# Warnings | |||
name = site["name"] | |||
support = getattr(site["fn"], "support", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for completeness, you can add
if isinstance(support, constraints.independent):
support = support.base_constraint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, I've added that now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @alexlyttle! Regarding the multiple warnings, one way is to add a boolean argument raise_warnings
to helpful_support_errors
and raise it here.
I've added the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to merge this! Thanks so much for addressing this tricky issue.
Add circular reparameterization to improve the sampling of the
VonMises
distribution re. issue #1070.To Do:
CircularReparam
classcircular = _Circular()
constraintVonMises
to circular constraintVonMises
distributionCircularReparam
Issues so far:
VonMises
does not work without usingCircularReparam
and sotest_distribution_constraints
fails withNotImplementedError: <numpyro.distributions.constraints._Circular object at 0x13129f520> not implemented.
.VonMises
withCircularReparam
?