In [1]:
import subprocess

def install(name):
    subprocess.call(['pip', 'install', name])

install('nibabel')
install('scikit-learn')

In [2]:
import numpy as np
import os
import glob
import sys
import nibabel as nib
import logging
import matplotlib.pyplot as plt

# Importing our model
MY_UTILS_PATH = "../src/"
if not MY_UTILS_PATH in sys.path:
    sys.path.append(MY_UTILS_PATH)
import cs_models_sc as fsnet
import tensorflow as tf
# Importing callbacks and data augmentation utils

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import  Adam
from tensorflow.keras.models import Model

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Lambda, Add, LeakyReLU,  \
                                    MaxPooling2D, concatenate, UpSampling2D,\
                                    Multiply, ZeroPadding2D, Cropping2D,    \
                                    Concatenate

physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [3]:

H,W = 512,512
w1,w2 = 0,1
string = 'iki'
channels = 2
base_model_name = "../models/flat_unrolled_cascade_iki.hdf5"
model_name = "../models/end-to-end-transfer-learning-0-1-no-dc.hdf5"
save_path = '../data/predicted/10x-transfer-0-1-no-dc'

In [4]:
test_files = np.loadtxt('../data/test.txt',dtype=str)
ref_path = '../../../data/brain-cancer/'
next_path = '../data/zero_filled_rec/10x/'
#ref_reg_path = '../data/reference_reg_10x/'
ref_reg_path = '../../../my-repos/multi-visit-isbi-2022/data/reference_reg_10x-iki/'

In [5]:
#pathes to test data
test_previous_files = [ref_reg_path + 'elastic_' + moving[:-4] + '_' + fixed for moving, fixed in zip(test_files[:,0],test_files[:,1])]
test_follow_up_files = [ref_path + file for file in test_files[:,1]]
#test_follow_up_files = [next_path + file for file in test_files[:,1]]
test_reference_files = [ref_path + file for file in test_files[:,1]]

In [6]:
#generate masks for k-space data of zero_filled reconstruction
#there may be a better way of implementing this but will try this first
# Loading sampling patterns. Notice that here we are using uncentred k-space
var_sampling_mask = np.fft.fftshift(~np.load("../data/sampling_masks/R10_512x512_poisson_center_true_radius_40_r_2.66.npy"),axes=(1,2))
var_sampling_mask = np.concatenate((var_sampling_mask[:,:,:,np.newaxis],var_sampling_mask[:,:,:,np.newaxis]),axis = -1)


In [7]:

for ii, file in enumerate(test_follow_up_files):
    img = nib.load(file)
    zero_filled_rec = np.swapaxes(img.get_fdata(),0,2)
    zero_filled_rec = zero_filled_rec / np.abs(zero_filled_rec).max()
    #convert zero filled reconstruction to kspace and undersample
    f = np.fft.fft2(zero_filled_rec)
    zero_filled_kspace = np.zeros((*zero_filled_rec.shape,2))
    zero_filled_kspace[:,:,:,0] = f.real
    zero_filled_kspace[:,:,:,1] = f.imag
    var_sampling_mask_test = np.tile(var_sampling_mask[0],(zero_filled_kspace.shape[0],1,1,1))
    print(var_sampling_mask_test.shape)
    zero_filled_kspace[:,var_sampling_mask[0]] = 0
    print(zero_filled_kspace.shape)
    
    #load previous registered reconstruction
    img2 = nib.load(test_previous_files[ii])
    previous_rec = np.swapaxes(img2.get_fdata(),0,2)[...,np.newaxis]
    previous_rec = previous_rec / np.abs(previous_rec).max()
    print(previous_rec.shape)
    #load our model
    
    
    # single-visit reconstruction model independently trained in modular fashion
    base_model = fsnet.deep_cascade_flat_unrolled(string, H, W)
    #end-to-end model inputs
    inputs = Input(shape=(H,W,channels))
    inputs2 = Input(shape=(H,W,1))
    mask = Input(shape=(H,W,channels))
    #append base model with layers for unet_block and data consistency 
    
    x = base_model([inputs, mask], training=False)
    mag = Lambda(fsnet.abs_layer)(x)
    ph = Lambda(fsnet.phase_layer)(x)
    x = tf.concat([mag, inputs2],-1)
    x = fsnet.unet_block(x, channels=1)

    out = Add()([x, mag])
    out1 = fsnet.polar2cartesian(ph, out)
    
    #out2 = fsnet.DC_block(out1, mask, inputs, channels, kspace=False)
    #out3 = Lambda(fsnet.ifft_layer)(out2)
    out4 = fsnet.abs_layer(out1)

    model = Model(inputs=[inputs,inputs2,mask], outputs=[mag,out4])
    model.load_weights(model_name)
    pred = model.predict([zero_filled_kspace, previous_rec, var_sampling_mask_test])
    
    single = pred[0][:,:,:,0]
    single = np.swapaxes(single, 0,2)
    single_nifti = nib.Nifti1Image(single, img.affine)
    
    name = test_files[ii][1][:-4]
    os.makedirs(save_path, exist_ok=True)
    nib.save(single_nifti, save_path + '/' + name +'_single_visit.nii')
    
    multi = pred[1][:,:,:,0]
    multi = np.swapaxes(multi, 0,2)
    multi_nifti = nib.Nifti1Image(multi, img.affine)
    name = test_files[ii][1][:-4]
    nib.save(multi_nifti, save_path + '/' + name +'_multi_visit.nii')

(152, 512, 512, 2)
(152, 512, 512, 2)
(152, 512, 512, 1)
(159, 512, 512, 2)
(159, 512, 512, 2)
(159, 512, 512, 1)
(179, 512, 512, 2)
(179, 512, 512, 2)
(179, 512, 512, 1)
(142, 512, 512, 2)
(142, 512, 512, 2)
(142, 512, 512, 1)
(156, 512, 512, 2)
(156, 512, 512, 2)
(156, 512, 512, 1)
(166, 512, 512, 2)
(166, 512, 512, 2)
(166, 512, 512, 1)
(165, 512, 512, 2)
(165, 512, 512, 2)
(165, 512, 512, 1)


# Compare base model weights to end-to-end model

In [8]:
model.layers[2].get_weights()

[array([[[[ 8.06608796e-02,  6.18829466e-02,  9.50664654e-02,
            5.50335785e-03,  1.12089969e-01,  5.28192706e-02,
           -1.64499640e-01, -7.68155307e-02,  2.82528698e-02,
           -6.34312406e-02,  4.68422435e-02,  1.13049522e-01,
            8.57907385e-02,  3.49574462e-02, -9.47760716e-02,
           -1.58189945e-02, -1.77190334e-01,  6.76353574e-02,
           -1.97243690e-02, -1.35127664e-01,  1.69734463e-01,
           -9.80107039e-02, -8.00574720e-02,  3.90859582e-02,
            1.08011499e-01,  8.24879408e-02,  2.02184450e-02,
           -1.73835963e-01, -1.19159497e-01,  4.23874855e-02,
           -1.38673231e-01,  1.18771136e-01, -1.31270751e-01,
            1.65880814e-01,  8.67916122e-02, -1.90241531e-01,
           -1.03974760e-01, -1.24005049e-01,  1.42652124e-01,
            3.32921743e-02, -8.45480636e-02,  7.52143264e-02,
           -1.01905331e-01,  1.36679769e-01, -5.18107116e-02,
           -1.35888070e-01,  5.07452562e-02, -6.05266802e-02],
       

In [9]:
base_model.load_weights(base_model_name)

In [10]:
base_model.layers[2].get_weights()

[array([[[[ 8.06608796e-02,  6.18829466e-02,  9.50664654e-02,
            5.50335785e-03,  1.12089969e-01,  5.28192706e-02,
           -1.64499640e-01, -7.68155307e-02,  2.82528698e-02,
           -6.34312406e-02,  4.68422435e-02,  1.13049522e-01,
            8.57907385e-02,  3.49574462e-02, -9.47760716e-02,
           -1.58189945e-02, -1.77190334e-01,  6.76353574e-02,
           -1.97243690e-02, -1.35127664e-01,  1.69734463e-01,
           -9.80107039e-02, -8.00574720e-02,  3.90859582e-02,
            1.08011499e-01,  8.24879408e-02,  2.02184450e-02,
           -1.73835963e-01, -1.19159497e-01,  4.23874855e-02,
           -1.38673231e-01,  1.18771136e-01, -1.31270751e-01,
            1.65880814e-01,  8.67916122e-02, -1.90241531e-01,
           -1.03974760e-01, -1.24005049e-01,  1.42652124e-01,
            3.32921743e-02, -8.45480636e-02,  7.52143264e-02,
           -1.01905331e-01,  1.36679769e-01, -5.18107116e-02,
           -1.35888070e-01,  5.07452562e-02, -6.05266802e-02],
       

In [11]:
model.summary()

Model: "model_13"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_33 (InputLayer)           [(None, 512, 512, 2) 0                                            
__________________________________________________________________________________________________
input_35 (InputLayer)           [(None, 512, 512, 2) 0                                            
__________________________________________________________________________________________________
model_12 (Functional)           (None, 512, 512, 2)  252438      input_33[0][0]                   
                                                                 input_35[0][0]                   
__________________________________________________________________________________________________
lambda_65 (Lambda)              (None, 512, 512, 1)  0           model_12[0][0]            

In [12]:
base_model.summary()

Model: "model_12"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_31 (InputLayer)           [(None, 512, 512, 2) 0                                            
__________________________________________________________________________________________________
lambda_60 (Lambda)              (None, 512, 512, 2)  0           input_31[0][0]                   
__________________________________________________________________________________________________
conv2d_240 (Conv2D)             (None, 512, 512, 48) 912         lambda_60[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_90 (LeakyReLU)      (None, 512, 512, 48) 0           conv2d_240[0][0]                 
___________________________________________________________________________________________

In [13]:
base_model.layers[2]

<keras.layers.convolutional.Conv2D at 0x7f2aabfdc630>

In [14]:
model.layers[2].load_weights(base_model_name)