# MRI Reconstruction: a comparison between classical and learning approaches

In this notebook we will compare 2 techniques for reconstructing MRI images:
- the classical, wavelet-based, iterative reconstruction.
- the cascade-net [1]

[1] Schlemper, J., Caballero, J., Hajnal, J. V, Price, A., & Rueckert, D. (2018). A Deep Cascade of Convolutional Neural Networks for MR Image Reconstruction. IEEE Transactions on Medical Imaging, 37(2), 491–503. https://doi.org/10.1109/TMI.2017.2760978

In [1]:
!pip install --upgrade pip

Requirement already up-to-date: pip in /volatile/home/Zaccharie/workspace/fastmri-reproducible-benchmark/venv/lib/python3.6/site-packages (19.3.1)


In [2]:
!pip install python-pysap pysap-mri

Collecting python-pysap
[?25l  Downloading https://files.pythonhosted.org/packages/ad/c0/886a9d047ec1a7a485e556854429ed3e0aae2e8bb4dddde94ed79dd152db/python-pySAP-0.0.3.tar.gz (161kB)
[K     |████████████████████████████████| 163kB 11.8MB/s eta 0:00:01
[?25hCollecting pysap-mri
  Downloading https://files.pythonhosted.org/packages/43/0b/77d3c3e2343a7354d8d1bd2887cf5baa4a9112a6aceab6443cde7df1554b/pysap-mri-0.1.1.tar.gz
Collecting astropy>=3.0.0
[?25l  Downloading https://files.pythonhosted.org/packages/45/8a/1392bcc8e2f666aa7b7652b302b3642ca0cadf56b606e5d74550f8f8250a/astropy-4.0-cp36-cp36m-manylinux1_x86_64.whl (6.5MB)
[K     |████████████████████████████████| 6.5MB 13.2MB/s eta 0:00:01
Collecting pyqtgraph>=0.10.0
[?25l  Downloading https://files.pythonhosted.org/packages/cd/ad/307e0280df5c19986c4206d138ec3a8954afc722cea991f4adb4a16337d9/pyqtgraph-0.10.0.tar.gz (1.5MB)
[K     |████████████████████████████████| 1.6MB 58.9MB/s eta 0:00:01
[?25hCollecting progressbar2>=3.34.3
  

In [2]:
%matplotlib nbagg
import matplotlib.pyplot as plt
from mri.numerics.fourier import FFT2
from mri.numerics.reconstruct import sparse_rec_fista
from mri.numerics.utils import generate_operators
from mri.numerics.utils import convert_mask_to_locations
import numpy as np

from fastmri_recon.helpers.utils import gen_mask, crop_center
from fastmri_recon.helpers.evaluate import psnr, ssim 
from fastmri_recon.models.cascading import cascade_net


                 .|'''|       /.\      '||'''|,
                 ||          // \\      ||   ||
'||''|, '||  ||` `|'''|,    //...\\     ||...|'
 ||  ||  `|..||   .   ||   //     \\    ||
 ||..|'      ||   |...|' .//       \\. .||
 ||       ,  |'
.||        ''

Package version: 0.0.3

License: CeCILL-B

Authors: 

Antoine Grigis <antoine.grigis@cea.fr>
Samuel Farrens <samuel.farrens@cea.fr>
Jean-Luc Starck <jl.stark@cea.fr>
Philippe Ciuciu <philippe.ciuciu@cea.fr>

Dependencies: 

scipy          : >=1.3.0   - required | 1.4.1     installed
numpy          : >=1.16.4  - required | 1.16.4    installed
matplotlib     : >=3.0.0   - required | 3.1.1     installed
astropy        : >=3.0.0   - required | 4.0       installed
nibabel        : >=2.3.2   - required | 2.4.1     installed
pyqtgraph      : >=0.10.0  - required | 0.10.0    installed
progressbar2   : >=3.34.3  - required | ?         installed
modopt         : >=1.4.0   - required | 1.4.1     installed
scikit-learn   : >=0.19.1  - requi

Using TensorFlow backend.


In [3]:
plt.rcParams['figure.figsize'] = (5, 5)
plt.rcParams['image.cmap'] = 'gray'

In [4]:
np.random.seed(0)

# Data handling

The data comes from the fastMRI database [2]. We selected the 16th slice of the first validation file (`file1000000`).

[2] Zbontar, J., Knoll, F., Sriram, A., Muckley, M. J., Bruno, M., Defazio, A., … Lui, Y. W. (n.d.). fastMRI: An Open Dataset and Benchmarks for Accelerated MRI. Retrieved from https://arxiv.org/pdf/1811.08839.pdf

## Data loading

In [5]:
image = np.load('gt_image.npy')
kspace = np.load('gt_kspace.npy')

In [6]:
fig, axs = plt.subplots(1, 2)
axs[0].imshow(image[..., 0])
axs[1].imshow(np.abs(kspace[..., 0]))

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f235cb2de10>

## Retrospective undersampling

In [7]:
AF = 4
mask = gen_mask(kspace[..., 0], accel_factor=AF, seed=0)
# the mas is received in fastmri format, we make a 1 and 0 mask in the fourier space
fourier_mask = np.repeat(mask.astype(np.float), kspace.shape[0], axis=0)
masked_kspace = fourier_mask[..., None] * kspace

In [8]:
plt.figure()
plt.imshow(fourier_mask)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f235c75b518>

In [9]:
plt.figure()
plt.imshow(crop_center(np.abs(np.fft.fftshift(np.fft.ifft2(masked_kspace[..., 0], norm='ortho'))), 320))

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f235c7149e8>

## Batching

Batching is necessary when using neural networks on both the training and the prediction steps.

In [10]:
mask_batch = fourier_mask[None, ...]
masked_kspace_batch = masked_kspace[None, ...]

## Scaling
Scaling was used for training to avoid numerical approximations that lead to poor training. Generally the inputs are normalized, but it was easier here to just scale them using the mean over the training set.

In [11]:
masked_kspace_batch *= 1e6

# Reconstruction using the Cascade net neural network

## Model loading

In [12]:
run_params = {
    'n_cascade': 5,
    'n_convs': 5,
    'n_filters': 48,
    'noiseless': True,
}
run_id = 'cascadenet_af4_oasis_1569491836'
epoch = 300

In [13]:
model = cascade_net(input_size=(None, None, 1), fastmri=True, **run_params)
chkpt_path = f'../checkpoints/{run_id}-{epoch}.hdf5'
model.load_weights(chkpt_path)

W0103 16:32:23.777014 139793466787584 deprecation.py:323] From /home/zaccharie/workspace/fastmri-reproducible-benchmark/fastmri_recon/helpers/nn_mri.py:82: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.cast` instead.


## Reconstruction

In [14]:
%%time
cascade_reconstructed_image = model.predict_on_batch([masked_kspace_batch, mask_batch])

W0103 16:32:25.135434 139793466787584 deprecation_wrapper.py:119] From /home/zaccharie/workspace/keras/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.



CPU times: user 10.4 s, sys: 474 ms, total: 10.9 s
Wall time: 4.8 s


## Visual comparison

In [15]:
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
axs[0].imshow(image[..., 0])
axs[1].imshow(cascade_reconstructed_image[0, ..., 0])

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f23507d0e80>

## Quantitative comparison

In [16]:
print('PSNR of the net reconstructed image:', psnr(image, cascade_reconstructed_image[0]/1e6))

PSNR of the net reconstructed image: 26.48027258193485


In [17]:
print('SSIM of the net reconstructed image:', ssim(image[None, ..., 0], cascade_reconstructed_image[..., 0]/1e6))

SSIM of the net reconstructed image: 0.5352749583786784


# Reconstruction using pysap

## Reformatting the data

In [18]:
kspace_squeeze = np.squeeze(kspace*1e6)
k_shape = kspace_squeeze.shape
kspace_loc = convert_mask_to_locations(fourier_mask)

## Creating the appropriate operators

In [19]:
fourier_op = FFT2(samples=kspace_loc, shape=k_shape)

  mask[test] = 1


In [20]:
gradient_op, linear_op, prox_op, cost_op = generate_operators(
    data=kspace_squeeze,
    wavelet_name="sym8",
    samples=kspace_loc,
    nb_scales=4,
    mu=0.01,
    non_cartesian=False,
    uniform_data_shape=None,
    gradient_space="synthesis",
    padding_mode="periodization",
)



## Reconstruction
For the reconstruction we will use the FISTA algorithm

In [None]:
max_iter = 200
x_final,  costs, metrics = sparse_rec_fista(
    gradient_op,
    linear_op,
    prox_op,
    cost_op,
    max_nb_of_iter=max_iter,
    verbose=1)
image_rec_fs = crop_center(np.abs(np.fft.fftshift(x_final)), 320)

                                                                               N/A% (0 of 200) |                        | Elapsed Time: 0:00:00 ETA:  --:--:--


  _____             ____     _____      _
 |" ___|    ___    / __"| u |_ " _| U  /"\  u
U| |_  u   |_"_|  <\___ \/    | |    \/ _ \/
\|  _|/     | |    u___) |   /| |\   / ___ \\
 |_|      U/| |\u  |____/>> u |_|U  /_/   \_\\
 )(\\\,-.-,_|___|_,-.)(  (__)_// \\\_  \\\    >>
(__)(_/ \_)-' '-(_/(__)    (__) (__)(__)  (__)
    
 - mu:  0.01
 - lipschitz constant:  1.1000000000000685
 - data:  (640, 368)
 - wavelet:  <mri.reconstruct.linear.WaveletN object at 0x7f235c728780> - 4
 - max iterations:  200
 - image variable shape:  (640, 368)
 - alpha variable shape:  (235520,)
----------------------------------------
Starting optimization...


 97% (195 of 200) |##################### | Elapsed Time: 0:00:37 ETA:   0:00:01

## Visual comparison

In [None]:
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
axs[0].imshow(image[..., 0])
axs[1].imshow(image_rec_fs)

## Quantitative comparison

In [None]:
print('PSNR of the pysap reconstructed image:', psnr(image[..., 0], image_rec_fs/1e6))

In [None]:
print('SSIM of the pysap reconstructed image:', ssim(image[None, ..., 0], image_rec_fs[None, ...]/1e6))