Skip to content

Commit

Permalink
BUG: fix travis failures
Browse files Browse the repository at this point in the history
  • Loading branch information
aarchiba committed May 12, 2015
1 parent 235bcf5 commit 71fc658
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
16 changes: 8 additions & 8 deletions scipy/misc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
return_sign : bool, optional
If this is set to True, the result will be a pair containing sign
information; if False, results that are negative will be returned
as NaN.
as NaN. Default is False (no sign information).
.. versionadded:: 0.16.0
Returns
Expand Down Expand Up @@ -92,16 +92,16 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
Returning a sign flag
>>> logsumexp([1,2],b=[1,-1])
(1.5413248546129181, -1)
>>> logsumexp([1,2],b=[1,-1],return_sign=True)
(1.5413248546129181, -1.0)
"""
a = asarray(a)
if b is not None:
a, b = broadcast_arrays(a,b)
if np.any(b==0):
a = a + 0. # promote to at least float
a[b==0] = -np.inf
if np.any(b == 0):
a = a + 0. # promote to at least float
a[b == 0] = -np.inf

# keepdims is available in numpy.sum and numpy.amax since NumPy 1.7.0
#
Expand Down Expand Up @@ -139,7 +139,7 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
s = sum(tmp, axis=axis)
if return_sign:
sgn = sign(s)
s *= sgn # /= makes more sense but we need zero -> zero
s *= sgn # /= makes more sense but we need zero -> zero
out = log(s)

out += a_max
Expand Down Expand Up @@ -169,7 +169,7 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
s = sum(tmp, axis=axis, keepdims=keepdims)
if return_sign:
sgn = sign(s)
s *= sgn # /= makes more sense but we need zero -> zero
s *= sgn # /= makes more sense but we need zero -> zero
out = log(s)

if not keepdims:
Expand Down
17 changes: 13 additions & 4 deletions scipy/misc/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,23 @@ def test_logsumexp_sign_zero():
r, s = logsumexp(a, b=b, return_sign=True)
assert_(not np.isfinite(r))
assert_(not np.isnan(r))
assert_(r<0)
assert_(r < 0)
assert_equal(s,0)

def test_logsumexp_sign_shape():
a = np.ones((1,2,3,4))
b = np.ones_like(a)

r, s = logsumexp(a, axis=(1,3), b=b, return_sign=True)
r, s = logsumexp(a, axis=2, b=b, return_sign=True)

assert_equal(r.shape, s.shape)
assert_equal(r.shape, (1,3))
assert_equal(r.shape, (1,2,4))

if NumpyVersion(np.__version__) >= NumpyVersion('1.7.0'):
r, s = logsumexp(a, axis=(1,3), b=b, return_sign=True)

assert_equal(r.shape, s.shape)
assert_equal(r.shape, (1,3))

def test_logsumexp_shape():
a = np.ones((1,2,3,4))
Expand All @@ -146,6 +152,10 @@ def test_logsumexp_shape():
r = logsumexp(a, axis=(1,3), b=b)

assert_equal(r.shape, (1,3))
if NumpyVersion(np.__version__) >= NumpyVersion('1.7.0'):
r = logsumexp(a, axis=(1,3), b=b)

assert_equal(r.shape, (1,3))

def test_logsumexp_b_zero():
a = [1,10000]
Expand All @@ -160,7 +170,6 @@ def test_logsumexp_b_shape():
logsumexp(a, b=b)



def test_face():
assert_equal(face().shape, (768, 1024, 3))

Expand Down

0 comments on commit 71fc658

Please sign in to comment.