In [1]:
%cd /home/pac/gits/phd/mri-online/
%matplotlib widget

/home/pac/gits/phd/mri-online


In [2]:
import os
import time
import numpy as np
import copy
import matplotlib.pyplot as plt
import scipy as sp
from mri.operators import FFT,WaveletN, OWL
from modopt.opt.linear import Identity
from modopt.opt.proximity import GroupLASSO, IdentityProx
from online.operators.proximity import LASSO
from online.generators import Column2DKspaceGenerator,  DataOnlyKspaceGenerator, KspaceGenerator, PartialColumn2DKspaceGenerator
from online.reconstructors import OnlineReconstructor
from online.operators.fourier import ColumnFFT
from project_utils import implot, load_data, create_cartesian_metrics
from online.metrics import ssos, psnr_ssos,ssim_ssos,mse_ssos

plt.rcParams['axes.formatter.useoffset'] = False

plt.style.use({'figure.facecolor':'white'})



In [3]:
def plot_metric(results, name, *args, log=False, ax=None,**kwargs):
    if ax == None:
        ax = plt.gca()
    if log:
        ax.semilogy(results['metrics'][name]['index'], results['metrics'][name]['values'],*args,**kwargs,label=name)
    else:
        ax.plot(results['metrics'][name]['index'], results['metrics'][name]['values'],*args, **kwargs,label=name)


In [4]:
DATA_DIR = "data/"
N_JOBS = -1
results = dict()


full_k, real_img, mask_loc, final_mask = load_data(DATA_DIR, 2, monocoil=False)
final_k = np.squeeze(full_k * final_mask[np.newaxis])
square_mask= np.zeros(real_img.shape)
real_img_size = real_img.shape
img_size = [min(real_img.shape)]*2
square_mask[real_img_size[0] // 2 - img_size[0] // 2:real_img_size[0] // 2 + img_size[0] // 2,
            real_img_size[1] // 2 - img_size[1] // 2:real_img_size[1] // 2 + img_size[1] // 2] = 1

# Type II reconstruction

In [10]:
full_k, real_img, mask_loc, final_mask = load_data(DATA_DIR, 2, monocoil=False)
final_k = np.squeeze(full_k * final_mask[np.newaxis])

line_kspace_gen = DataOnlyKspaceGenerator(full_kspace=final_k, mask_cols=mask_loc)
kspace_gen = PartialColumn2DKspaceGenerator(full_kspace=final_k, mask_cols=mask_loc)

K_DIM = line_kspace_gen.shape[-2:]
N_COILS = line_kspace_gen.shape[0] if full_k.ndim == 3 else 1
#line_fourier_op = ColumnFFT(shape=K_DIM, mask=final_mask, n_coils=N_COILS)
line_fourier_op = ColumnFFT(shape=K_DIM, n_coils=N_COILS)
fourier_op = FFT(shape=K_DIM,n_coils=N_COILS, mask=final_mask)


In [11]:
linear_op = WaveletN("sym8", nb_scale=4, n_coils=N_COILS, n_jobs=N_JOBS)
# initialisation of wavelet transform
linear_op.op(np.zeros_like(final_k))

# Regularizer definition
# Regularizer definition
if N_COILS ==1:
    GL_op = LASSO(weights=2e-6)
else:
    GL_op = GroupLASSO(weights=2e-6)
    
OWL_op = OWL(alpha=1e-05,
             beta=1e-12,
             bands_shape=linear_op.coeffs_shape,
             mode='band_based',
             n_coils=N_COILS,
             n_jobs=N_JOBS)
cost_op_kwargs = {'cost_interval': 1}

In [15]:
online_pb = OnlineReconstructor(line_fourier_op,
                                Identity(),
                                regularizer_op=IdentityProx(),
#                                 linear_op,
#                                 regularizer_op=OWL_op,
                                opt='vanilla',
                                verbose=0)

metrics_config = create_cartesian_metrics(online_pb, real_img, final_mask, final_k)
metrics_config['metrics']['grad'] = {'metric': lambda x,y: np.sqrt(np.sum(np.square(abs(y*x)))),
                               'mapping': {'dir_grad': 'x',
                                           'speed_grad':'y'},
                               'early_stopping': False,
                               'cst_kwargs': dict(),
                               }

# metrics_config['metrics'].pop('ssim')
# metrics_config['metrics'].pop('psnr')

results = online_pb.reconstruct(line_kspace_gen,
                                eta=1.,
                                beta=1.,
#                                 eta_update= lambda eta, idx: eta*1.001,
                       #         **metrics_config,
                                epsilon=1e-8,
                                nb_run=1,
                                epoch_size=1,
                                )
x=ssos(results['x_final'])

psnr = psnr_ssos(x,real_img,mask=square_mask)
ssim = ssim_ssos(x,real_img,mask=square_mask)
implot(x,title=f"PSNR = {psnr:.2f} dB, ssim={ssim:.3f}",mask=square_mask, colorbar=True);
                                  
                                  

  0%|          | 0/80 [00:00<?, ?it/s]

vanilla


100%|██████████| 80/80 [00:22<00:00,  3.61it/s]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [16]:

plt.figure()
plt.plot(abs(real_img[320]),label='ref')
plt.plot(abs(ssos(x)[320]),label='xf')
plt.legend()
plt.figure()
plot_metric(results,'data_res_off',log=True)
plot_metric(results,'data_res_on',log=True)
plot_metric(results,'grad','--',log=True)
plot_metric(results, 'reg_res',log=True)
plt.legend()
fig, (ax1, ax2) = plt.subplots(2, 1)
plot_metric(results,'psnr',ax=ax1)
plot_metric(results,'ssim',ax=ax2)


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

KeyError: 'data_res_off'