Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sag handle numerical error outside of cython #13389

Copy path View file
@@ -191,6 +191,11 @@ Support for Python 3.4 and below has been officially dropped.
:mod:`sklearn.linear_model`
...........................

- |Fix| Fixed a performance issue of ``saga`` and ``sag`` solvers when called
in a :class:`joblib.Parallel` setting with ``n_jobs > 1`` and
``backend="threading"``, causing them to perform worse than in the sequential
case. :issue:`13389` by :user:`Pierre Glaser <pierreglaser>`.

This conversation was marked as resolved by ogrisel

This comment has been minimized.

Copy link
@jeremiedbb

jeremiedbb Mar 4, 2019

Contributor

should be PR number

- |Feature| :class:`linear_model.LogisticRegression` and
:class:`linear_model.LogisticRegressionCV` now support Elastic-Net penalty,
with the 'saga' solver. :issue:`11646` by :user:`Nicolas Hug <NicolasHug>`.
@@ -59,12 +59,6 @@ from ..utils.seq_dataset cimport SequentialDataset32, SequentialDataset64

from libc.stdio cimport printf

cdef void raise_infinite_error(int n_iter):
raise ValueError("Floating-point under-/overflow occurred at "
"epoch #%d. Lowering the step_size or "
"scaling the input data with StandardScaler "
"or MinMaxScaler might help." % (n_iter + 1))



{{for name, c_type, np_type in get_dispatch(dtypes)}}
@@ -349,6 +343,9 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
# the scalar used for multiplying z
cdef {{c_type}} wscale = 1.0

# return value (-1 if an error occurred, 0 otherwise)
cdef int status = 0

# the cumulative sums for each iteration for the sparse implementation
cumulative_sums[0] = 0.0

@@ -402,16 +399,19 @@ def sag{{name}}(SequentialDataset{{name}} dataset,

# make the weight updates
if sample_itr > 0:
lagged_update{{name}}(weights, wscale, xnnz,
n_samples, n_classes, sample_itr,
cumulative_sums,
cumulative_sums_prox,
feature_hist,
prox,
sum_gradient,
x_ind_ptr,
False,
n_iter)
status = lagged_update{{name}}(weights, wscale, xnnz,
n_samples, n_classes,
sample_itr,
cumulative_sums,
cumulative_sums_prox,
feature_hist,
prox,
sum_gradient,
x_ind_ptr,
False,
n_iter)
if status == -1:
break

# find the current prediction
predict_sample{{name}}(x_data_ptr, x_ind_ptr, xnnz, weights, wscale,
@@ -460,8 +460,12 @@ def sag{{name}}(SequentialDataset{{name}} dataset,

# check to see that the intercept is not inf or NaN
if not skl_isfinite{{name}}(intercept[class_ind]):
with gil:
raise_infinite_error(n_iter)
status = -1
break
# Break from the n_samples outer loop if an error happened
# in the fit_intercept n_classes inner loop
if status == -1:
break

# update the gradient memory for this sample
for class_ind in range(n_classes):
@@ -484,21 +488,32 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
if verbose:
with gil:
print("rescaling...")
wscale = scale_weights{{name}}(
weights, wscale, n_features, n_samples, n_classes,
status = scale_weights{{name}}(
weights, &wscale, n_features, n_samples, n_classes,
sample_itr, cumulative_sums,
cumulative_sums_prox,
feature_hist,
prox, sum_gradient, n_iter)
if status == -1:
break

# Break from the n_iter outer loop if an error happened in the
# n_samples inner loop
if status == -1:
break

# we scale the weights every n_samples iterations and reset the
# just-in-time update system for numerical stability.
wscale = scale_weights{{name}}(weights, wscale, n_features, n_samples,
n_classes, n_samples - 1, cumulative_sums,
cumulative_sums_prox,
feature_hist,
prox, sum_gradient, n_iter)

status = scale_weights{{name}}(weights, &wscale, n_features,
n_samples,
n_classes, n_samples - 1,
cumulative_sums,
cumulative_sums_prox,
feature_hist,
prox, sum_gradient, n_iter)

if status == -1:
break
# check if the stopping criteria is reached
max_change = 0.0
max_weight = 0.0
@@ -520,6 +535,12 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
printf('Epoch %d, change: %.8f\n', n_iter + 1,
max_change / max_weight)
n_iter += 1
# We do the error treatment here based on error code in status to avoid
# re-acquiring the GIL within the cython code, which slows the computation.
This conversation was marked as resolved by ogrisel

This comment has been minimized.

Copy link
@ogrisel

ogrisel Mar 6, 2019

Member

... when the sag/saga solver is used concurrently in multiple Python threads.

if status == -1:
raise ValueError(("Floating-point under-/overflow occurred at epoch"
" #%d. Scaling input data with StandardScaler or"
" MinMaxScaler might help.") % (n_iter + 1))
This conversation was marked as resolved by ogrisel

This comment has been minimized.

Copy link
@ogrisel

ogrisel Mar 6, 2019

Member

There is already n_iter += 1 at line 537. I don't think we need n_iter + 1 here.


if verbose and n_iter >= max_iter:
end_time = time(NULL)
@@ -533,14 +554,15 @@ def sag{{name}}(SequentialDataset{{name}} dataset,

{{for name, c_type, np_type in get_dispatch(dtypes)}}

cdef {{c_type}} scale_weights{{name}}({{c_type}}* weights, {{c_type}} wscale, int n_features,
int n_samples, int n_classes, int sample_itr,
{{c_type}}* cumulative_sums,
{{c_type}}* cumulative_sums_prox,
int* feature_hist,
bint prox,
{{c_type}}* sum_gradient,
int n_iter) nogil:
cdef int scale_weights{{name}}({{c_type}}* weights, {{c_type}}* wscale,
int n_features,
int n_samples, int n_classes, int sample_itr,
{{c_type}}* cumulative_sums,
{{c_type}}* cumulative_sums_prox,
int* feature_hist,
bint prox,
{{c_type}}* sum_gradient,
int n_iter) nogil:
"""Scale the weights with wscale for numerical stability.

wscale = (1 - step_size * alpha) ** (n_iter * n_samples + sample_itr)
@@ -550,34 +572,37 @@ cdef {{c_type}} scale_weights{{name}}({{c_type}}* weights, {{c_type}} wscale, in
This also limits the size of `cumulative_sums`.
"""

lagged_update{{name}}(weights, wscale, n_features,
n_samples, n_classes, sample_itr + 1,
cumulative_sums,
cumulative_sums_prox,
feature_hist,
prox,
sum_gradient,
NULL,
True,
n_iter)
# reset wscale to 1.0
return 1.0
cdef int status
status = lagged_update{{name}}(weights, wscale[0], n_features,
n_samples, n_classes, sample_itr + 1,
cumulative_sums,
cumulative_sums_prox,
feature_hist,
prox,
sum_gradient,
NULL,
True,
n_iter)
# if lagged update succeeded, reset wscale to 1.0
if status == 0:
wscale[0] = 1.0
return status

{{endfor}}


{{for name, c_type, np_type in get_dispatch(dtypes)}}

cdef void lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
int n_samples, int n_classes, int sample_itr,
{{c_type}}* cumulative_sums,
{{c_type}}* cumulative_sums_prox,
int* feature_hist,
bint prox,
{{c_type}}* sum_gradient,
int* x_ind_ptr,
bint reset,
int n_iter) nogil:
cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
int n_samples, int n_classes, int sample_itr,
{{c_type}}* cumulative_sums,
{{c_type}}* cumulative_sums_prox,
int* feature_hist,
bint prox,
{{c_type}}* sum_gradient,
int* x_ind_ptr,
bint reset,
int n_iter) nogil:
"""Hard perform the JIT updates for non-zero features of present sample.
The updates that awaits are kept in memory using cumulative_sums,
cumulative_sums_prox, wscale and feature_hist. See original SAGA paper
@@ -605,8 +630,9 @@ cdef void lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz
if reset:
weights[idx] *= wscale
if not skl_isfinite{{name}}(weights[idx]):
with gil:
raise_infinite_error(n_iter)
# returning here does not require the gil as the return
# type is a C integer
return -1
else:
for class_ind in range(n_classes):
idx = f_idx + class_ind
@@ -640,8 +666,7 @@ cdef void lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz
weights[idx] *= wscale
# check to see that the weight is not inf or NaN
if not skl_isfinite{{name}}(weights[idx]):
with gil:
raise_infinite_error(n_iter)
return -1
if reset:
feature_hist[feature_ind] = sample_itr % n_samples
else:
@@ -652,6 +677,8 @@ cdef void lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz
if prox:
cumulative_sums_prox[sample_itr - 1] = 0.0

return 0

{{endfor}}


@@ -24,7 +24,7 @@
from sklearn.utils import compute_class_weight
from sklearn.utils import check_random_state
from sklearn.preprocessing import LabelEncoder, LabelBinarizer
from sklearn.datasets import make_blobs, load_iris
from sklearn.datasets import make_blobs, load_iris, make_classification
from sklearn.base import clone

iris = load_iris()
@@ -826,3 +826,24 @@ def test_multinomial_loss_ground_truth():
[-0.903942, +5.258745, -4.354803]])
assert_almost_equal(loss_1, loss_gt)
assert_array_almost_equal(grad_1, grad_gt)


@pytest.mark.parametrize("solver", ["sag", "saga"])
def test_sag_classifier_raises_error(solver):
# Following #13316, the error handling behavior changed in cython sag. This
# is simply a non-regression test to make sure numerical errors are
# properly raised.

# Train a classifier on a simple problem
rng = np.random.RandomState(42)
X, y = make_classification(random_state=rng)
clf = LogisticRegression(solver=solver, random_state=rng, warm_start=True)
clf.fit(X, y)
This conversation was marked as resolved by ogrisel

This comment has been minimized.

Copy link
@ogrisel

ogrisel Mar 6, 2019

Member

Maybe record clf.n_iter_ here an check below that "at epoch #%d" % clf.n_iter_ + 1 appears in the error message.

This comment has been minimized.

Copy link
@ogrisel

ogrisel Mar 6, 2019

Member

actually the reported n_iter in the error message raised by the cython code when warm start is enabled cannot be correct (as in consistent with clf.n_iter_) because of the way the Cython code is structured. I don't want to make mix concerns in this PR, let's keep simple.


# Trigger a numerical error by:
# - corrupting the fitted coefficients of the classifier
# - fit it again starting from its current state thanks to warm_start
clf.coef_[:] = np.nan

with pytest.raises(ValueError, match="Floating-point under-/overflow"):
clf.fit(X, y)
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.