Skip to content

Loading…

BUG: make curve_fit() work with array_like input. Closes gh-3037. #3166

Merged
merged 1 commit into from

5 participants

@rgommers
SciPy member

No description provided.

@coveralls

Coverage Status

Coverage remained the same when pulling 856e396 on rgommers:curvefit-list-input into 34ae412 on scipy:master.

@dlax
SciPy member

Looks good to me. Does it work for other people involved in the issue discussion? @josef-pkt

@josef-pkt
SciPy member

It still looks dangerous to me, I would maybe take out the tuple and only convert list (if I would convert anything at all).

I don't have any code using curve_fit (except script to try it out while looking at issues for it).
The only one I remember recently from stackoverflow is xdata = (x1, x2) and tuple unpacking inside the user function, which wouldn't break with this change, just make an unnecessary array conversion.

I also don't know if anyone using [x1, x2] actually would like column_stack([x1, x2]) instead of row stacking.

(I would just change the starting values to 0.999 and let the user deal with the exceptions in his/her code.
and maybe go for deprecation default starting values as has been suggested.)

Changing this is just guessing on what users might be doing. I'm just pointing out some cases, and have no idea what different users are actually doing. which essentially means I'm -0 on this.

@pv
SciPy member
pv commented

@josef-pkt: a better way than 0.999 is to replace params with params.tolist() in _general_function and _weighed_general_function. Just changing the value doesn't help because np.float64(0.999) * [1,2,3] == [].

However, this doesn't solve the issue that in the most common use case where x is meant to be an array, passing in list doesn't work.

@josef-pkt
SciPy member

Just changing the value doesn't help because np.float64(0.999) * [1,2,3] == [].

I think it would avoid the original case where curve_fit returned wrong results instead of raising an exception.

I agree that in most cases users that use a list actually want an array. I'm less convinced for the tuple case.
Also if we convert only list, then users can still easily pass in other "things", which was the main part of my original objection.

executive decision required: guessing what users want (which will be correct in most cases) or being more liberal with exceptions

@rgommers
SciPy member

Raising exceptions instead of returning wrong/unexpected results is also not acceptable I think. This should just work, like for any other function with array_like input.

Anything else, like passing in random objects that the user-provided function then knows how to handle, is nice but of secondary importance imho.

@josef-pkt
SciPy member

Still we are just guessing on the usage.

Reminds me that curve_fit originally only allowed functions and broke with a method attached to a class (in inspect).
Which also reminds me that there is a way to get arbitrary content into the user function, either by using a class or reusing things from the outer scope like in nested functions. Which means there is another work around if users don't want the array conversion.

@rgommers
SciPy member

Besides common sense, there's a bunch of confused users on Stackoverflow. I wouldn't call that just guessing.

For anything that is not a list/tuple, there is still no array conversion after merging this PR. So no workaround needed, even though those are possible.

@pv
SciPy member
pv commented

Well, the documentation states "xdata : An N-length sequence or an (k,N)-shaped array". I believe it would be fine with an unconditional xdata = np.asanyarray(xdata). (Although, people, might pass in pandas dataframes, so just dealing with lists and tuples is probably OK, although quite ugly.)

@rgommers
SciPy member

That would also be OK with me.

@josef-pkt
SciPy member

Ok I'm +0, since there are easy ways to work around.

@dlax
SciPy member

So the consensus is on "unconditional asanyarray"?

@rgommers
SciPy member

I'm OK either way but still think the current PR makes the most sense, to not break any current off-label usage (classes, pandas dataframes, etc.). That can be documented and tested explicitly if it helps this PR go in.

@dlax dlax merged commit 3982a90 into scipy:master

1 check passed

Details default The Travis CI build passed
@rgommers rgommers deleted the rgommers:curvefit-list-input branch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Showing with 18 additions and 1 deletion.
  1. +9 −1 scipy/optimize/minpack.py
  2. +9 −0 scipy/optimize/tests/test_minpack.py
View
10 scipy/optimize/minpack.py
@@ -3,6 +3,7 @@
import warnings
from . import _minpack
+import numpy as np
from numpy import (atleast_1d, dot, take, triu, shape, eye,
transpose, zeros, product, greater, array,
all, where, isscalar, asarray, inf, abs,
@@ -532,15 +533,22 @@ def curve_fit(f, xdata, ydata, p0=None, sigma=None, absolute_sigma=False, **kw):
else:
p0 = [1.0] * (len(args)-1)
+ # Check input arguments
if isscalar(p0):
p0 = array([p0])
+ ydata = np.asanyarray(ydata)
+ if isinstance(xdata, (list, tuple)):
+ # `xdata` is passed straight to the user-defined `f`, so allow
+ # non-array_like `xdata`.
+ xdata = np.asarray(xdata)
+
args = (xdata, ydata, f)
if sigma is None:
func = _general_function
else:
func = _weighted_general_function
- args += (1.0/asarray(sigma),)
+ args += (1.0 / asarray(sigma),)
# Remove full_output from kw, otherwise we're passing it in twice.
return_full = kw.pop('full_output', False)
View
9 scipy/optimize/tests/test_minpack.py
@@ -375,6 +375,15 @@ def f_flat(x, a, b):
assert_(pcov.shape == (2, 2))
assert_array_equal(pcov, np.inf)
+ def test_array_like(self):
+ # Test sequence input. Regression test for gh-3037.
+ def f_linear(x, a, b):
+ return a*x + b
+
+ x = [1, 2, 3, 4]
+ y = [2, 4, 6, 8]
+ assert_allclose(curve_fit(f_linear, x, y)[0], [2, 0], atol=1e-10)
+
class TestFixedPoint(TestCase):
Something went wrong with that request. Please try again.