Permalink
Browse files

BUG: support full_output in optimize.curve_fit. Closes #1415.

  • Loading branch information...
1 parent 3243256 commit 24485988985dd6ab0e9947f9af7f34f0e32d4628 @rgommers rgommers committed Jun 4, 2011
Showing with 17 additions and 5 deletions.
  1. +12 −5 scipy/optimize/minpack.py
  2. +5 −0 scipy/optimize/tests/test_minpack.py
@@ -375,7 +375,6 @@ def curve_fit(f, xdata, ydata, p0=None, sigma=None, **kw):
This vector, if given, will be used as weights in the
least-squares problem.
-
Returns
-------
popt : array
@@ -385,11 +384,14 @@ def curve_fit(f, xdata, ydata, p0=None, sigma=None, **kw):
The estimated covariance of popt. The diagonals provide the variance
of the parameter estimate.
+ See Also
+ --------
+ leastsq
+
Notes
-----
- The algorithm uses the Levenburg-Marquardt algorithm:
- scipy.optimize.leastsq. Additional keyword arguments are passed directly
- to that algorithm.
+ The algorithm uses the Levenburg-Marquardt algorithm through `leastsq`.
+ Additional keyword arguments are passed directly to that algorithm.
Examples
--------
@@ -423,6 +425,8 @@ def curve_fit(f, xdata, ydata, p0=None, sigma=None, **kw):
func = _weighted_general_function
args += (1.0/asarray(sigma),)
+ # Remove full_output from kw, otherwise we're passing it in twice.
+ return_full = kw.pop('full_output', False)
res = leastsq(func, p0, args=args, full_output=1, **kw)
(popt, pcov, infodict, errmsg, ier) = res
@@ -436,7 +440,10 @@ def curve_fit(f, xdata, ydata, p0=None, sigma=None, **kw):
else:
pcov = inf
- return popt, pcov
+ if return_full:
+ return popt, pcov, infodict, errmsg, ier
+ else:
+ return popt, pcov
def check_gradient(fcn, Dfcn, x0, args=(), col_deriv=0):
"""Perform a simple check on the gradient for correctness.
@@ -209,6 +209,11 @@ def func(x,a):
assert_almost_equal(popt[0], 1.9149, decimal=4)
assert_almost_equal(pcov[0,0], 0.0016, decimal=4)
+ # Test if we get the same with full_output. Regression test for #1415.
+ res = curve_fit(func, self.x, self.y, full_output=1)
+ (popt2, pcov2, infodict, errmsg, ier) = res
+ assert_array_almost_equal(popt, popt2)
+
def test_two_argument(self):
def func(x, a, b):
return b*x**a

0 comments on commit 2448598

Please sign in to comment.