# Magnetic Resonance Imaging 

## Looking Ahead

In the previous example, we verified that an example MR image from a dataset was sparse in the wavelet domain, after we transformed it from the spatial domain. In this example, we will exploit this sparsity in order to recover the MR image from compressive measurements. Recall that MR images are, at acquisition time, obtained in the Fourier domain; this has implications for the measurement map that we construct in a compressed sensing setup. Please review the homework writeup for a discussion of these issues and a justification for the sensing model we implement below.

## Loading the Data

We use the same code from the previous notebook to load in the data. As before, we will focus on a single sagittarial slice of the patient's anatomical data.

In [None]:
## Install AWS CLI tools
!pip install awscli
## Prepare data directory
import os
os.chdir('/content')
!mkdir bold5000
os.chdir('/content/bold5000')

## Grab the data
#!aws s3 sync --no-sign-request --exclude "*" --include "*06*" s3://openneuro.org/ds001499/derivatives/fmriprep/sub-CSI3/ses-13/func/ /content/bold5000/sub-CSI3_ses-13_run-06/
!aws s3 sync --no-sign-request s3://openneuro.org/ds001499/sub-CSI3/ses-16/anat/ /content/bold5000/sub-CSI3_anat/

We add the same auxiliary definitions as last time.

In [None]:
## Auxiliary code for our wavelet experiments
import bokeh
import bokeh.plotting as bpl
from bokeh.models import ColorBar, BasicTicker, LinearColorMapper
import pywt

## Try to do something like imagesc in MATLAB using Bokeh tools.
def imagesc(M, title=''):
  m, n = M.shape
  
  # 600 px should be good; calculate ph to try to get aspect ratio right
  pw = 600
  ph = round(1.0 * pw * m / n)
  h = bpl.figure(plot_width = pw, plot_height = ph, x_range=(0, 1.0*n),
                 y_range=(0, 1.0*m), toolbar_location='below',
                 title=title, match_aspect=True
                )
  
  minval = np.min(M)
  maxval = np.max(M)
  
  color_mapper = LinearColorMapper(palette="Greys256", low=minval, high=maxval)
  h.image(image=[M], x=0, y=0, dw=1.0*n, dh=1.0*m, color_mapper=color_mapper)
  
  color_bar = ColorBar(color_mapper=color_mapper, ticker=BasicTicker(),
                      label_standoff=12, border_line_color=None, location=(0, 0))
  
  h.add_layout(color_bar, 'right')
  

  bpl.show(h)
  return h

## Wavelet functions below
## Note: we expect all image sizes to be powers-of-two and square!
## So if you adapt this code, be sure to fix this or enforce this requirement...

# Get a default slice object for a multilevel wavelet transform
# Used to abstract this annoying notation out of the transform...
def default_slices(levels, n):
  c = pywt.wavedec2(np.zeros((n, n)), 'db4', mode='periodization', level=levels)
  bye, slices = pywt.coeffs_to_array(c)
  return slices

# Wrapper for forward discrete wavelet transform
# Output data as a matrix (we don't care about tuple format)
def dwt(levels, sdom_data):
  c = pywt.wavedec2(sdom_data, 'db4', mode='periodization', level=levels)
  output, bye = pywt.coeffs_to_array(c)
  return output

# Wrapper for inverse discrete wavelet transform
# Expect wdom_data as a matrix (we don't care about tuple format)
def idwt(levels, wdom_data, slices=None):
  n = wdom_data.shape[0]
  if slices is None:
    slices = default_slices(levels, n)
  c = pywt.array_to_coeffs(wdom_data, slices, output_format='wavedec2')
  return pywt.waverec2(c, 'db4', mode='periodization')

We finally extract the sagittarial slice we will study. Again, same as last time.

In [None]:
import numpy as np
import nibabel as nib

img = nib.load('/content/bold5000/sub-CSI3_anat/sub-CSI3_ses-16_T1w.nii.gz')

data = img.get_fdata()

## Store dimensions
Nx = data.shape[0]
Ny = data.shape[1]
Nz = data.shape[2]
n = Ny
X = data[Nx//2, :, :];

bpl.output_notebook()
imagesc(data[Nx//2, :, :], title='MR Image We Will Recover')

## Implementing the Measurement Model

An MR machine collects samples of the 2D Fourier transformation of the underlying spatial profile (e.g., the figure above). An MR machine that employs compressive sensing does the same, but collects far fewer than the `n ** 2` measurements necessary to exactly represent the image at the resolution we are using above. From the lecture, we know that such subsampling leads to an incoherent measurement map when the sampling is done randomly; and we expect thus that L1 minimization will work for recovery from the compressive measurements.

The measurement model we will implement here is the Bernoulli one described in the homework handout. We will implement this below.

In [None]:
## Create the index set for our mapping
p = 0.3  # bernoulli distribution parameter
Omega = {}# We define Omega here to contain all the indices to _delete_
for idx_i in np.arange(n):
  for idx_j in np.arange(n):
    coin = np.random.rand(1,)
    if coin > p:
      Omega[(idx_i, idx_j)] = 1
idxs = np.asarray(list(Omega.keys()))
len(X[idxs[:,0], idxs[:,1]]) # Index like this

## Create the operator
levels = 2
def meas_map(mtx):
  pre_proj = np.fft.fft2(idwt(levels, mtx), norm="ortho")
  pre_proj[idxs[:, 0], idxs[:, 1]] = 0
  return pre_proj
def meas_map_adj(mtx):
  mtx[idxs[:, 0], idxs[:, 1]] = 0
  return dwt(levels, np.fft.ifft2(mtx, norm="ortho"))

## Performing Sparse Recovery

We generate observations `Y` from our input `X` and the measurement map. In particular, we first get the sparse coefficients `S` for our image, and then transform them in order to match with the measurement model.

In [None]:
S = dwt(levels, X)
Y = meas_map(S)

We plot a few observations of our image below.

In [None]:
bpl.output_notebook()
imagesc(np.fft.fftshift(np.abs(np.fft.fft2(X/n**2))))

Above is the normalized 2D FFT of the image `X`, shifted to have low frequencies in the center of the image. We see that nearly all the frequency content is localized in the low frequencies, and the image appears quite sparse! However, the Fourier phase contains a *lot* of information about the image: reconstructing as we did in the previous homework using large-magnitude wavelet coefficients (but with large magnitude Fourier coefficients here) gives a small squared-error, but a result with poor visual quality, as we can see in the figure below.

In [None]:
F = np.fft.fft2(X)
absmags = np.absolute(F.flatten())
idxs_absmag = np.argsort(absmags)
idxs_absmag = idxs_absmag[::-1] 

num_keep = 8000
F_copy = np.copy(F)
F_copy[np.unravel_index(idxs_absmag[num_keep:], (n, n))] = 0
X_reconstr = np.fft.ifft2(F_copy)

bpl.output_notebook()
imagesc(np.abs(X_reconstr))

This is what motivates us to use the wavelet transform, and its corresponding notion of sparsity, in our measurement map.

## Your Tasks

Complete each of the tasks in the level three headers below.

### Task 1: Sparse Recovery with Proximal Gradient

For this task, you should implement the proximal gradient descent algorithm for the LASSO objective with the measurement map we have specified in this problem and in the theoretical setting sketched in the homework writeup. Your algorithm will be quite similar to the algorithm you wrote (or will write) for the spectrum sensing application, but you will need to make the necessary changes to the matrix-vector multiplications to accommodate the "matrix linear maps" in this problem.

See the first task's description in the spectrum sensing problem for hints about how to code and debug your proximal gradient descent algorithm.

**Hint**: Be sure your choice of initialization matches the scale of the data `Y`. A good practice is to use Gaussian initialization, such that the expected Frobenius norm of the initialization matrix is around 1; then also scale the matrix `Y` to have Frobenius norm 1. You can restore the original scale after you solve the optimization problem.

**Hint**: The algorithm can take a while to run. To speed things up, for the purpose of testing you may want to downsample the input MR image, so that instead of being `256 x 256` it is e.g. `64 x 64`. Be sure to keep the sizes as powers of two, or the wavelet transform code will break. This will also help with experimenting to find the best setting of the sparsifying parameter `lambda`: it plays a very large role in the visual quality of the signal you reconstruct.

### Task 2: Assessing Performance

Complete the following performance evaluation tasks for your sparse recovery algorithm:
1. For at least 3 values of the Bernoulli parameter `p`, say `[0.1, 0.2, 0.4, 0.5, 0.7]`, perform at least 3 independent trials of the sparse recovery experiment. Here, "independent trials" means you should re-generate the measurement map in each separate experiment. For each experiment, calculate the mean squared error between your LASSO solver's output and the ground truth image `X` (make sure they are in the same domain!), and average the mean squared errors for each setting of `p` over the independent trials. Output a plot of these averaged mean squared errors as a function of `p`; include error bars corresponding to the trial variances.
2. How large do you need `p` to be before you get acceptable (in terms of both MSE and in terms of visual quality) results for the sparse recovery experiment? Can you give an explanation for why the performance may be worse here than in other experiments in terms of properties of the specific measurement map we use here?

## Solutions

### Task 1

We give the code for performing proximal gradient on the matrix LASSO objective below.

A challenging aspect of this problem is that the problem size (essentially $n = 2^{16}$) makes convergence slow, and initialization plays a role in the speed of convergence. Also, for small $p$, there are problem instances (realizations of the set $\Omega$) that are not well-posed. In the solution we will demonstrate two ways to get around this. One is to implement a biased initialization using the pseudoinverse, which in this context reduces to `X = meas_map_adj(Y)`; it does not speed up convergence much, but it gets rid of the randomness in rate of convergence associated with a naive random initialization. 

In [None]:
## Prox Gradient
def soft_thresh(x, lambd):
  return np.maximum(x - lambd, np.zeros_like(x))

def prox_l1(x, lambd):
  phases = np.angle(x)
  mags = np.abs(x)
  thresholded = soft_thresh(mags, lambd)
  return thresholded * np.exp(1j * phases)

def objective(x, A, y):
  return 0.5 * np.linalg.norm(y - A(x), ord='fro')**2
    
def pg_lasso(Y, A, A_adj, lambd):
  lip = 1
  alpha = 0.5 * 1/lip
  MAX_ITER = int(2e3)
  TOL = 1e-8
  ZERO_TOL = 1e-6
  S_init = 1/n * np.random.randn(n,n)
  # S_init = A_adj(Y)
  
  pt = S_init
  obj = np.zeros((MAX_ITER,))

  for iteration in range(MAX_ITER):
    # Step to aux point
    # Compute gradient
    grad = A_adj(A(pt) - Y)
    # Gradient step
    pt = prox_l1(pt - alpha * grad, alpha * lambd)
    # Evaluate objective
    obj[iteration] = objective(pt, A, Y) + lambd * np.linalg.norm(pt.flatten(), ord=1)
    
    if iteration % 1e2 == 0:
      print('iter {}, obj {}'.format(iteration, obj[iteration]))
    
    # Stopping criterion
    if iteration > 0 and np.abs(obj[iteration-1] - obj[iteration]) < TOL:
      print('Met minimum tolerance at iter {}. Breaking.'.format(iteration))
      break
  
  obj = obj[:iteration]
  return pt, obj



Another way to make things run smoother is to use the knowledge we have acquired since completing the homework, and apply the accelerated proximal gradient algorithm to speed up convergence. I implement here a version of this algorithm with a very simple-to-implement choice of the "momentum parameter" sequence $t_k$, which is worth remembering.

In [None]:
def apg_lasso(Y, A, A_adj, lambd):
  lip = 1
  alpha = 1/lip
  
  MAX_ITER = int(1e6)
  TOL = 1e-8
  ZERO_TOL = 1e-6
  S_init = 1/n * np.random.randn(n,n)
 #   S_init = A_adj(Y)
  
  pt = S_init
  pt_prev = pt
  aux_pt = np.zeros_like(pt)
  obj = np.zeros((MAX_ITER,))

  for iteration in range(MAX_ITER):
    # Step to aux point
    aux_pt = pt + (iteration / (iteration + 3)) * (pt - pt_prev)
    pt_prev = pt
    # Compute gradient
    grad = A_adj(A(aux_pt) - Y)
    # Gradient step
    pt = prox_l1(aux_pt - alpha * grad, alpha * lambd)

    obj[iteration] = objective(pt, A, Y) + lambd * np.linalg.norm(pt.flatten(), ord=1)
    
    if iteration % 1e2 == 0:
      print('iter {}, obj {}'.format(iteration, obj[iteration]))
    
    # Stopping criterion
    if iteration > 0 and np.abs(obj[iteration-1] - obj[iteration]) < TOL:
      print('Met minimum tolerance at iter {}. Breaking.'.format(iteration))
      break
  
  obj = obj[:iteration]
  return pt, obj

We test the code below. You can try calling `pg_lasso` or `apg_lasso` and compare the difference.

In [None]:
x_try1, obj = pg_lasso(Y/np.linalg.norm(Y, ord='fro'), meas_map, meas_map_adj, 1e-4)

In [None]:
import matplotlib.pyplot as plt
plt.loglog(obj)
plt.show()

bpl.output_notebook()
imagesc(np.real(idwt(levels, x_try1)))

### Task 2

Now we complete the evaluation tasks.

In [None]:
def trial(prob, lambd):
  # Run a single trial of the experiment
  # Copy-paste our measurement map code from above
  Omega = {}# We define Omega here to contain all the indices to _delete_
  for idx_i in np.arange(n):
    for idx_j in np.arange(n):
      coin = np.random.rand(1,)
      if coin > prob:
        Omega[(idx_i, idx_j)] = 1
  idxs = np.asarray(list(Omega.keys()))
  len(X[idxs[:,0], idxs[:,1]]) # Index like this

  ## Create the operator
  levels = 2
  def meas_map(mtx):
    pre_proj = np.fft.fft2(idwt(levels, mtx), norm="ortho")
    pre_proj[idxs[:, 0], idxs[:, 1]] = 0
    return pre_proj
  def meas_map_adj(mtx):
    mtx[idxs[:, 0], idxs[:, 1]] = 0
    return dwt(levels, np.fft.ifft2(mtx, norm="ortho"))
    pass
  
  # Take measurement
  Y = meas_map(S)
  
  # Perform recovery
  scale = np.linalg.norm(Y, ord='fro')
  S_est, obj = apg_lasso(Y/scale, meas_map, meas_map_adj, lambd)
  X_est = idwt(levels, scale * S_est).real
  
  # Output result
  # Note that because the wavelet transform is unitary, there is no difference
  # between comparing MSEs of S and S_est (wavelet domain) or X and X_est (spatial domain)
  bpl.output_notebook()
  imagesc(X_est)
  return np.linalg.norm(X_est - X, ord='fro')**2/n**2


In [None]:
# Run the experiments and plot the results.
p_vec = [0.1, 0.25, 0.35, 0.5, 0.6]
lambda_sched = [1e-4, 1e-4, 1e-4, 1e-4, 5e-4]
num_trials = 3

nmse_results = np.zeros((len(p_vec), num_trials))
for prob_idx in np.arange(len(p_vec)):
  prob = p_vec[prob_idx]
  lambd = lambda_sched[prob_idx]
  print('- probability of keeping: {}'.format(prob))
  for trial_idx in np.arange(num_trials):
    print('--- trial number: {}'.format(trial_idx + 1))
    nmse_results[prob_idx, trial_idx] = trial(prob, 1e-4)


In [None]:
nmse = np.mean(nmse_results, axis=1)
nmse_std = np.std(nmse_results, axis=1)
plt.errorbar(p_vec, nmse, yerr=nmse_std)
plt.show()

One may account for the large variance across trials as being due to the measurement map we employ, and occasional ill-posedness as a result of bad draws. It would be possible to further improve performance by tuning the sparsifying parameter $\lambda$ individually for each setting of $p$. 

In our experiments, we can start to discern fine-detail features in the recovered image once $p$ is close to $0.5$. One way to interpret the relatively poor scaling performance in this application compared to e.g. the spectrum sensing application is that here, the measurement operator we implement (subsampling) does not have the RIP for arbitrary sparsity levels.