In [1]:
%cd ..

/home/zaccharie/workspace/fastmri-reproducible-benchmark


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib nbagg
import glob

import h5py
import matplotlib.pyplot as plt
import numpy as np

from fourier import FFT2
from utils import crop_center, gen_mask

from modopt.opt.proximity import SparseThreshold
from mri.numerics.gradient import GradAnalysis2
from mri.numerics.linear import WaveletUD
from mri.numerics.reconstruct import sparse_rec_condatvu


                 .|'''|       /.\      '||'''|,
                 ||          // \\      ||   ||
'||''|, '||  ||` `|'''|,    //...\\     ||...|'
 ||  ||  `|..||   .   ||   //     \\    ||
 ||..|'      ||   |...|' .//       \\. .||
 ||       ,  |'
.||        ''

Package version: 0.0.3

License: CeCILL-B

Authors: 

Antoine Grigis <antoine.grigis@cea.fr>
Samuel Farrens <samuel.farrens@cea.fr>
Jean-Luc Starck <jl.stark@cea.fr>
Philippe Ciuciu <philippe.ciuciu@cea.fr>

Dependencies: 

scipy          : >=0.18.0  - required | 1.3.0     installed
numpy          : >=1.11.0  - required | 1.16.4    installed
matplotlib     : >=2.0.0   - required | 3.1.0     installed
future         : >=0.16.0  - required | 0.17.1    installed
astropy        : ==2.0.8   - required | 2.0.8     installed
nibabel        : >=2.1.0   - required | 2.4.1     installed
pyqtgraph      : >=0.10.0  - required | 0.10.0    installed
progressbar2   : >=3.34.3  - required | ?         installed
modopt         : >=1.1.5   - requi

In [3]:
plt.rcParams['figure.figsize'] = (9, 5)
plt.rcParams['image.cmap'] = 'gray'

In [4]:
fourier_op = FFT2(np.array([1]))

# Loading input data

In [5]:
val_path = '../singlecoil_val/'
filenames = glob.glob(val_path + '*')

In [6]:
def from_file_to_mask_and_kspace(filename):
    h5_obj = h5py.File(filename)
    image = h5_obj['reconstruction_esc'][()]
    kspace = h5_obj['kspace'][()]
    return image, kspace

In [7]:
file_0 = filenames[5]
images,  kspaces = from_file_to_mask_and_kspace(file_0)

In [8]:
slice_selected = 10
kspace = kspaces[slice_selected]
image = images[slice_selected]

In [9]:
kspace.shape

(640, 368)

In [10]:
image.shape

(320, 320)

# Visualizing input data

In [11]:
plt.figure()
plt.imshow(image)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f62565762b0>

In [12]:
plt.figure()
plt.imshow(np.abs(kspace))

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f62564ba470>

In [13]:
plt.figure()
plt.imshow(np.abs(np.fft.ifftshift(np.fft.ifft2(kspace))))

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f62564787b8>

In [14]:
reco = np.abs(fourier_op.adj_op(kspace))
reco = crop_center(reco, 320)

In [15]:
plt.figure()
plt.imshow(reco)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f62564b1b38>

# Mask generation

In [16]:
accel_factor = 8
mask = gen_mask(kspace, accel_factor=accel_factor)
len(mask) / np.sum(mask)

8.177777777777777

# Create the operators

In [17]:
nb_scales = 4
wavelet_id = 24

In [18]:
fourier_mask = np.repeat(mask.astype(np.float)[None, :], kspace.shape[0], axis=0)
fourier_op_masked = FFT2(mask=fourier_mask)
masked_kspace = kspace * fourier_mask

linear_op = WaveletUD(
    nb_scale=nb_scales,
    wavelet_id=wavelet_id,
    set_norm=1.1,
)

gradient_op = GradAnalysis2(
    data=masked_kspace,
    fourier_op=fourier_op_masked)

# Define the proximity dual/primal operator
prox_op = SparseThreshold(linear_op, None, thresh_type="soft")
cost_op = None



In [19]:
plt.figure()
plt.imshow(np.abs(fourier_mask))

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f62563f7978>

In [20]:
def objective_cost(x):
    return data_fidelity(x) + sparsity(x)

def sparsity(x):
    return mu * np.sum(np.abs(linear_op.op(x)))

def data_fidelity(x):
    return 0.5 * np.linalg.norm(fourier_op_masked.op(x) - masked_kspace)**2

In [21]:
def nrmse(x):
    im = crop_center(x, 320)
    return np.linalg.norm(np.abs(im) - image) / np.mean(image)

def psnr(x):
    im = crop_center(x, 320)
    mse = np.mean((np.abs(im) - image)**2)
    p = 10 * np.log10(np.max(image)**2 / mse)
#     import pdb; pdb.set_trace()
    return p

In [22]:
metrics_ = {
    "cost": {"metric": objective_cost, "mapping": {"x_new": "x"}, "cst_kwargs": {}, "early_stopping": False},
    "nrmse": {"metric": nrmse, "mapping": {"x_new": "x"}, "cst_kwargs": {}, "early_stopping": False},
    "psnr": {"metric": psnr, "mapping": {"x_new": "x"}, "cst_kwargs": {}, "early_stopping": False},
}

In [23]:
mu = 1e-9

In [24]:
max_iter = 10

In [25]:
# Start the Condat reconstruction
cost_op = None
x_final, _, _, metrics = sparse_rec_condatvu(
    gradient_op,
    linear_op,
    prox_op,
    cost_op,
    std_est_method=None,
    std_est=0.1,
    mu=mu,
    sigma=10,
    nb_of_reweights=0,
    max_nb_of_iter=max_iter,
    metrics=metrics_,
    metric_call_period=1,
    verbose=1,
)


   ____   U  ___ u  _   _    ____       _       _____      __     __    _   _
U /"___|   \/"_ \/ | \ |"|  |  _"\  U  /"\  u  |_ " _|     \ \   /"/uU |"|u| |
\| | u     | | | |<|  \| |>/| | | |  \/ _ \/     | |        \ \ / //  \| |\| |
 | |/__.-,_| |_| |U| |\  |uU| |_| |\ / ___ \    /| |\       /\ V /_,-. | |_| |
  \____|\_)-\___/  |_| \_|  |____/ u/_/   \_\  u |_|U      U  \_/-(_/ <<\___/
 _// \\      \\    ||   \\,-.|||_    \\    >>  _// \\_       //      (__) )(
(__)(__)    (__)   (_")  (_/(__)_)  (__)  (__)(__) (__)     (__)         (__)
    
 - mu:  1e-09
 - lipschitz constant:  1.1
 - tau:  0.07905138333671825
 - sigma:  10
 - rho:  1.0
 - std:  0.1
 - 1/tau - sigma||L||^2 >= beta/2:  True
 - data:  (640, 368)
 - wavelet:  <mri.reconstruct.linear.WaveletUD object at 0x7f6256449400> - 4
 - max iterations:  10
 - number of reweights:  0
 - primal variable shape:  (640, 368)
 - dual variable shape:  (10, 640, 368)
----------------------------------------
Starting optimization...


100% (10 of 10) |########################| Elapsed Time: 0:00:26 Time:  0:00:26


 - converged:  False
Done.
Execution time:  35.570751  seconds
----------------------------------------


In [26]:
plt.figure()
plt.imshow(np.abs(x_final))

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f625130dbe0>

In [27]:
plt.figure()
plt.plot(np.array(metrics['cost']['values']))
plt.title('Final cost {}'.format(metrics['cost']['values'][-1]))

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'Final cost 1.2525651996368082e-05')

In [28]:
plt.figure()
plt.plot(np.array(metrics['psnr']['values']))
plt.title('Final psnr {}'.format(metrics['psnr']['values'][-1]))

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'Final psnr 21.309001025858592')

In [29]:
plt.figure()
plt.imshow(np.abs(fourier_op_masked.adj_op(masked_kspace)))

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f6255a6cdd8>