/
block_separable.py
384 lines (306 loc) · 13 KB
/
block_separable.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
import numpy as np
from numpy.linalg import norm
from numba import float64, int32
from numba.types import bool_
from skglm.penalties.base import BasePenalty
from skglm.utils.prox_funcs import (
BST, prox_block_2_05, prox_SCAD, value_SCAD, prox_MCP, value_MCP)
class L2_1(BasePenalty):
"""L2/1 row-wise penalty: sum of L2 norms of rows."""
def __init__(self, alpha):
self.alpha = alpha
def get_spec(self):
spec = (
('alpha', float64),
)
return spec
def params_to_dict(self):
return dict(alpha=self.alpha)
def value(self, W):
"""Compute the L2/1 penalty value."""
return self.alpha * np.sqrt(np.sum(W ** 2, axis=1)).sum()
def prox_1feat(self, value, stepsize, j):
"""Compute proximal operator of the L2/1 penalty (block soft thresholding)."""
return BST(value, self.alpha * stepsize)
def subdiff_distance(self, W, grad, ws):
"""Compute distance of negative gradient to the subdifferential at W."""
subdiff_dist = np.zeros_like(ws, dtype=grad.dtype)
for idx, j in enumerate(ws):
if not np.any(W[j, :]):
# distance of - grad_j to alpha * the unit l2 ball
norm_grad_j = norm(grad[idx, :])
subdiff_dist[idx] = max(0, norm_grad_j - self.alpha)
else:
# distance of -grad_j to alpha * W[j] / norm(W[j])
subdiff_dist[idx] = norm(
grad[idx, :]
+ self.alpha * W[j, :] / norm(W[j, :]))
return subdiff_dist
def is_penalized(self, n_features):
"""Return a binary mask with the penalized features."""
return np.ones(n_features, dtype=np.bool_)
class L2_05(BasePenalty):
"""L2/0.5 row-wise penalty: sum of square roots of L2 norms of rows."""
def __init__(self, alpha):
self.alpha = alpha
def get_spec(self):
spec = (
('alpha', float64),
)
return spec
def params_to_dict(self):
return dict(alpha=self.alpha)
def value(self, W):
"""Compute the value of L2/0.5 at w."""
W_rows_norm = np.zeros(W.shape[0])
for i in range(W.shape[0]):
W_rows_norm[i] = norm(W[i])
return self.alpha * np.sum(np.sqrt(W_rows_norm))
def prox_1feat(self, value, stepsize, j):
"""Compute the proximal operator of L2/0.5."""
return prox_block_2_05(value, self.alpha * stepsize)
def subdiff_distance(self, W, grad, ws):
"""Compute distance of negative gradient to the subdifferential at W."""
subdiff_dist = np.zeros_like(ws, dtype=grad.dtype)
for idx, j in enumerate(ws):
if not np.any(W[j, :]):
subdiff_dist[idx] = 0.
else:
subdiff_dist[idx] = norm(
grad[idx, :] + self.alpha * W[j, :] / (2 * norm(W[j, :])**(3./2.))
)
return subdiff_dist
def is_penalized(self, n_features):
"""Return a binary mask with the penalized features."""
return np.ones(n_features, dtype=np.bool_)
class BlockMCPenalty(BasePenalty):
"""Block Minimax Concave Penalty.
Notes
-----
With W_j the j-th row of W, the penalty is:
pen(||W_j||) = alpha * ||W_j|| - ||W_j||^2 / (2 * gamma)
if ||W_j|| =< gamma * alpha
= gamma * alpha ** 2 / 2
if ||W_j|| > gamma * alpha
value = sum_{j=1}^{n_features} pen(||W_j||)
"""
def __init__(self, alpha, gamma):
self.alpha = alpha
self.gamma = gamma
def get_spec(self):
spec = (
('alpha', float64),
('gamma', float64),
)
return spec
def params_to_dict(self):
return dict(alpha=self.alpha,
gamma=self.gamma)
def value(self, W):
"""Compute the value of BlockMCP at W."""
norm_rows = np.sqrt(np.sum(W ** 2, axis=1))
return value_MCP(norm_rows, self.alpha, self.gamma)
def prox_1feat(self, value, stepsize, j):
"""Compute the proximal operator of BlockMCP."""
norm_rows = norm(value)
prox = prox_MCP(norm_rows, stepsize, self.alpha, self.gamma)
return prox * value / norm_rows
def subdiff_distance(self, W, grad, ws):
"""Compute distance of negative gradient to the subdifferential at W."""
subdiff_dist = np.zeros_like(ws, dtype=grad.dtype)
for idx, j in enumerate(ws):
norm_Wj = norm(W[j])
if not np.any(W[j]):
# distance of -grad_j to alpha * unit ball
norm_grad_j = norm(grad[idx])
subdiff_dist[idx] = max(0, norm_grad_j - self.alpha)
elif norm_Wj < self.alpha * self.gamma:
# distance of -grad_j to alpha * W[j] / ||W_j|| - W[j] / gamma
subdiff_dist[idx] = norm(
grad[idx] + self.alpha * W[j]/norm_Wj - W[j] / self.gamma)
else:
# distance of -grad to 0
subdiff_dist[idx] = norm(grad[idx])
return subdiff_dist
def is_penalized(self, n_features):
"""Return a binary mask with the penalized features."""
return np.ones(n_features, dtype=np.bool_)
class BlockSCAD(BasePenalty):
r"""Block Smoothly Clipped Absolute Deviation.
Notes
-----
With :math:`W_j` the j-th row of math:`W`, the penalty is:
.. math::
"pen"(||W_j||) = {
(alpha ||W_j|| , if \ \ \ \ \ \ \ \ \ \ ||W_j|| <= alpha),
((2 alpha gamma ||W_j|| - ||W_j||^2 - alpha^2) / (2 (gamma - 1))
, if alpha \ \ < ||W_j|| <= alpha gamma),
((alpha^2 (gamma + 1)) / 2, if alpha gamma < ||W_j||)
:}
.. math::
"value" = sum_(j=1)^(n_"features") "pen"(||W_j||)
"""
def __init__(self, alpha, gamma):
self.alpha = alpha
self.gamma = gamma
def get_spec(self):
spec = (
('alpha', float64),
('gamma', float64),
)
return spec
def params_to_dict(self):
return dict(alpha=self.alpha,
gamma=self.gamma)
def value(self, W):
"""Compute the value of the SCAD penalty at W."""
norm_rows = np.sqrt(np.sum(W ** 2, axis=1))
return value_SCAD(norm_rows, self.alpha, self.gamma)
def prox_1feat(self, value, stepsize, j):
"""Compute the proximal operator of BlockSCAD."""
norm_value = norm(value)
prox = prox_SCAD(norm_value, stepsize, self.alpha, self.gamma)
return prox * value / norm_value
def subdiff_distance(self, W, grad, ws):
"""Compute distance of negative gradient to the subdifferential at W."""
subdiff_dist = np.zeros_like(ws, dtype=grad.dtype)
for idx, j in enumerate(ws):
norm_Wj = norm(W[j])
if not np.any(W[j]):
# distance of -grad_j to alpha * unit_ball
subdiff_dist[idx] = max(0, norm(grad[idx]) - self.alpha)
elif norm_Wj <= self.alpha:
# distance of -grad_j to alpha * W[j] / ||W[j]||
subdiff_dist[idx] = norm(grad[idx] + self.alpha * W[j] / norm_Wj)
elif norm_Wj <= self.gamma * self.alpha:
# distance of -grad_j to (alpha * gamma - ||W[j]||)
# / ((gamma - 1) * ||W[j]||) * W[j]
subdiff_dist[idx] = norm(grad[idx] + (
(self.alpha * self.gamma - norm_Wj) / (norm_Wj * (self.gamma - 1))
) * W[j])
else:
# distance of -grad_j to 0
subdiff_dist[idx] = norm(grad[idx])
return subdiff_dist
def is_penalized(self, n_features):
"""Return a binary mask with the penalized features."""
return np.ones(n_features, dtype=np.bool_)
class WeightedGroupL2(BasePenalty):
r"""Weighted Group L2 penalty.
The penalty reads
.. math::
sum_{g=1}^{n_"groups"} "weights"_g xx ||w_{[g]}||
with :math:`w_{[g]}` being the coefficients of the g-th group.
When ``positive=True``, it reads
.. math::
sum_{g=1}^{n_"groups"} "weights"_g xx ||w_{[g]}|| + i_{w_{[g]} \geq 0}
Where :math:`i_{w_{[g]} \geq 0}` is the indicator function of the positive orthant.
Refer to :ref:`prox_nn_group_lasso` for details on the derivation of the proximal
operator and the distance to subdifferential.
Attributes
----------
alpha : float
The regularization parameter.
weights : array, shape (n_groups,)
The weights of the groups.
grp_indices : array, shape (n_features,)
The group indices stacked contiguously
([grp1_indices, grp2_indices, ...]).
grp_ptr : array, shape (n_groups + 1,)
The group pointers such that two consecutive elements delimit
the indices of a group in ``grp_indices``.
positive : bool, optional
When set to ``True``, forces the coefficient vector to be positive.
"""
def __init__(self, alpha, weights, grp_ptr, grp_indices, positive=False):
self.alpha, self.weights = alpha, weights
self.grp_ptr, self.grp_indices = grp_ptr, grp_indices
self.positive = positive
def get_spec(self):
spec = (
('alpha', float64),
('weights', float64[:]),
('grp_ptr', int32[:]),
('grp_indices', int32[:]),
('positive', bool_),
)
return spec
def params_to_dict(self):
return dict(alpha=self.alpha, weights=self.weights,
grp_ptr=self.grp_ptr, grp_indices=self.grp_indices,
positive=self.positive)
def value(self, w):
"""Value of penalty at vector ``w``."""
if self.positive and np.any(w < 0):
return np.inf
alpha, weights = self.alpha, self.weights
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
n_grp = len(grp_ptr) - 1
sum_weighted_L2 = 0.
for g in range(n_grp):
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
w_g = w[grp_g_indices]
sum_weighted_L2 += alpha * weights[g] * norm(w_g)
return sum_weighted_L2
def prox_1group(self, value, stepsize, g):
"""Compute the proximal operator of group ``g``."""
return BST(
value, self.alpha * stepsize * self.weights[g], positive=self.positive)
def subdiff_distance(self, w, grad_ws, ws):
"""Compute distance to the subdifferential at ``w`` of negative gradient.
Refer to :ref:`subdiff_positive_group_lasso` for details of the derivation.
Note:
----
``grad_ws`` is a stacked array of gradients ``[grad_ws_1, grad_ws_2, ...]``.
"""
alpha, weights = self.alpha, self.weights
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
scores = np.zeros(len(ws))
grad_ptr = 0
for idx, g in enumerate(ws):
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
grad_g = grad_ws[grad_ptr: grad_ptr + len(grp_g_indices)]
grad_ptr += len(grp_g_indices)
w_g = w[grp_g_indices]
norm_w_g = norm(w_g)
if self.positive:
if norm_w_g == 0:
# distance of -neg_grad_g to weights[g] * [-alpha, alpha]
neg_grad_g = grad_g[grad_g < 0.]
scores[idx] = max(0,
norm(neg_grad_g) - self.alpha * weights[g])
elif np.any(w_g < 0):
scores[idx] = np.inf
else:
res = np.zeros_like(grad_g)
for j in range(len(w_g)):
thresh = alpha * weights[g] * w_g[j] / norm_w_g
if w_g[j] > 0:
res[j] = -grad_g[j] - thresh
else:
# thresh is 0, we simplify the expression
res[j] = max(-grad_g[j], 0)
scores[idx] = norm(res)
else:
if norm_w_g == 0:
scores[idx] = max(0, norm(grad_g) - alpha * weights[g])
else:
# distance of -grad_g to the subdiff (here a singleton)
subdiff = alpha * weights[g] * w_g / norm_w_g
scores[idx] = norm(grad_g + subdiff)
return scores
def is_penalized(self, n_groups):
return np.ones(n_groups, dtype=np.bool_)
def generalized_support(self, w):
grp_indices, grp_ptr = self.grp_indices, self.grp_ptr
n_groups = len(grp_ptr) - 1
is_penalized = self.is_penalized(n_groups)
gsupp = np.zeros(n_groups, dtype=np.bool_)
for g in range(n_groups):
if not is_penalized[g]:
gsupp[g] = True
continue
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
if np.any(w[grp_g_indices]):
gsupp[g] = True
return gsupp