The idea of this notebook is to provide one plot showing that greedy faasta (combining Gale and Jingwei's ideas) is still slower than Condat for an analysis reconstruction problem with undecimated wavelets.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib nbagg

# Third party import
import matplotlib.pyplot as plt
import numpy as np
import scipy.io
from tqdm import tqdm_notebook

# Package import
from modopt.math.metrics import ssim
from modopt.opt.linear import Identity
from modopt.opt.proximity import SparseThreshold, LinearCompositionProx, LinearCompositionIterativeProx
from mri.numerics.fourier import FFT2, NFFT
from mri.numerics.gradient import GradAnalysis2
from mri.numerics.linear import Wavelet2, WaveletUD
from mri.numerics.reconstruct import sparse_rec_fista, sparse_rec_condatvu
from mri.numerics.utils import convert_mask_to_locations
import pysap
from pysap.data import get_sample_data

np.random.seed(0)


                 .|'''|       /.\      '||'''|,
                 ||          // \\      ||   ||
'||''|, '||  ||` `|'''|,    //...\\     ||...|'
 ||  ||  `|..||   .   ||   //     \\    ||
 ||..|'      ||   |...|' .//       \\. .||
 ||       ,  |'
.||        ''

Package version: 0.0.3

License: CeCILL-B

Authors: 

Antoine Grigis <antoine.grigis@cea.fr>
Samuel Farrens <samuel.farrens@cea.fr>
Jean-Luc Starck <jl.stark@cea.fr>
Philippe Ciuciu <philippe.ciuciu@cea.fr>

Dependencies: 

scipy          : >=0.18.0  - required | 1.3.0     installed
numpy          : >=1.11.0  - required | 1.16.3    installed
matplotlib     : >=2.0.0   - required | 3.1.0     installed
future         : >=0.16.0  - required | 0.17.1    installed
astropy        : ==2.0.8   - required | 2.0.8     installed
nibabel        : >=2.1.0   - required | 2.4.0     installed
pyqtgraph      : >=0.10.0  - required | 0.10.0    installed
progressbar2   : >=3.34.3  - required | ?         installed
modopt         : >=1.1.5   - requi

#  Loading input data

In [2]:
image = get_sample_data("mri-slice-nifti")
original_image_data = np.copy(image.data)
image.data += np.random.randn(*image.shape) * 20.
image.data = image.data.astype(np.complex128)
mask = get_sample_data("mri-mask")

In [3]:
non_cartesian = True
sparkling = True

# Generate the kspace 
From the 2D brain slice and the acquistion mask, we generate the acquisition measurments, the observed kspace. We then reconstruct the zero order solution.

In [4]:
# potential path to sparkling trajectory
sparkling_traj_file_path = '../personal_experiments/2019-Mar-01_N512_nc34_ns3073_OS1_decim64_decay2_tau0.75_nrevol1/samples_SPARKLING_N512_nc34x3073_OS1.mat'

In [5]:
# Generate the subsampled kspace
if non_cartesian:
    if sparkling:
        kspace_loc = scipy.io.loadmat(sparkling_traj_file_path)['samples']
    else:
        kspace_loc = convert_mask_to_locations(mask.data)
    fourier_op = NFFT(samples=kspace_loc, shape=image.shape)
    kspace_data = fourier_op.op(image.data)

    # Zero order solution
    image_rec0 = pysap.Image(data=fourier_op.adj_op(kspace_data), metadata=image.metadata)
else:
    kspace_mask = np.fft.ifftshift(mask.data)
    kspace_loc = convert_mask_to_locations(kspace_mask)
    fourier_op = FFT2(samples=kspace_loc, shape=image.shape)
    kspace_data = fourier_op.op(image.data)

    # Zero order solution
    image_rec0 = pysap.Image(data=fourier_op.adj_op(kspace_data), metadata=image.metadata)



# Operators

In [6]:
nb_scales = 4
wavelet_id = 2

In [7]:
linear_op = WaveletUD(
    nb_scale=nb_scales,
    wavelet_id=wavelet_id,
    coarse=False,
    set_norm=1.1,
)

gradient_op = GradAnalysis2(
    data=kspace_data,
    fourier_op=fourier_op)

prox_op_iterative = LinearCompositionIterativeProx(
    linear_op=linear_op,
    prox_op=SparseThreshold(Identity(), None, thresh_type="soft"),
    max_precision_level=150,
    solver_sigma=10.0,
)

prox_op_condat = SparseThreshold(linear_op, None, thresh_type="soft")

cost_op = None



# FAASTA optimization
We now want to refine the zero order solution using a FAASTA optimization.

In [8]:
def objective_cost(x):
    return data_fidelity(x) + sparsity(x)

def sparsity(x):
    return mu * np.sum(np.abs(linear_op.op(x)))

def data_fidelity(x):
    return 0.5 * np.linalg.norm(fourier_op.op(x) - kspace_data)**2

In [9]:
def nrmse(x):
    return np.linalg.norm(x - original_image_data) / np.mean(original_image_data)

In [10]:
metrics_ = {
    "cost": {"metric": objective_cost, "mapping": {"x_new": "x"}, "cst_kwargs": {}, "early_stopping": False},
    "nrmse": {"metric": nrmse, "mapping": {"x_new": "x"}, "cst_kwargs": {}, "early_stopping": False},
}

In [11]:
mu = 15.0

In [12]:
max_iter = 1000

In [13]:
# Start the FISTA reconstruction
cost_op = None

_, _, _, metrics_iterative = sparse_rec_fista(
    gradient_op=gradient_op,
    linear_op=Identity(),
    prox_op=prox_op_iterative,
    cost_op=cost_op,
    xi_restart=0.96,
    s_greedy=1.1,
    mu=mu,
    restart_strategy='greedy',
    pov='analysis',
    iterative_prox=True,
    max_nb_of_iter=max_iter,
    metrics=metrics_,
    metric_call_period=1,
    verbose=1,
    adaptative_precision=True,
    precision_increase_rate=1.01,
    initial_precision_level=5,
)


  _____             ____     _____      _
 |" ___|    ___    / __"| u |_ " _| U  /"\  u
U| |_  u   |_"_|  <\___ \/    | |    \/ _ \/
\|  _|/     | |    u___) |   /| |\   / ___ \\
 |_|      U/| |\u  |____/>> u |_|U  /_/   \_\\
 )(\\\,-.-,_|___|_,-.)(  (__)_// \\\_  \\\    >>
(__)(_/ \_)-' '-(_/(__)    (__) (__)(__)  (__)
    
 - mu:  15.0
 - lipschitz constant:  95.5108226538341
 - data:  (512, 512)
 - max iterations:  1000
 - image variable shape:  (512, 512)
 - alpha variable shape:  (512, 512)
----------------------------------------
Starting optimization...






























100% (1000 of 1000) |####################| Elapsed Time: 0:42:27 Time:  0:42:27


 - converged:  False
Done.
Execution time:  21170.337991  seconds
----------------------------------------


In [14]:
# Start the Condat reconstruction
cost_op = None
_, _, _, metrics_condat = sparse_rec_condatvu(
    gradient_op,
    linear_op,
    prox_op_condat,
    cost_op,
    std_est_method='dual',
    std_est=0.1,
    mu=mu,
    sigma=10,
    nb_of_reweights=0,
    max_nb_of_iter=max_iter,
    metrics=metrics_,
    metric_call_period=1,
    verbose=1,
)


   ____   U  ___ u  _   _    ____       _       _____      __     __    _   _
U /"___|   \/"_ \/ | \ |"|  |  _"\  U  /"\  u  |_ " _|     \ \   /"/uU |"|u| |
\| | u     | | | |<|  \| |>/| | | |  \/ _ \/     | |        \ \ / //  \| |\| |
 | |/__.-,_| |_| |U| |\  |uU| |_| |\ / ___ \    /| |\       /\ V /_,-. | |_| |
  \____|\_)-\___/  |_| \_|  |____/ u/_/   \_\  u |_|U      U  \_/-(_/ <<\___/
 _// \\      \\    ||   \\,-.|||_    \\    >>  _// \\_       //      (__) )(
(__)(__)    (__)   (_")  (_/(__)_)  (__)  (__)(__) (__)     (__)         (__)
    
 - mu:  15.0
 - lipschitz constant:  95.5108226538341
 - tau:  0.016706927204478663
 - sigma:  10
 - rho:  1.0
 - std:  0.1
 - 1/tau - sigma||L||^2 >= beta/2:  True
 - data:  (512, 512)
 - wavelet:  <mri.reconstruct.linear.WaveletUD object at 0x7f0e3ad220f0> - 4
 - max iterations:  1000
 - number of reweights:  0
 - primal variable shape:  (512, 512)
 - dual variable shape:  (3, 512, 512)
----------------------------------------
Starting opti

100% (1000 of 1000) |####################| Elapsed Time: 0:12:44 Time:  0:12:44


 - converged:  False
Done.
Execution time:  16185.141667  seconds
----------------------------------------


In [15]:
opt_metrics = {metric_name: metric_val['values'][-1] for metric_name, metric_val in metrics_condat.items()}

In [20]:
metric_name = 'cost'
plt.figure(figsize=(9, 5))
plt.plot(
#     np.log10(np.abs(np.array(metrics_condat[metric_name]['values'] - opt_metrics[metric_name]))), 
    np.log10(np.abs(np.array(metrics_condat[metric_name]['values']))), 
    label='Condat', 
)
plt.plot(
#     np.log10(np.abs(np.array(metrics_iterative[metric_name]['values'] - opt_metrics[metric_name]))), 
    np.log10(np.abs(np.array(metrics_iterative[metric_name]['values']))), 
    label='greedy FAASTA', 
)
plt.legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7f0f20394e10>

In [18]:
metrics_iterative['nrmse']['values'][-1]

155.29523659835667