Permalink
Browse files

BUG: optimize: make curve_fit work with method as callable. Closes #1531

.
  • Loading branch information...
1 parent c2ff01c commit 2fdf5ce1187d3d9837789eef76fc542bcd554918 @calbaker calbaker committed with rgommers Dec 17, 2011
Showing with 25 additions and 5 deletions.
  1. +8 −4 scipy/optimize/minpack.py
  2. +17 −1 scipy/optimize/tests/test_minpack.py
View
@@ -407,16 +407,20 @@ def curve_fit(f, xdata, ydata, p0=None, sigma=None, **kw):
>>> popt, pcov = curve_fit(func, x, yn)
"""
- if p0 is None or isscalar(p0):
+ if p0 is None:
# determine number of parameters by inspecting the function
import inspect
args, varargs, varkw, defaults = inspect.getargspec(f)
if len(args) < 2:
msg = "Unable to determine number of fit parameters."
raise ValueError(msg)
- if p0 is None:
- p0 = 1.0
- p0 = [p0]*(len(args)-1)
+ if 'self' in args:
+ p0 = [1.0] * (len(args)-2)
+ else:
+ p0 = [1.0] * (len(args)-1)
+
+ if isscalar(p0):
+ p0 = array([p0])
args = (xdata, ydata, f)
if sigma is None:
@@ -221,7 +221,23 @@ def func(x, a, b):
assert_(len(popt) == 2)
assert_(pcov.shape == (2,2))
assert_array_almost_equal(popt, [1.7989, 1.1642], decimal=4)
- assert_array_almost_equal(pcov, [[0.0852, -0.1260],[-0.1260, 0.1912]], decimal=4)
+ assert_array_almost_equal(pcov, [[0.0852, -0.1260],[-0.1260, 0.1912]],
+ decimal=4)
+
+ def test_func_is_classmethod(self):
+ class test_self(object):
+ """This class tests if curve_fit passes the correct number of
+ arguments when the model function is a class instance method.
+ """
+ def func(self, x, a, b):
+ return b * x**a
+
+ test_self_inst = test_self()
+ popt, pcov = curve_fit(test_self_inst.func, self.x, self.y)
+ assert_(pcov.shape == (2,2))
+ assert_array_almost_equal(popt, [1.7989, 1.1642], decimal=4)
+ assert_array_almost_equal(pcov, [[0.0852, -0.1260], [-0.1260, 0.1912]],
+ decimal=4)
class TestFixedPoint(TestCase):

0 comments on commit 2fdf5ce

Please sign in to comment.