Skip to content

Commit

Permalink
fixed NaN with in log_prob corr < -1e-8 for SineBivariateVonMises (
Browse files Browse the repository at this point in the history
…#3165)

* fixed NaN with in `log_prob` corr < -1e-8 for `SineBivariateVonMises`

* fixed lint

* added clamp to corr.

* eps -> tiny.
  • Loading branch information
OlaRonning committed Dec 20, 2022
1 parent 3422c3a commit 0b1818c
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions pyro/distributions/sine_bivariate_von_mises.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class SineBivariateVonMises(TorchDistribution):
:param torch.Tensor phi_concentration: concentration of first angle
:param torch.Tensor psi_concentration: concentration of second angle
:param torch.Tensor correlation: correlation between the two angles
:param torch.Tensor weighted_correlation: set correlation to weigthed_corr * sqrt(phi_conc*psi_conc)
to avoid bimodality (see note). The `weightd_correlation` should be in [0,1].
:param torch.Tensor weighted_correlation: set correlation to weighted_corr * sqrt(phi_conc*psi_conc)
to avoid bimodality (see note). The `weighted_correlation` should be in [0,1].
"""

arg_constraints = {
Expand Down Expand Up @@ -95,9 +95,8 @@ def __init__(
sqrt_ = (
torch.sqrt if isinstance(phi_concentration, torch.Tensor) else math.sqrt
)
correlation = (
weighted_correlation * sqrt_(phi_concentration * psi_concentration)
+ 1e-8
correlation = weighted_correlation * sqrt_(
phi_concentration * psi_concentration
)

(
Expand Down Expand Up @@ -130,14 +129,15 @@ def __init__(

@lazy_property
def norm_const(self):
corr = self.correlation.view(1, -1) + 1e-8
corr = self.correlation.view(1, -1)
conc = torch.stack(
(self.phi_concentration, self.psi_concentration), dim=-1
).view(-1, 2)
m = torch.arange(50, device=self.phi_loc.device).view(-1, 1)
tiny = torch.finfo(corr.dtype).tiny
fs = (
SineBivariateVonMises._lbinoms(m.max() + 1).view(-1, 1)
+ 2 * m * torch.log(corr)
+ m * torch.log((corr**2).clamp(min=tiny))
- m * torch.log(4 * torch.prod(conc, dim=-1))
)
fs += log_I1(m.max(), conc, 51).sum(-1)
Expand Down

0 comments on commit 0b1818c

Please sign in to comment.