This notebook demonstrates testing of the previously trained single-channel (from Step 01-1) for an acceleration factor of 10x. A subset of 18 subject scans with a correpsonding fully-sampled previous scan are reserved for testing.  

In [2]:
import subprocess

def install(name):
    subprocess.call(['pip', 'install', name])

install('nibabel')
install('scikit-learn')

In [1]:
import numpy as np
import os
import glob
import sys
import nibabel as nib
import logging
import matplotlib.pyplot as plt

# Importing our model
MY_UTILS_PATH = "../src/"
if not MY_UTILS_PATH in sys.path:
    sys.path.append(MY_UTILS_PATH)
import cs_models_sc as fsnet
import tensorflow as tf
# Importing callbacks and data augmentation utils

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import  Adam

physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [2]:
## PARAMETERS
H,W = 512,512 # Training image dimensions
channels = 2 # complex data 0-> real; 1-> imaginary
norm = np.sqrt(H*W)

In [3]:
#location of brain-cancer data-set; change this path to destination of downloaded data
data_path = '../../../data/brain-cancer/'
#define pathes to test files
test_files = np.loadtxt('../data/train_val_test_split/test_initial.txt', dtype=str)
rec_files_test = [data_path + file for file in test_files]

In [4]:
# Loading sampling patterns. Notice that here we are using uncentred k-space
var_sampling_mask = np.fft.fftshift(~np.load("../data/sampling_masks/R10_512x512_poisson_center_true_radius_40_r_2.66.npy")                                     ,axes = (1,2))
var_sampling_mask = np.concatenate((var_sampling_mask[:,:,:,np.newaxis],var_sampling_mask[:,:,:,np.newaxis]),                                          axis = -1)[0]

print("Undersampling:", 1.0*var_sampling_mask.sum()/var_sampling_mask.size)

Undersampling: 0.89959716796875


In [5]:
# Training our model
model_name = "../models/flat_unrolled_cascade_iki.hdf5"
model = fsnet.deep_cascade_flat_unrolled("iki", H, W)
opt = Adam(learning_rate = 1e-3,decay = 1e-4)
model.compile(loss = 'mse',optimizer=opt)
model.load_weights(model_name)

In [6]:
for ii in range(len(rec_files_test)):
    nib_file = nib.load(rec_files_test[ii])
    rec_test = np.swapaxes(nib_file.get_fdata(),0,2)
    rec_test = rec_test / np.abs(rec_test).max()

    aux = rec_test.shape[-1]
    kspace_test = np.zeros((rec_test.shape[0],rec_test.shape[1],rec_test.shape[2],2))
    aux = np.fft.fft2(rec_test)
    kspace_test[:,:,:,0] = aux.real
    kspace_test[:,:,:,1] = aux.imag
    var_sampling_mask_test = np.tile(var_sampling_mask,(kspace_test.shape[0],1,1,1))
    #print(var_sampling_mask_test.shape)
    kspace_test[:,var_sampling_mask] = 0

    pred = model.predict([kspace_test,var_sampling_mask_test])
    rec_pred = np.abs(pred[:,:,:,0]+1j*pred[:,:,:,1])
    rec_pred = np.swapaxes(rec_pred, 0,2)
    pred_nifti = nib.Nifti1Image(rec_pred, nib_file.affine)

    save_path = '../data/predicted/10x-iki'
    name = rec_files_test[ii].split("/")[-1][:-4]
    os.makedirs(save_path, exist_ok=True)
    nib.save(pred_nifti, save_path + '/' + name +'_predicted.nii')