Skip to content

Commit

Permalink
Warning printed too many times (`RuntimeWarning: invalid value encoun…
Browse files Browse the repository at this point in the history
…tered in scalar divide ....`) (#371)
  • Loading branch information
frances-h committed Nov 15, 2023
1 parent 77855a3 commit ceebde9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
12 changes: 8 additions & 4 deletions copulas/univariate/truncated_gaussian.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""TruncatedGaussian module."""

import warnings

import numpy as np
from scipy.optimize import fmin_slsqp
from scipy.stats import truncnorm
Expand Down Expand Up @@ -47,10 +49,12 @@ def nnlf(params):
return truncnorm.nnlf((a, b, loc, scale), X)

initial_params = X.mean(), X.std()
optimal = fmin_slsqp(nnlf, initial_params, iprint=False, bounds=[
(self.min, self.max),
(0.0, (self.max - self.min)**2)
])
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
optimal = fmin_slsqp(nnlf, initial_params, iprint=False, bounds=[
(self.min, self.max),
(0.0, (self.max - self.min)**2)
])

loc, scale = optimal
a = (self.min - loc) / scale
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/univariate/test_truncated_gaussian.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import warnings
from unittest import TestCase
from unittest.mock import patch

import numpy as np
from scipy.stats import truncnorm
Expand Down Expand Up @@ -35,6 +37,27 @@ def test__fit(self):
for key, value in distribution._params.items():
np.testing.assert_allclose(value, expected[key], atol=0.3)

@patch('copulas.univariate.truncated_gaussian.fmin_slsqp')
def test__fit_silences_warnings(self, mocked_wrapper):
"""Test the ``_fit`` method does not emit RuntimeWarnings."""
# Setup
def mock_fmin_sqlsqp(*args, **kwargs):
warnings.warn(
message='Runtime Warning occured!',
category=RuntimeWarning
)
return 0, 1

mocked_wrapper.side_effect = mock_fmin_sqlsqp
distribution = TruncatedGaussian()

data = truncnorm.rvs(size=10000, a=0, b=3, loc=3, scale=1)

# Run and assert
with warnings.catch_warnings():
warnings.simplefilter('error')
distribution._fit(data)

def test__is_constant_true(self):
distribution = TruncatedGaussian()

Expand Down

0 comments on commit ceebde9

Please sign in to comment.