Skip to content

Commit

Permalink
BUG: make curve_fit() work with array_like input. Closes scipygh-3037.
Browse files Browse the repository at this point in the history
  • Loading branch information
rgommers committed Dec 20, 2013
1 parent 34ae412 commit 856e396
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
10 changes: 9 additions & 1 deletion scipy/optimize/minpack.py
Expand Up @@ -3,6 +3,7 @@
import warnings
from . import _minpack

import numpy as np
from numpy import (atleast_1d, dot, take, triu, shape, eye,
transpose, zeros, product, greater, array,
all, where, isscalar, asarray, inf, abs,
Expand Down Expand Up @@ -532,15 +533,22 @@ def curve_fit(f, xdata, ydata, p0=None, sigma=None, absolute_sigma=False, **kw):
else:
p0 = [1.0] * (len(args)-1)

# Check input arguments
if isscalar(p0):
p0 = array([p0])

ydata = np.asanyarray(ydata)
if isinstance(xdata, (list, tuple)):
# `xdata` is passed straight to the user-defined `f`, so allow
# non-array_like `xdata`.
xdata = np.asarray(xdata)

args = (xdata, ydata, f)
if sigma is None:
func = _general_function
else:
func = _weighted_general_function
args += (1.0/asarray(sigma),)
args += (1.0 / asarray(sigma),)

# Remove full_output from kw, otherwise we're passing it in twice.
return_full = kw.pop('full_output', False)
Expand Down
9 changes: 9 additions & 0 deletions scipy/optimize/tests/test_minpack.py
Expand Up @@ -375,6 +375,15 @@ def f_flat(x, a, b):
assert_(pcov.shape == (2, 2))
assert_array_equal(pcov, np.inf)

def test_array_like(self):
# Test sequence input. Regression test for gh-3037.
def f_linear(x, a, b):
return a*x + b

x = [1, 2, 3, 4]
y = [2, 4, 6, 8]
assert_allclose(curve_fit(f_linear, x, y)[0], [2, 0], atol=1e-10)


class TestFixedPoint(TestCase):

Expand Down

0 comments on commit 856e396

Please sign in to comment.