Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix_reverse_out_bound_quadratic_spline #3140

Merged

Conversation

LiaoShiqi97
Copy link
Contributor

Issue Description
quadratic spline in pyro.distributions.transforms.Spline does not consider outside_interval input at reverse.

According to the Neural spline flow(https://arxiv.org/abs/1906.04032), monotonic rational-quadratic transforms sets the boundary derivatives as 1 to match the linear tails. However, when it takes the y(sample from the target distribution ) to reverse back to the x(sample from the base distribution) in the function _monotonic_rational_spline, it does not consider the scenario where y is out of bound, which would render a negative discriminant and return Error.

Environment
For any bugs, please provide the following:

python version: 3.8.7
PyTorch version: 1.12.1+cpu
Pyro version: 1.8.2
Solution
I replace a fake discriminant for the out-of-bound input.
Origin:
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()

Changed:
discriminant = b.pow(2) - 4 * a * c
discriminant[outside_interval_mask] = 0 # added to make sure outputs[outside_interval_mask] = inputs[outside_interval_mask]
assert (discriminant >= 0).all()

@@ -254,6 +254,7 @@ def _monotonic_rational_spline(
c = -input_delta * (inputs - input_cumheights)

discriminant = b.pow(2) - 4 * a * c
discriminant[outside_interval_mask] = 0 # added to make sure outside_interval input can be reversed as identity.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you'll need two spaces before the comment, and possibly to move the line comment to a separate line to satisfy max line length requirements. You can locally run make lint to see lint errors.

Also you might need to use a non-inplace version like

discriminant = discriminant.masked_fill(outside_interval_mask, 0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment. I am new to pulling a request. So in this case, it is because of my non-standard comment that my request fails, right?

Copy link
Member

@fritzo fritzo Oct 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so, we use the black automatic style checker to ensure all code follows the same style. This helps everyone collaborate by making the code more uniform.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get it. Thanks. I will edit it soon

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have re-pulled it. Please check it.

@fritzo fritzo merged commit c573af1 into pyro-ppl:dev Oct 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants