-
Notifications
You must be signed in to change notification settings - Fork 1
/
proxs.py
executable file
·534 lines (430 loc) · 17.6 KB
/
proxs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
# -*- coding: utf-8 -*-
r"""PROXIMAL OPERATORS.
Defines proximal operators to be fed to ModOpt algorithm that are
specific to MCCD(or rather, not currently in ``modopt.opt.proximity``).
: Authors: Tobias Liaudat <tobiasliaudat@gmail.com>,
Morgan Schmitz <github @MorganSchmitz>
"""
from __future__ import absolute_import, print_function
import numpy as np
from modopt.signal.wavelet import filter_convolve
from modopt.opt.proximity import ProximityParent
import mccd.utils as utils
import tensorflow as tf
from . import saving_unets as unet_model
from . import saving_learnlets as learnlet_model
from mccd.denoising.learnlets.learnlet_model import Learnlet
from sklearn.model_selection import train_test_split
from tensorflow.keras.optimizers import Adam
class LinRecombine(ProximityParent):
r"""Multiply eigenvectors ``S`` and (factorized) weights ``A``.
Maintain the knowledge about the linear operator norm which is calculated
as the spectral norm (highest eigenvalue of the matrix).
The recombination is done with ``S`` living in the tranformed domain.
Parameters
----------
A: numpy.ndarray
Matrix defining the linear operator.
filters: numpy.ndarray
Filters used by the wavelet transform.
compute_norm: bool
Computation of the matrix spectral radius in the initialization.
"""
def __init__(self, A, filters, compute_norm=False):
r"""Initialize class attributes."""
self.A = A
self.op = self.recombine
self.adj_op = self.adj_rec
self.filters = filters
if compute_norm:
U, s, Vt = np.linalg.svd(self.A.dot(self.A.T), full_matrices=False)
self.norm = np.sqrt(s[0])
def recombine(self, transf_S):
r"""Recombine new S and return it."""
S = np.array([filter_convolve(transf_Sj, self.filters, filter_rot=True)
for transf_Sj in transf_S])
return utils.rca_format(S).dot(self.A)
def adj_rec(self, Y):
r"""Return the adjoint operator of ``recombine``."""
return utils.apply_transform(Y.dot(self.A.T), self.filters)
def update_A(self, new_A, update_norm=True):
r"""Update the ``A`` matrix.
Also calculate the operator norm of A.
"""
self.A = new_A
if update_norm:
U, s, Vt = np.linalg.svd(self.A.dot(self.A.T), full_matrices=False)
self.norm = np.sqrt(s[0])
class LinRecombineNoFilters(ProximityParent):
r"""Multiply eigenvectors ``S`` and (factorized) weights ``A``.
Maintain the knowledge about the linear operator norm which is calculated
as the spectral norm (highest eigenvalue of the matrix).
The recombination is done with ``S`` living in the direct domain.
Parameters
----------
A: numpy.ndarray
Matrix defining the linear operator.
compute_norm: bool
Computation of the matrix spectral radius in the initialization.
"""
def __init__(self, A, compute_norm=False):
r"""Initialize class attributes."""
self.A = A
self.op = self.recombine
self.adj_op = self.adj_rec
if compute_norm:
U, s, Vt = np.linalg.svd(self.A.dot(self.A.T), full_matrices=False)
self.norm = np.sqrt(s[0])
def recombine(self, S):
r"""Recombine new S and return it."""
return S.dot(self.A)
def adj_rec(self, Y):
r"""Return the adjoint operator of ``recombine``."""
return Y.dot(self.A.T)
def update_A(self, new_A, update_norm=True):
r"""Update the ``A`` matrix.
Also calculate the operator norm of A.
"""
self.A = new_A
if update_norm:
U, s, Vt = np.linalg.svd(self.A.dot(self.A.T), full_matrices=False)
self.norm = np.sqrt(s[0])
class KThreshold(ProximityParent):
r"""Define linewise hard-thresholding operator with variable thresholds.
Parameters
----------
iter_func: function
Input function that calcultates the number of non-zero values to keep
in each line at each iteration.
"""
def __init__(self, iter_func):
r"""Initialize class attributes."""
self.iter_func = iter_func
self.iter = 0
def reset_iter(self):
r"""Set iteration counter to zero."""
self.iter = 0
def op(self, data, extra_factor=1.0):
r"""Return input data after thresholding."""
self.iter += 1
return utils.lineskthresholding(data, self.iter_func(self.iter,
data.shape[1]))
def cost(self, x):
r"""Return cost.
(Indicator of :math:`\Omega` is either 0 or infinity).
"""
return 0
class StarletThreshold(ProximityParent):
r"""Apply soft thresholding in wavelet(default Starlet) domain.
Parameters
----------
threshold: numpy.ndarray
Threshold levels.
thresh_type: str
Whether soft- or hard-thresholding should be used.
Default is ``'soft'``.
"""
def __init__(self, threshold, thresh_type='soft'):
r"""Initialize class attributes."""
self.threshold = threshold
self._thresh_type = thresh_type
def update_threshold(self, new_threshold, new_thresh_type=None):
r"""Update starlet threshold."""
self.threshold = new_threshold
if new_thresh_type in ['soft', 'hard']:
self._thresh_type = new_thresh_type
def op(self, transf_data, **kwargs):
r"""Apply wavelet transform and perform thresholding."""
# Threshold all scales but the coarse
transf_data[:, :-1] = utils.SoftThresholding(transf_data[:, :-1],
self.threshold[:, :-1])
return transf_data
def cost(self, x, y):
r"""Return cost."""
return 0
class Learnlets(ProximityParent):
r"""Apply Learnlets denoising.
Parameters
----------
model: str
Which denoising algorithm to use.
We couldn't save the whole architecture of the model, thus we use the weights of the model. However, this requires a
first step of initialization that we didn't need for the U-Nets.
"""
def __init__(self, items=None):
r"""Initialize class attributes."""
self.im_shape = (51,51)
# Calculate window function for estimating the noise
# We couldn't use Galsim to estimate the moments, so we chose to work with the real center of the image (25.5,25.5)
# instead of using the real centroid. Also, we use 13 instead of 5*obs_sigma, so that we are sure to cut all the flux
# from the star
self.noise_window = np.ones(self.im_shape, dtype=bool)
for coord_x in range(self.im_shape[0]):
for coord_y in range(self.im_shape[1]):
if np.sqrt((coord_x - 25.5)**2 + (coord_y - 25.5)**2) <= 13 :
self.noise_window[coord_x, coord_y] = False
im_val = tf.convert_to_tensor(np.random.rand(2, self.im_shape[0], self.im_shape[1], 1))
std_val = tf.convert_to_tensor(np.random.rand(2))
run_params = {
'denoising_activation': 'dynamic_soft_thresholding',
'learnlet_analysis_kwargs':{
'n_tiling': 256,
'mixing_details': False,
'skip_connection': True,
},
'learnlet_synthesis_kwargs': {
'res': True,
},
'threshold_kwargs':{
'noise_std_norm': True,
},
# 'wav_type': 'bior',
'n_scales': 5,
'n_reweights_learn': 1,
'clip': False,
}
learnlets = Learnlet(**run_params)
learnlets.compile(
optimizer=Adam(lr=1e-3),
loss='mse',
)
learnlets.fit(
(im_val, std_val),
im_val,
validation_data = ((im_val, std_val), im_val),
steps_per_epoch = 1,
epochs = 1,
batch_size=12,
)
learnlets.load_weights(learnlet_model.__path__[0] + '/cp.h5')
self.model = learnlets
self.noise = None
def mad(self, x):
r"""Compute an estimation of the standard deviation
of a Gaussian distribution using the robust
MAD (Median Absolute Deviation) estimator."""
return 1.4826*np.median(np.abs(x - np.median(x)))
def noise_estimator(self, image):
r"""Estimate the noise level of the image."""
# Calculate noise std dev
return self.mad(image[self.noise_window])
def convert_and_pad(self, image):
r"""Convert images to 64x64x1 shaped tensors to feed the model, using zero-padding."""
image = tf.reshape(
tf.convert_to_tensor(image),
[np.shape(image)[0], np.shape(image)[1], np.shape(image)[2], 1]
)
# pad = tf.constant([[0,0], [6,7],[6,7], [0,0]])
# return tf.pad(image, pad, "CONSTANT")
return image
def crop_and_convert(self, image):
r"""Crop back the image to its original size and convert it to np.array"""
#image = tf.reshape(tf.image.crop_to_bounding_box(image, 6, 6, 51, 51), [np.shape(image)[0], 51, 51])
image = tf.reshape(image, [np.shape(image)[0], 51, 51])
return image.numpy()
def op(self, image, **kwargs):
r"""Apply Learnlets denoising."""
# Threshold all scales but the coarse
image = utils.reg_format(image)
multiple = np.array([np.sum(image[i,:,:])>0 for i in np.arange(len(image))]) * 2. - 1.
image *= multiple.reshape((-1, 1, 1))
self.noise = np.array([self.noise_estimator(image[_i,:,:]) for _i in np.arange(len(image))])
self.noise = tf.reshape(tf.convert_to_tensor(self.noise), [len(image), 1])
image = self.convert_and_pad(image)
image = self.model.predict((image, self.noise))
image = tf.math.multiply(multiple.reshape((-1, 1, 1, 1)), image)
return utils.rca_format(self.crop_and_convert(image))
def cost(self, x, y):
r"""Return cost."""
return 0
class Unets(ProximityParent):
r"""Apply Unets denoising.
Parameters
----------
model: str
Which denoising algorithm to use.
"""
def __init__(self, items=None):
r"""Initialize class attributes."""
self.model = tf.keras.models.load_model(unet_model.__path__[0])
def convert_and_pad(self, image):
r"""Convert images to 64x64x1 shaped tensors to feed the model, using zero-padding."""
image = tf.reshape(tf.convert_to_tensor(image),
[np.shape(image)[0], np.shape(image)[1], np.shape(image)[2], 1])
# pad = tf.constant([[0,0], [6,7],[6,7], [0,0]])
# return tf.pad(image, pad, "CONSTANT")
return image
def crop_and_convert(self, image):
r"""Crop back the image to its original size and convert it to np.array"""
#image = tf.reshape(tf.image.crop_to_bounding_box(image, 6, 6, 51, 51), [np.shape(image)[0], 51, 51])
image = tf.reshape(image, [np.shape(image)[0], 51, 51])
return image.numpy()
def op(self, image, **kwargs):
r"""Apply Unets denoising."""
# Threshold all scales but the coarse
image = utils.reg_format(image)
multiple = np.array([np.sum(image[i,:,:])>0 for i in np.arange(len(image))]) * 2. - 1.
image *= multiple.reshape((-1, 1, 1))
image = self.convert_and_pad(image)
image = self.model.predict(image)
image = tf.math.multiply(multiple.reshape((-1, 1, 1, 1)), image)
return utils.rca_format(self.crop_and_convert(image))
def cost(self, x, y):
r"""Return cost."""
return 0
class proxNormalization(ProximityParent):
r"""Normalize rows or columns of :math:`x` relatively to L2 norm.
Parameters
----------
type: str
String defining the axis to normalize. If is `lines`` or ``columns``.
Default is ``columns``.
"""
def __init__(self, type='columns'):
r"""Initialize class attributes."""
self.op = self.normalize
self.type = type
def normalize(self, x, extra_factor=1.0):
r"""Apply normalization.
Following the prefered type.
"""
# if self.type == 'lines':
# x_norm = np.linalg.norm(x, axis=1).reshape(-1, 1)
# else:
# x_norm = np.linalg.norm(x, axis=0).reshape(1, -1)
# return x / x_norm
return x
# Not using a prox normalization as it is constraining the model
# too strong.
return x
def cost(self, x):
r"""Return cost."""
return 0
class PositityOff(ProximityParent):
r"""Project to the positive subset, taking into acount an offset."""
def __init__(self, offset):
r"""Initialize class attibutes."""
self.offset = offset
self.op = self.off_positive_part
def update_offset(self, new_offset):
r"""Update the offset value."""
self.offset = new_offset
def off_positive_part(self, x, extra_factor=1.0):
r"""Perform the projection accounting for the offset."""
prox_x = np.zeros(x.shape)
pos_idx = (x > - self.offset)
neg_idx = np.array(1 - pos_idx).astype(bool)
prox_x[pos_idx] = x[pos_idx]
prox_x[neg_idx] = - self.offset[neg_idx]
return prox_x
def cost(self, x):
r"""Return cost."""
return 0
class LinRecombineAlpha(ProximityParent):
r"""Compute alpha recombination.
Multiply alpha and VT/Pi matrices (in this function named M) and
compute the operator norm.
"""
def __init__(self, M):
r"""Initialize class attributes."""
self.M = M
self.op = self.recombine
self.adj_op = self.adj_rec
U, s, Vt = np.linalg.svd(self.M.dot(self.M.T), full_matrices=False)
self.norm = np.sqrt(s[0])
def recombine(self, x):
r"""Return recombination."""
return x.dot(self.M)
def adj_rec(self, y):
r"""Return adjoint recombination."""
return y.dot(self.M.T)
class GMCAlikeProxL1(ProximityParent):
"""Classic l1 prox with GMCA-like decreasing weighting values.
GMCA stand for Generalized Morphological Component Analysis.
Parameters
----------
iter_func: function
Input function that calcultates the number of non-zero values to keep
in each line at each iteration.
Notes
-----
Not being used by the MCCD algorithm for the moment.
"""
def __init__(self, iter_func, kmax):
r"""Initialize class attributes."""
self.iter_func = iter_func
self.iter = 0
self.iter_max = kmax
def reset_iter(self):
r"""Set iteration counter to zero."""
self.iter = 0
def op(self, data, extra_factor=1.0):
r"""Return input data after thresholding."""
self.iter += 1
return self.op_tobi_prox_l1(data, self.iter, self.iter_max)
def op_tobi_prox_l1(self, mat, k, kmax):
r"""Apply GMCA hard-thresholding to each line of input matrix."""
mat_out = np.copy(mat)
shap = mat.shape
for j in range(0, shap[0]):
# GMCA-like threshold calculation
line = mat_out[j, :]
idx = np.floor(
len(line) * np.max([0.9 - (k / kmax) * 3, 0.2])).astype(int)
idx_thr = np.argsort(abs(line))[idx]
thresh = abs(line[idx_thr])
# Linear norm_inf decrease
# thresh = np.max(mat_out[j,:])*np.max([0.9-(k/kmax)*3,0.2])
# mat_out[j,:] = utils.SoftThresholding(mat[j,:],thresh)
mat_out[j, :] = self.HardThresholding(mat_out[j, :], thresh)
return mat_out
@staticmethod
def HardThresholding(data, thresh):
r"""Perform element-wise hard thresholding."""
data[data < thresh] = 0.
return data
def cost(self, x):
r"""Cost function. To do."""
return 0
class ClassicProxL2(ProximityParent):
r"""This class defines the classic l2 prox.
Notes
-----
``prox_weights``: Corresponds to the weights of the weighted norm l_{w,2}.
They are set by default to ones. Not being used in this implementation.
``beta_param``: Corresponds to the beta (or lambda) parameter that goes
with the fucn tion we will calculate the prox on prox_{lambda f(.)}(y).
``iter``: Iteration number, just to follow track of the iterations.
It could be part of the lambda update strategy for the prox calculation.
Reference: « Mixed-norm estimates for the M/EEG inverse problem using
accelerated gradient methods
Alexandre Gramfort, Matthieu Kowalski, Matti Hämäläinen »
"""
def __init__(self):
r"""Initialize class attributes."""
self.beta_param = 0
self.iter = 0
def set_beta_param(self, beta_param):
r"""Set ``beta_param``."""
self.beta_param = beta_param
def reset_iter(self):
"""Set iteration counter to zero."""
self.iter = 0
def op(self, data, extra_factor=1.0):
r"""Return input data after thresholding.
The extra factor is the beta_param!
Should be used on the proximal operator function.
"""
self.iter += 1 # not used in this prox
return self.op_tobi_prox_l2(data)
def op_tobi_prox_l2(self, data):
r"""Apply the opterator on the whole data matrix.
for a vector: :math:`x = prox_{lambda || . ||^{2}_{w,2}}(y)`
:math:`=> x_i = y_i /(1 + lambda w_i)`
The operator can be used for the whole data matrix at once.
"""
dividing_weight = 1. + self.beta_param
return data / dividing_weight
def cost(self, x):
r"""Cost function. To do."""
return 0