Skip to content

Commit

Permalink
Rename allow_failures to errors to be consistent with other methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mgunyho committed Jun 4, 2023
1 parent d3decbb commit efce0b3
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
12 changes: 6 additions & 6 deletions xarray/core/dataarray.py
Expand Up @@ -6162,7 +6162,7 @@ def curvefit(
p0: dict[str, float | DataArray] | None = None,
bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None,
param_names: Sequence[str] | None = None,
allow_failures: bool = False,
errors: ErrorOptions = "raise",
kwargs: dict[str, Any] | None = None,
) -> Dataset:
"""
Expand Down Expand Up @@ -6207,10 +6207,10 @@ def curvefit(
this will be automatically determined by arguments of `func`. `param_names`
should be manually supplied when fitting a function that takes a variable
number of parameters.
allow_failures: bool, default: False
If True and the underlying `scipy.optimize_curve_fit` optimization fails for
any of the fits, return NaN in coefficients and covariances for those
coordinates.
errors : {"raise", "ignore"}, default: "raise"
If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will
raise an exception. If 'ignore', the coefficients and covariances for the
coordinates where the fitting failed will be NaN.
**kwargs : optional
Additional keyword arguments to passed to scipy curve_fit.
Expand Down Expand Up @@ -6315,7 +6315,7 @@ def curvefit(
p0=p0,
bounds=bounds,
param_names=param_names,
allow_failures=allow_failures,
errors=errors,
kwargs=kwargs,
)

Expand Down
15 changes: 9 additions & 6 deletions xarray/core/dataset.py
Expand Up @@ -8631,7 +8631,7 @@ def curvefit(
p0: dict[str, float | DataArray] | None = None,
bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None,
param_names: Sequence[str] | None = None,
allow_failures: bool = False,
errors: ErrorOptions = "raise",
kwargs: dict[str, Any] | None = None,
) -> T_Dataset:
"""
Expand Down Expand Up @@ -8676,10 +8676,10 @@ def curvefit(
this will be automatically determined by arguments of `func`. `param_names`
should be manually supplied when fitting a function that takes a variable
number of parameters.
allow_failures: bool, default: False
If True and the underlying `scipy.optimize_curve_fit` optimization fails for
any of the fits, return NaN in coefficients and covariances for those
coordinates.
errors : {"raise", "ignore"}, default: "raise"
If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will
raise an exception. If 'ignore', the coefficients and covariances for the
coordinates where the fitting failed will be NaN.
**kwargs : optional
Additional keyword arguments to passed to scipy curve_fit.
Expand Down Expand Up @@ -8762,6 +8762,9 @@ def curvefit(
f"dimensions {preserved_dims}."
)

if errors not in ["raise", "ignore"]:
raise ValueError('errors must be either "raise" or "ignore"')

# Broadcast all coords with each other
coords_ = broadcast(*coords_)
coords_ = [
Expand Down Expand Up @@ -8802,7 +8805,7 @@ def _wrapper(Y, *args, **kwargs):
try:
popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs)
except RuntimeError:
if not allow_failures:
if errors == "raise":
raise
popt = np.full([n_params], np.nan)
pcov = np.full([n_params, n_params], np.nan)
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_dataarray.py
Expand Up @@ -4573,7 +4573,7 @@ def sine(t, a, f, p):

@requires_scipy
@pytest.mark.parametrize("use_dask", [True, False])
def test_curvefit_allow_failures(self, use_dask: bool) -> None:
def test_curvefit_ignore_errors(self, use_dask: bool) -> None:
if use_dask and not has_dask:
pytest.skip("requires dask")

Expand Down Expand Up @@ -4606,7 +4606,7 @@ def line(x, a, b):
fit = da.curvefit(
coords="x",
func=line,
allow_failures=True,
errors="ignore",
# limit maximum number of calls so the optimization fails
kwargs=dict(maxfev=5),
)
Expand Down

0 comments on commit efce0b3

Please sign in to comment.