In [1]:
import subprocess

def install(name):
    subprocess.call(['pip', 'install', name])

install('nibabel')
install('scikit-learn')

In [2]:
import numpy as np
import os
import glob
import sys
import nibabel as nib
import logging
import matplotlib.pyplot as plt

# Importing our model
MY_UTILS_PATH = "../src/"
if not MY_UTILS_PATH in sys.path:
    sys.path.append(MY_UTILS_PATH)
import cs_models_sc as fsnet
import tensorflow as tf
# Importing callbacks and data augmentation utils

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import  Adam
from tensorflow.keras.models import Model

physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [3]:
lr = 1e-4
decay = 1e-7
H,W = 512,512
w1,w2 = 0,1
model_name = "../models/end-to-end-transfer-learning.hdf5"
save_path = '../data/predicted/10x-transfer'

In [4]:
test_files = np.loadtxt('../data/test.txt',dtype=str)
ref_path = '../../../data/brain-cancer/'
next_path = '../data/zero_filled_rec/10x/'
ref_reg_path = '../data/reference_reg_10x/'

In [5]:
#pathes to test data
test_previous_files = [ref_reg_path + 'elastic_' + moving[:-4] + '_' + fixed for moving, fixed in zip(test_files[:,0],test_files[:,1])]
test_follow_up_files = [ref_path + file for file in test_files[:,1]]
#test_follow_up_files = [next_path + file for file in test_files[:,1]]
test_reference_files = [ref_path + file for file in test_files[:,1]]

In [6]:
#generate masks for k-space data of zero_filled reconstruction
#there may be a better way of implementing this but will try this first
# Loading sampling patterns. Notice that here we are using uncentred k-space
var_sampling_mask = np.fft.fftshift(~np.load("../data/sampling_masks/R10_512x512_poisson_center_true_radius_40_r_2.66.npy"),axes=(1,2))
var_sampling_mask = np.concatenate((var_sampling_mask[:,:,:,np.newaxis],var_sampling_mask[:,:,:,np.newaxis]),axis = -1)


In [7]:

for ii, file in enumerate(test_follow_up_files):
    img = nib.load(file)
    zero_filled_rec = np.swapaxes(img.get_fdata(),0,2)
    zero_filled_rec = zero_filled_rec / np.abs(zero_filled_rec).max()
    #convert zero filled reconstruction to kspace and undersample
    f = np.fft.fft2(zero_filled_rec)
    zero_filled_kspace = np.zeros((*zero_filled_rec.shape,2))
    zero_filled_kspace[:,:,:,0] = f.real
    zero_filled_kspace[:,:,:,1] = f.imag
    var_sampling_mask_test = np.tile(var_sampling_mask[0],(zero_filled_kspace.shape[0],1,1,1))
    print(var_sampling_mask_test.shape)
    zero_filled_kspace[:,var_sampling_mask[0]] = 0
    print(zero_filled_kspace.shape)
    
    #load previous registered reconstruction
    img2 = nib.load(test_previous_files[ii])
    previous_rec = np.swapaxes(img2.get_fdata(),0,2)[...,np.newaxis]
    previous_rec = previous_rec / np.abs(previous_rec).max()
    print(previous_rec.shape)
    #load our model
    
    model = fsnet.deep_cascade_flat_unrolled_end("iki", H, W)
    opt = Adam(learning_rate = lr,decay = decay)
    model.compile(loss = ['mse','mse'],loss_weights=[w1,w2],optimizer=opt)
    model.load_weights(model_name)
    
    #model2 = Model(inputs=model.inputs, outputs=model.layers[-4])
    print(model.summary())
    
    pred = model.predict([zero_filled_kspace, previous_rec, var_sampling_mask_test])
    
    single = pred[0][:,:,:,0]
    single = np.swapaxes(single, 0,2)
    single_nifti = nib.Nifti1Image(single, img.affine)
    
    name = test_files[ii][1][:-4]
    os.makedirs(save_path, exist_ok=True)
    nib.save(single_nifti, save_path + '/' + name +'_single_visit.nii')
    
    multi = pred[1][:,:,:,0]
    multi = np.swapaxes(multi, 0,2)
    multi_nifti = nib.Nifti1Image(multi, img.affine)
    name = test_files[ii][1][:-4]
    nib.save(multi_nifti, save_path + '/' + name +'_multi_visit.nii')

(152, 512, 512, 2)
(152, 512, 512, 2)
(152, 512, 512, 1)


ValueError: You are trying to load a weight file containing 23 layers into a model with 40 layers.