Skip to content

Commit

Permalink
fix_reverse_out_bound_quadratic_spline (#3140)
Browse files Browse the repository at this point in the history
* fix_reverse_out_bound_quadratic_spline

* fix_reverse_out_bound_quadratic_spline

* add two spaces before comment

* add sapce

* space

* lint

Co-authored-by: Fritz Obermeyer <fritz.obermeyer@gmail.com>
  • Loading branch information
LiaoShiqi97 and fritzo committed Oct 21, 2022
1 parent 0fe132e commit c573af1
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pyro/distributions/transforms/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def _monotonic_rational_spline(
c = -input_delta * (inputs - input_cumheights)

discriminant = b.pow(2) - 4 * a * c
# Make sure outside_interval input can be reversed as identity.
discriminant = discriminant.masked_fill(outside_interval_mask, 0)
assert (discriminant >= 0).all()

root = (2 * c) / (-b - torch.sqrt(discriminant))
Expand Down

0 comments on commit c573af1

Please sign in to comment.