Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sag handle numerical error outside of cython #13389

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ Support for Python 3.4 and below has been officially dropped.
:mod:`sklearn.linear_model`
...........................

- |Fix| Fixed a performance issue of ``saga`` and ``sag`` solvers when called
in a :class:`joblib.Parallel` setting with ``n_jobs > 1`` and
``backend="threading"``, causing them to perform worse than in the sequential
case. :issue:`13389` by :user:`Pierre Glaser <pierreglaser>`.

ogrisel marked this conversation as resolved.
Show resolved Hide resolved
- |Feature| :class:`linear_model.LogisticRegression` and
:class:`linear_model.LogisticRegressionCV` now support Elastic-Net penalty,
with the 'saga' solver. :issue:`11646` by :user:`Nicolas Hug <NicolasHug>`.
Expand Down
147 changes: 87 additions & 60 deletions sklearn/linear_model/sag_fast.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,6 @@ from ..utils.seq_dataset cimport SequentialDataset32, SequentialDataset64

from libc.stdio cimport printf

cdef void raise_infinite_error(int n_iter):
raise ValueError("Floating-point under-/overflow occurred at "
"epoch #%d. Lowering the step_size or "
"scaling the input data with StandardScaler "
"or MinMaxScaler might help." % (n_iter + 1))



{{for name, c_type, np_type in get_dispatch(dtypes)}}
Expand Down Expand Up @@ -349,6 +343,9 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
# the scalar used for multiplying z
cdef {{c_type}} wscale = 1.0

# return value (-1 if an error occurred, 0 otherwise)
cdef int status = 0

# the cumulative sums for each iteration for the sparse implementation
cumulative_sums[0] = 0.0

Expand Down Expand Up @@ -402,16 +399,19 @@ def sag{{name}}(SequentialDataset{{name}} dataset,

# make the weight updates
if sample_itr > 0:
lagged_update{{name}}(weights, wscale, xnnz,
n_samples, n_classes, sample_itr,
cumulative_sums,
cumulative_sums_prox,
feature_hist,
prox,
sum_gradient,
x_ind_ptr,
False,
n_iter)
status = lagged_update{{name}}(weights, wscale, xnnz,
n_samples, n_classes,
sample_itr,
cumulative_sums,
cumulative_sums_prox,
feature_hist,
prox,
sum_gradient,
x_ind_ptr,
False,
n_iter)
if status == -1:
break

# find the current prediction
predict_sample{{name}}(x_data_ptr, x_ind_ptr, xnnz, weights, wscale,
Expand Down Expand Up @@ -460,8 +460,12 @@ def sag{{name}}(SequentialDataset{{name}} dataset,

# check to see that the intercept is not inf or NaN
if not skl_isfinite{{name}}(intercept[class_ind]):
with gil:
raise_infinite_error(n_iter)
status = -1
break
# Break from the n_samples outer loop if an error happened
# in the fit_intercept n_classes inner loop
if status == -1:
break

# update the gradient memory for this sample
for class_ind in range(n_classes):
Expand All @@ -484,21 +488,32 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
if verbose:
with gil:
print("rescaling...")
wscale = scale_weights{{name}}(
weights, wscale, n_features, n_samples, n_classes,
status = scale_weights{{name}}(
weights, &wscale, n_features, n_samples, n_classes,
sample_itr, cumulative_sums,
cumulative_sums_prox,
feature_hist,
prox, sum_gradient, n_iter)
if status == -1:
break

# Break from the n_iter outer loop if an error happened in the
# n_samples inner loop
if status == -1:
break

# we scale the weights every n_samples iterations and reset the
# just-in-time update system for numerical stability.
wscale = scale_weights{{name}}(weights, wscale, n_features, n_samples,
n_classes, n_samples - 1, cumulative_sums,
cumulative_sums_prox,
feature_hist,
prox, sum_gradient, n_iter)

status = scale_weights{{name}}(weights, &wscale, n_features,
n_samples,
n_classes, n_samples - 1,
cumulative_sums,
cumulative_sums_prox,
feature_hist,
prox, sum_gradient, n_iter)

if status == -1:
break
# check if the stopping criteria is reached
max_change = 0.0
max_weight = 0.0
Expand All @@ -520,6 +535,12 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
printf('Epoch %d, change: %.8f\n', n_iter + 1,
max_change / max_weight)
n_iter += 1
# We do the error treatment here based on error code in status to avoid
# re-acquiring the GIL within the cython code, which slows the computation.
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
if status == -1:
raise ValueError(("Floating-point under-/overflow occurred at epoch"
" #%d. Scaling input data with StandardScaler or"
" MinMaxScaler might help.") % (n_iter + 1))
ogrisel marked this conversation as resolved.
Show resolved Hide resolved

if verbose and n_iter >= max_iter:
end_time = time(NULL)
Expand All @@ -533,14 +554,15 @@ def sag{{name}}(SequentialDataset{{name}} dataset,

{{for name, c_type, np_type in get_dispatch(dtypes)}}

cdef {{c_type}} scale_weights{{name}}({{c_type}}* weights, {{c_type}} wscale, int n_features,
int n_samples, int n_classes, int sample_itr,
{{c_type}}* cumulative_sums,
{{c_type}}* cumulative_sums_prox,
int* feature_hist,
bint prox,
{{c_type}}* sum_gradient,
int n_iter) nogil:
cdef int scale_weights{{name}}({{c_type}}* weights, {{c_type}}* wscale,
int n_features,
int n_samples, int n_classes, int sample_itr,
{{c_type}}* cumulative_sums,
{{c_type}}* cumulative_sums_prox,
int* feature_hist,
bint prox,
{{c_type}}* sum_gradient,
int n_iter) nogil:
"""Scale the weights with wscale for numerical stability.

wscale = (1 - step_size * alpha) ** (n_iter * n_samples + sample_itr)
Expand All @@ -550,34 +572,37 @@ cdef {{c_type}} scale_weights{{name}}({{c_type}}* weights, {{c_type}} wscale, in
This also limits the size of `cumulative_sums`.
"""

lagged_update{{name}}(weights, wscale, n_features,
n_samples, n_classes, sample_itr + 1,
cumulative_sums,
cumulative_sums_prox,
feature_hist,
prox,
sum_gradient,
NULL,
True,
n_iter)
# reset wscale to 1.0
return 1.0
cdef int status
status = lagged_update{{name}}(weights, wscale[0], n_features,
n_samples, n_classes, sample_itr + 1,
cumulative_sums,
cumulative_sums_prox,
feature_hist,
prox,
sum_gradient,
NULL,
True,
n_iter)
# if lagged update succeeded, reset wscale to 1.0
if status == 0:
wscale[0] = 1.0
return status

{{endfor}}


{{for name, c_type, np_type in get_dispatch(dtypes)}}

cdef void lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
int n_samples, int n_classes, int sample_itr,
{{c_type}}* cumulative_sums,
{{c_type}}* cumulative_sums_prox,
int* feature_hist,
bint prox,
{{c_type}}* sum_gradient,
int* x_ind_ptr,
bint reset,
int n_iter) nogil:
cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
int n_samples, int n_classes, int sample_itr,
{{c_type}}* cumulative_sums,
{{c_type}}* cumulative_sums_prox,
int* feature_hist,
bint prox,
{{c_type}}* sum_gradient,
int* x_ind_ptr,
bint reset,
int n_iter) nogil:
"""Hard perform the JIT updates for non-zero features of present sample.
The updates that awaits are kept in memory using cumulative_sums,
cumulative_sums_prox, wscale and feature_hist. See original SAGA paper
Expand Down Expand Up @@ -605,8 +630,9 @@ cdef void lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz
if reset:
weights[idx] *= wscale
if not skl_isfinite{{name}}(weights[idx]):
with gil:
raise_infinite_error(n_iter)
# returning here does not require the gil as the return
# type is a C integer
return -1
else:
for class_ind in range(n_classes):
idx = f_idx + class_ind
Expand Down Expand Up @@ -640,8 +666,7 @@ cdef void lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz
weights[idx] *= wscale
# check to see that the weight is not inf or NaN
if not skl_isfinite{{name}}(weights[idx]):
with gil:
raise_infinite_error(n_iter)
return -1
if reset:
feature_hist[feature_ind] = sample_itr % n_samples
else:
Expand All @@ -652,6 +677,8 @@ cdef void lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz
if prox:
cumulative_sums_prox[sample_itr - 1] = 0.0

return 0

{{endfor}}


Expand Down
24 changes: 23 additions & 1 deletion sklearn/linear_model/tests/test_sag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sklearn.utils import compute_class_weight
from sklearn.utils import check_random_state
from sklearn.preprocessing import LabelEncoder, LabelBinarizer
from sklearn.datasets import make_blobs, load_iris
from sklearn.datasets import make_blobs, load_iris, make_classification
from sklearn.base import clone

iris = load_iris()
Expand Down Expand Up @@ -826,3 +826,25 @@ def test_multinomial_loss_ground_truth():
[-0.903942, +5.258745, -4.354803]])
assert_almost_equal(loss_1, loss_gt)
assert_array_almost_equal(grad_1, grad_gt)


@pytest.mark.parametrize("solver", ["sag", "saga"])
def test_sag_classifier_raises_error(solver):
# Following #13316, the sag cython function does not raise any error by
# itself if a numerical problem (under-/overflow, nans...) occurs. Instead,
# it notices its caller, sag_solver, with a return code of -1. The caller
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
# is responsible for raising an error correclty.

# Train a classifier on a simple problem
rng = np.random.RandomState(42)
X, y = make_classification(random_state=rng)
clf = LogisticRegression(solver=solver, random_state=rng, warm_start=True)
clf.fit(X, y)
ogrisel marked this conversation as resolved.
Show resolved Hide resolved

# Trigger a numerical error by:
# - corrupting the fitted coefficients of the classifier
# - fit it again starting from its current state thanks to warm_start
clf.coef_[:] = np.nan

with pytest.raises(ValueError, match="Floating-point under-/overflow"):
clf.fit(X, y)