-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
domain check for a < b in stats.truncnorm #9174
domain check for a < b in stats.truncnorm #9174
Conversation
adding a domain condition of a < b
Changed the condition for py_test.
scipy/stats/_continuous_distns.py
Outdated
@@ -6254,7 +6254,7 @@ def _argcheck(self, a, b): | |||
-(self._sb - self._sa), | |||
self._nb - self._na) | |||
self._logdelta = np.log(self._delta) | |||
return a != b | |||
return np.bitwise_and(np.not_equal(a, b), np.less(a, b)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Less (<) implies not_equal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used the numpy function as a one-to-one correspondence during the argument checking.
As you already fully knows, the !=
symbol is np.not_equal
, the <
symbol is np.less
.
If so, does it mean that you can use only one np.less
function in this section?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return a < b
does the same thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will reflect that mention.
But, I used the np.less
function because it fails in py_test when using a<b
condition.
bshape = [5, 4, 3, 2]
dist = 'truncnorm'
distfunc = <scipy.stats._continuous_distns.truncnorm_gen object at 0x6ffef15b710>
k = 1
loc = array([0., 0.])
nargs = 2
scale = array([[1.],
[1.],
[1.]])
shape_args = (0.1, 2.0)
shape_only = False
shp = (5, 1, 1, 1)
scipy/stats/tests/test_continuous_basic.py:258:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
scipy/stats/tests/common_tests.py:300: in check_rvs_broadcast
sample = distfunc.rvs(*allargs)
scipy/stats/_distn_infrastructure.py:936: in rvs
cond = logical_and(self._argcheck(*args), (scale >= 0))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <scipy.stats._continuous_distns.truncnorm_gen object at 0x6ffef15b710>
a = array([[[[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1]],
[[0.1, 0.1],
[0.1, 0.1],
[0....1, 0.1],
[0.1, 0.1],
[0.1, 0.1]],
[[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1]]]])
b = array([[[[2., 2.],
[2., 2.],
[2., 2.]],
[[2., 2.],
[2., 2.],
[2., 2.]],
...
[[2., 2.],
[2., 2.],
[2., 2.]],
[[2., 2.],
[2., 2.],
[2., 2.]]]])
def _argcheck(self, a, b):
self.a = a
self.b = b
self._nb = _norm_cdf(b)
self._na = _norm_cdf(a)
self._sb = _norm_sf(b)
self._sa = _norm_sf(a)
self._delta = np.where(self.a > 0,
-(self._sb - self._sa),
self._nb - self._na)
self._logdelta = np.log(self._delta)
> return a != b and a < b
E ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be only return a < b
. Equality is ruled out by strict inequality <
anyways.
If you want to combine bool statements then it should read as (a != b) & (a < b)
. But here not needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I understood an expression of equality and comparison operator.
I will use the comparison operator <
as your advised.
My worries were about the py_test as well as the comparison operator.
I used the numpy function(np.less
, np.not_equal
, np.bitwise_and
) to solve the section
where py_test failed in multiple condition.
As a result, I will commit the conditional statements used only by the comparison operator.
only using comparison operator.
I've added |
All green, thanks @akahard2dj @chrisb83 |
I fix and a PR of the bug issued by @chrisb83.
I changed the condition formula for py_test from
a!=b and a<b
tonp.bitwise_and(np.not_equal(a, b), np.less(a, b))
in_argcheck
.Closes #9169