From efce0b328c9766456db33c9ec3d6b5ea7be0a7dd Mon Sep 17 00:00:00 2001 From: mgunyho <20118130+mgunyho@users.noreply.github.com> Date: Sun, 4 Jun 2023 20:51:53 +0300 Subject: [PATCH] Rename allow_failures to errors to be consistent with other methods --- xarray/core/dataarray.py | 12 ++++++------ xarray/core/dataset.py | 15 +++++++++------ xarray/tests/test_dataarray.py | 4 ++-- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 8b8bd9fdbd6..15f6ab95d6e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6162,7 +6162,7 @@ def curvefit( p0: dict[str, float | DataArray] | None = None, bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None, param_names: Sequence[str] | None = None, - allow_failures: bool = False, + errors: ErrorOptions = "raise", kwargs: dict[str, Any] | None = None, ) -> Dataset: """ @@ -6207,10 +6207,10 @@ def curvefit( this will be automatically determined by arguments of `func`. `param_names` should be manually supplied when fitting a function that takes a variable number of parameters. - allow_failures: bool, default: False - If True and the underlying `scipy.optimize_curve_fit` optimization fails for - any of the fits, return NaN in coefficients and covariances for those - coordinates. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will + raise an exception. If 'ignore', the coefficients and covariances for the + coordinates where the fitting failed will be NaN. **kwargs : optional Additional keyword arguments to passed to scipy curve_fit. @@ -6315,7 +6315,7 @@ def curvefit( p0=p0, bounds=bounds, param_names=param_names, - allow_failures=allow_failures, + errors=errors, kwargs=kwargs, ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 846000219a7..c4091a0a819 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8631,7 +8631,7 @@ def curvefit( p0: dict[str, float | DataArray] | None = None, bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None, param_names: Sequence[str] | None = None, - allow_failures: bool = False, + errors: ErrorOptions = "raise", kwargs: dict[str, Any] | None = None, ) -> T_Dataset: """ @@ -8676,10 +8676,10 @@ def curvefit( this will be automatically determined by arguments of `func`. `param_names` should be manually supplied when fitting a function that takes a variable number of parameters. - allow_failures: bool, default: False - If True and the underlying `scipy.optimize_curve_fit` optimization fails for - any of the fits, return NaN in coefficients and covariances for those - coordinates. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will + raise an exception. If 'ignore', the coefficients and covariances for the + coordinates where the fitting failed will be NaN. **kwargs : optional Additional keyword arguments to passed to scipy curve_fit. @@ -8762,6 +8762,9 @@ def curvefit( f"dimensions {preserved_dims}." ) + if errors not in ["raise", "ignore"]: + raise ValueError('errors must be either "raise" or "ignore"') + # Broadcast all coords with each other coords_ = broadcast(*coords_) coords_ = [ @@ -8802,7 +8805,7 @@ def _wrapper(Y, *args, **kwargs): try: popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs) except RuntimeError: - if not allow_failures: + if errors == "raise": raise popt = np.full([n_params], np.nan) pcov = np.full([n_params, n_params], np.nan) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index c183a70fc24..368f518abf2 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4573,7 +4573,7 @@ def sine(t, a, f, p): @requires_scipy @pytest.mark.parametrize("use_dask", [True, False]) - def test_curvefit_allow_failures(self, use_dask: bool) -> None: + def test_curvefit_ignore_errors(self, use_dask: bool) -> None: if use_dask and not has_dask: pytest.skip("requires dask") @@ -4606,7 +4606,7 @@ def line(x, a, b): fit = da.curvefit( coords="x", func=line, - allow_failures=True, + errors="ignore", # limit maximum number of calls so the optimization fails kwargs=dict(maxfev=5), )