Skip to content

Commit fc8122a

Browse files
ENH add gap safe screening rules to enet_coordinate_descent_gram (#31987)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent e714369 commit fc8122a

File tree

12 files changed

+264
-117
lines changed

12 files changed

+264
-117
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
- :class:`sklearn.covariance.GraphicalLasso`,
2+
:class:`sklearn.covariance.GraphicalLassoCV` and
3+
:func:`sklearn.covariance.graphical_lasso` with `mode="cd"` profit from the
4+
fit time performance improvement of :class:`sklearn.linear_model.Lasso` by means of
5+
gap safe screening rules.
6+
By :user:`Christian Lorentzen <lorentzenchr>`.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
- Fixed uncontrollable randomness in :class:`sklearn.covariance.GraphicalLasso`,
2+
:class:`sklearn.covariance.GraphicalLassoCV` and
3+
:func:`sklearn.covariance.graphical_lasso`. For `mode="cd"`, they now use cyclic
4+
coordinate descent. Before, it was random coordinate descent with uncontrollable
5+
random number seeding.
6+
By :user:`Christian Lorentzen <lorentzenchr>`.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
- :class:`sklearn.decomposition.DictionaryLearning` and
2+
:class:`sklearn.decomposition.MiniBatchDictionaryLearning` with `fit_algorithm="cd"`,
3+
:class:`sklearn.decomposition.SparseCoder` with `transform_algorithm="lasso_cd"`,
4+
:class:`sklearn.decomposition.MiniBatchSparsePCA`,
5+
:class:`sklearn.decomposition.SparsePCA`,
6+
:func:`sklearn.decomposition.dict_learning` and
7+
:func:`sklearn.decomposition.dict_learning_online` with `method="cd"`,
8+
:func:`sklearn.decomposition.sparse_encode` with `algorithm="lasso_cd"`
9+
all profit from the fit time performance improvement of
10+
:class:`sklearn.linear_model.Lasso` by means of gap safe screening rules.
11+
By :user:`Christian Lorentzen <lorentzenchr>`.

doc/whats_new/upcoming_changes/sklearn.linear_model/32014.efficiency.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
:class:`linear_model.MultiTaskElasticNetCV`, :class:`linear_model.MultiTaskLassoCV`
44
as well as
55
:func:`linear_model.lasso_path` and :func:`linear_model.enet_path` now implement
6-
gap safe screening rules in the coordinate descent solver for dense `X` (with
7-
`precompute=False` or `"auto"` with `n_samples < n_features`) and sparse `X`
8-
(always).
6+
gap safe screening rules in the coordinate descent solver for dense and sparse `X`.
97
The speedup of fitting time is particularly pronounced (10-times is possible) when
108
computing regularization paths like the \*CV-variants of the above estimators do.
119
There is now an additional check of the stopping criterion before entering the main
1210
loop of descent steps. As the stopping criterion requires the computation of the dual
1311
gap, the screening happens whenever the dual gap is computed.
14-
By :user:`Christian Lorentzen <lorentzenchr>` :pr:`31882`, :pr:`31986` and
12+
By :user:`Christian Lorentzen <lorentzenchr>` :pr:`31882`, :pr:`31986`,
13+
:pr:`31987` and

sklearn/covariance/_graph_lasso.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,23 @@ def _graphical_lasso(
138138
/ (precision_[idx, idx] + 1000 * eps)
139139
)
140140
coefs, _, _, _ = cd_fast.enet_coordinate_descent_gram(
141-
coefs,
142-
alpha,
143-
0,
144-
sub_covariance,
145-
row,
146-
row,
147-
max_iter,
148-
enet_tol,
149-
check_random_state(None),
150-
False,
141+
w=coefs,
142+
alpha=alpha,
143+
beta=0,
144+
Q=sub_covariance,
145+
q=row,
146+
y=row,
147+
# TODO: It is not ideal that the max_iter of the outer
148+
# solver (graphical lasso) is coupled with the max_iter of
149+
# the inner solver (CD). Ideally, CD has its own parameter
150+
# enet_max_iter (like enet_tol). A minimum of 20 is rather
151+
# arbitrary, but not unreasonable.
152+
max_iter=max(20, max_iter),
153+
tol=enet_tol,
154+
rng=check_random_state(None),
155+
random=False,
156+
positive=False,
157+
do_screening=True,
151158
)
152159
else: # mode == "lars"
153160
_, _, coefs = lars_path_gram(

sklearn/covariance/tests/test_graphical_lasso.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,12 @@
2525
)
2626

2727

28-
def test_graphical_lassos(random_state=1):
29-
"""Test the graphical lasso solvers.
30-
31-
This checks is unstable for some random seeds where the covariance found with "cd"
32-
and "lars" solvers are different (4 cases / 100 tries).
33-
"""
28+
def test_graphical_lassos(global_random_seed):
29+
"""Test the graphical lasso solvers."""
3430
# Sample data from a sparse multivariate normal
35-
dim = 20
31+
dim = 10
3632
n_samples = 100
37-
random_state = check_random_state(random_state)
33+
random_state = check_random_state(global_random_seed)
3834
prec = make_sparse_spd_matrix(dim, alpha=0.95, random_state=random_state)
3935
cov = linalg.inv(prec)
4036
X = random_state.multivariate_normal(np.zeros(dim), cov, size=n_samples)
@@ -45,24 +41,29 @@ def test_graphical_lassos(random_state=1):
4541
icovs = dict()
4642
for method in ("cd", "lars"):
4743
cov_, icov_, costs = graphical_lasso(
48-
emp_cov, return_costs=True, alpha=alpha, mode=method
44+
emp_cov,
45+
return_costs=True,
46+
alpha=alpha,
47+
mode=method,
48+
tol=1e-7,
49+
enet_tol=1e-11,
50+
max_iter=100,
4951
)
5052
covs[method] = cov_
5153
icovs[method] = icov_
5254
costs, dual_gap = np.array(costs).T
5355
# Check that the costs always decrease (doesn't hold if alpha == 0)
5456
if not alpha == 0:
55-
# use 1e-12 since the cost can be exactly 0
56-
assert_array_less(np.diff(costs), 1e-12)
57+
# use 1e-10 since the cost can be exactly 0
58+
assert_array_less(np.diff(costs), 1e-10)
5759
# Check that the 2 approaches give similar results
58-
assert_allclose(covs["cd"], covs["lars"], atol=5e-4)
59-
assert_allclose(icovs["cd"], icovs["lars"], atol=5e-4)
60+
assert_allclose(covs["cd"], covs["lars"], atol=1e-3)
61+
assert_allclose(icovs["cd"], icovs["lars"], atol=1e-3)
6062

6163
# Smoke test the estimator
62-
model = GraphicalLasso(alpha=0.25).fit(X)
64+
model = GraphicalLasso(alpha=0.25, tol=1e-7, enet_tol=1e-11, max_iter=100).fit(X)
6365
model.score(X)
64-
assert_array_almost_equal(model.covariance_, covs["cd"], decimal=4)
65-
assert_array_almost_equal(model.covariance_, covs["lars"], decimal=4)
66+
assert_allclose(model.covariance_, covs["cd"], rtol=1e-6)
6667

6768
# For a centered matrix, assume_centered could be chosen True or False
6869
# Check that this returns indeed the same result for centered data
@@ -87,6 +88,7 @@ def test_graphical_lasso_when_alpha_equals_0(global_random_seed):
8788

8889

8990
@pytest.mark.parametrize("mode", ["cd", "lars"])
91+
@pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning")
9092
def test_graphical_lasso_n_iter(mode):
9193
X, _ = datasets.make_classification(n_samples=5_000, n_features=20, random_state=0)
9294
emp_cov = empirical_covariance(X)
@@ -138,12 +140,25 @@ def test_graph_lasso_2D():
138140
assert_array_almost_equal(icov, icov_skggm)
139141

140142

141-
def test_graphical_lasso_iris_singular():
143+
@pytest.mark.parametrize("method", ["cd", "lars"])
144+
def test_graphical_lasso_iris_singular(method):
142145
# Small subset of rows to test the rank-deficient case
143146
# Need to choose samples such that none of the variances are zero
144147
indices = np.arange(10, 13)
145148

146149
# Hard-coded solution from R glasso package for alpha=0.01
150+
# library(glasso)
151+
# X = t(array(c(
152+
# 5.4, 3.7, 1.5, 0.2,
153+
# 4.8, 3.4, 1.6, 0.2,
154+
# 4.8, 3. , 1.4, 0.1),
155+
# dim = c(4, 3)
156+
# ))
157+
# n = nrow(X)
158+
# emp_cov = cov(X) * (n - 1)/n # without Bessel correction
159+
# sol = glasso(emp_cov, 0.01, penalize.diagonal = FALSE)
160+
# # print cov_R
161+
# print(noquote(format(sol$w, scientific=FALSE, digits = 10)))
147162
cov_R = np.array(
148163
[
149164
[0.08, 0.056666662595, 0.00229729713223, 0.00153153142149],
@@ -162,12 +177,9 @@ def test_graphical_lasso_iris_singular():
162177
)
163178
X = datasets.load_iris().data[indices, :]
164179
emp_cov = empirical_covariance(X)
165-
for method in ("cd", "lars"):
166-
cov, icov = graphical_lasso(
167-
emp_cov, alpha=0.01, return_costs=False, mode=method
168-
)
169-
assert_array_almost_equal(cov, cov_R, decimal=5)
170-
assert_array_almost_equal(icov, icov_R, decimal=5)
180+
cov, icov = graphical_lasso(emp_cov, alpha=0.01, return_costs=False, mode=method)
181+
assert_allclose(cov, cov_R, atol=1e-6)
182+
assert_allclose(icov, icov_R, atol=1e-5)
171183

172184

173185
def test_graphical_lasso_cv(global_random_seed):

sklearn/decomposition/_dict_learning.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def _sparse_encode_precomputed(
146146
alpha=alpha,
147147
fit_intercept=False,
148148
precompute=gram,
149+
tol=1e-8, # TODO: This parameter should be exposed.
149150
max_iter=max_iter,
150151
warm_start=True,
151152
positive=positive,

sklearn/decomposition/tests/test_dict_learning.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def ricker_matrix(width, resolution, n_components):
8989
return D
9090

9191
transform_algorithm = "lasso_cd"
92-
resolution = 1024
92+
resolution = 256
9393
subsampling = 3 # subsampling factor
9494
n_components = resolution // subsampling
9595

@@ -99,7 +99,7 @@ def ricker_matrix(width, resolution, n_components):
9999
ricker_matrix(
100100
width=w, resolution=resolution, n_components=n_components // 5
101101
)
102-
for w in (10, 50, 100, 500, 1000)
102+
for w in (10, 50, 100, 500)
103103
)
104104
]
105105

@@ -120,7 +120,7 @@ def ricker_matrix(width, resolution, n_components):
120120
with warnings.catch_warnings():
121121
warnings.simplefilter("error", ConvergenceWarning)
122122
model = SparseCoder(
123-
D_multi, transform_algorithm=transform_algorithm, transform_max_iter=2000
123+
D_multi, transform_algorithm=transform_algorithm, transform_max_iter=500
124124
)
125125
model.fit_transform(X)
126126

@@ -864,7 +864,7 @@ def test_dict_learning_dtype_match(data_type, expected_type, method):
864864
@pytest.mark.parametrize("method", ("lars", "cd"))
865865
def test_dict_learning_numerical_consistency(method):
866866
# verify numerically consistent among np.float32 and np.float64
867-
rtol = 1e-6
867+
rtol = 1e-4
868868
n_components = 4
869869
alpha = 2
870870

sklearn/decomposition/tests/test_sparse_pca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_fit_transform(global_random_seed):
7171
n_components=3, method="cd", random_state=global_random_seed, alpha=alpha
7272
)
7373
spca_lasso.fit(Y)
74-
assert_array_almost_equal(spca_lasso.components_, spca_lars.components_)
74+
assert_allclose(spca_lasso.components_, spca_lars.components_, rtol=5e-4)
7575

7676

7777
# TODO: remove mark once loky bug is fixed:
@@ -117,7 +117,7 @@ def test_fit_transform_tall(global_random_seed):
117117
U1 = spca_lars.fit_transform(Y)
118118
spca_lasso = SparsePCA(n_components=3, method="cd", random_state=rng)
119119
U2 = spca_lasso.fit(Y).transform(Y)
120-
assert_array_almost_equal(U1, U2)
120+
assert_allclose(U1, U2, rtol=1e-4, atol=1e-5)
121121

122122

123123
def test_initialization(global_random_seed):

0 commit comments

Comments
 (0)