In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pylab as plt
import os
import glob
import sys
import natsort
from tensorflow.keras.optimizers import  Adam
import nibabel as nib
import h5py
import tensorflow as tf

# Include path to my modules
MY_UTILS_PATH = "../Modules/"
if not MY_UTILS_PATH in sys.path:
    sys.path.append(MY_UTILS_PATH)

# Import my modules
import cs_models_mc as fsnet

physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [2]:
# Input parameters
channels = 24
batch_size = 8
weights_path = "../Models/weights_wwnet_ikik_mc_r10.h5"
model_string = "ikik"
test_path = "../../Dataset/Test-R=10/"
results_path = "../WWnet-ikik-baseline/12-channel-R=10/"
cascade = "unet"
crop = (50,-50) # slices to crop
verbose = True

In [3]:
test_files = np.asarray(glob.glob(test_path + "*.h5"))

# Sort by kx - ky size
files_sizes = []
for ii in test_files:
    with h5py.File(ii, 'r') as f:
        kspace_test = f['kspace']
        files_sizes.append(kspace_test.shape[1]*kspace_test.shape[2]) 
files_sizes = np.asarray(files_sizes)
indexes = np.argsort(files_sizes)
test_files = test_files[indexes]

if verbose:
    print(cascade)
    print("Domains: ", model_string)
    print("Number of files:", len(test_files))
    print(test_files[0])
    print("Weights path:", weights_path)
    print("Test path:", test_path)
    print("results path:", results_path)
    
norm = np.sqrt(218*170)
model_exists = False
for ii in test_files:

    name = ii.split("/")[-1]
    # Load data
    with h5py.File(ii, 'r') as f:
        kspace_test=  np.array(f.get('kspace')).astype(np.float32)[crop[0]:crop[1]]
    
    kspace_test = kspace_test/norm

    Z,H,W,_ = kspace_test.shape
    if H*W != (218*170): 
        model_exists = False

    Wpad = (8 -W%8)//2
    Hpad = (8 -H%8)//2

    if not model_exists:
        if cascade == 'unet':
            model = fsnet.deep_cascade_unet(depth_str = model_string, H = H, W = W, Hpad = Hpad, Wpad = Wpad, channels = channels) 
        elif cascade == 'flat':
            model = fsnet.deep_cascade_flat_unrolled(depth_str = model_string, H = H,W = W,depth = 14, kshape = (3,3), nf = 116,channels = channels)              

        model_exists = True
        opt = Adam()
        model.compile(loss = 'mse',optimizer = opt)
        model.load_weights(weights_path)   

    var_sampling_mask = ((kspace_test == 0)).astype(np.float32)
    pred = model.predict([ kspace_test, var_sampling_mask],batch_size = batch_size)*norm
    pred = pred[:,:,:,::2]+1j*pred[:,:,:,1::2]
    pred = np.sqrt((np.abs(pred)**2).sum(axis = -1)) # Root sum of squares

    if verbose:
        print(os.path.join(results_path,name))
    with h5py.File(os.path.join(results_path,name), 'w') as hf:
        hf.create_dataset('reconstruction', data=pred)

unet
Domains:  ikik
Number of files: 50
../../Dataset/Test-R=10/e15275s3_P27648.7.h5
Weights path:
 ../Models/weights_wwnet_ikik_mc_r10.h5
Test path:
 ../../Dataset/Test-R=10/
results path:
 ../WWnet-ikik-baseline/Track01/12-channel-R=10/
../WWnet-ikik-baseline/Track01/12-channel-R=10/e15275s3_P27648.7.h5
../WWnet-ikik-baseline/Track01/12-channel-R=10/e13991s3_P01536.7.h5
../WWnet-ikik-baseline/Track01/12-channel-R=10/e15865s13_P62464.7.h5
../WWnet-ikik-baseline/Track01/12-channel-R=10/e14736s3_P55296.7.h5
../WWnet-ikik-baseline/Track01/12-channel-R=10/e14080s3_P18944.7.h5
../WWnet-ikik-baseline/Track01/12-channel-R=10/e15493s3_P16896.7.h5
../WWnet-ikik-baseline/Track01/12-channel-R=10/e15781s13_P96256.7.h5
../WWnet-ikik-baseline/Track01/12-channel-R=10/e15866s13_P72192.7.h5
../WWnet-ikik-baseline/Track01/12-channel-R=10/e14618s3_P51200.7.h5
../WWnet-ikik-baseline/Track01/12-channel-R=10/e15272s3_P07680.7.h5
../WWnet-ikik-baseline/Track01/12-channel-R=10/e14781s3_P18944.7.h5
../WWnet-i