Skip to content

Commit

Permalink
BUG: optimize: make curve_fit work with method as callable. Closes #1531
Browse files Browse the repository at this point in the history
.
  • Loading branch information
calbaker authored and rgommers committed Dec 20, 2011
1 parent c2ff01c commit 2fdf5ce
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
12 changes: 8 additions & 4 deletions scipy/optimize/minpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,16 +407,20 @@ def curve_fit(f, xdata, ydata, p0=None, sigma=None, **kw):
>>> popt, pcov = curve_fit(func, x, yn)
"""
if p0 is None or isscalar(p0):
if p0 is None:
# determine number of parameters by inspecting the function
import inspect
args, varargs, varkw, defaults = inspect.getargspec(f)
if len(args) < 2:
msg = "Unable to determine number of fit parameters."
raise ValueError(msg)
if p0 is None:
p0 = 1.0
p0 = [p0]*(len(args)-1)
if 'self' in args:
p0 = [1.0] * (len(args)-2)
else:
p0 = [1.0] * (len(args)-1)

if isscalar(p0):
p0 = array([p0])

args = (xdata, ydata, f)
if sigma is None:
Expand Down
18 changes: 17 additions & 1 deletion scipy/optimize/tests/test_minpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,23 @@ def func(x, a, b):
assert_(len(popt) == 2)
assert_(pcov.shape == (2,2))
assert_array_almost_equal(popt, [1.7989, 1.1642], decimal=4)
assert_array_almost_equal(pcov, [[0.0852, -0.1260],[-0.1260, 0.1912]], decimal=4)
assert_array_almost_equal(pcov, [[0.0852, -0.1260],[-0.1260, 0.1912]],
decimal=4)

def test_func_is_classmethod(self):
class test_self(object):
"""This class tests if curve_fit passes the correct number of
arguments when the model function is a class instance method.
"""
def func(self, x, a, b):
return b * x**a

test_self_inst = test_self()
popt, pcov = curve_fit(test_self_inst.func, self.x, self.y)
assert_(pcov.shape == (2,2))
assert_array_almost_equal(popt, [1.7989, 1.1642], decimal=4)
assert_array_almost_equal(pcov, [[0.0852, -0.1260], [-0.1260, 0.1912]],
decimal=4)


class TestFixedPoint(TestCase):
Expand Down

0 comments on commit 2fdf5ce

Please sign in to comment.