-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sine Skewed Toridial distribution #2826
Conversation
@fehiepsi, thanks for the review! 👍 |
…tributions/__init__.py`
…to feature/ss_dist
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify the input shapes for base density and skewness? It is not clear to me what's the requirements so I couldn't check if the expand method and the reshape ops in sample and log_prob are valid.
Hi @fehiepsi, Thanks for the review! Sorry for the input dimensions being vague. Let me clarify; the requirements are ss = SineSkewed(base_dist=BaseDist(Size([3,3,2])).to_event(1), skewness=Size([3,3,2]))
assert ss.event_shape == (2,)
assert ss.batch_shape == (3,3)
ss1 = SineSkewed(base_dist=BaseDist(Size([3,3,2])).to_event(1), skewness=Size([3,3,2])).to_event(1)
assert ss1.event_shape == (3,2)
assert ss1.batch_shape == (3,)
ss2 = SineSkewed(base_dist=BaseDist(Size([3,3,2])).to_event(2), skewness=Size([3,3,2]))
assert ss2.event_shape == (3,2)
assert ss2.batch_shape == (3,)
try:
SineSkewed(base_dist=BaseDist(Size([3,3,2])).to_event(2), skewness=Size([3,3,2])).to_event(1)
assert 1==0
except AssertionError:
pass
try:
SineSkewed(base_dist=BaseDist(Size([3,3,2])).to_event(1), skewness=Size([3,3,2])).to_event(2)
assert 1==0
except AssertionError:
pass I would like for the construction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks much cleaner than before. 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after a couple tiny nits
psi_bound = 1 - skew_phi.abs() | ||
skew_psi = pyro.sample('skew_psi', Uniform(-1., 1.)) | ||
skewness = torch.stack((skew_phi, psi_bound * skew_psi), dim=-1) | ||
assert skewness.shape == (num_mix_comp, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to have a constraint+transform for this (in a follow-up PR). I believe we can use signed stick-breaking transform here. This way users can define distributions over skewness
(or just simply use pyro.param
with correct constraint). Without that, it is a bit cumbersome for users to define correct skewness
over general d-torus.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With that, we can have correct constraints
in the distribution definition. :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be neat; I'll add it to the backlog.
This PR introduces the Sine Skewed Toridial distribution described by Jose Ameijeiras-Alonso and Christophe Ley.
The distribution enables skewing a base distribution on a d-dimensional torus, which is useful, for example, with dihedral angles (1-torus) of peptides as can be seen on a Ramachandran plot.
Missing prior suggestion in docstring; inference method and inferable params as suggested in #2821.