The notebooks takes unseen test data and produces a multi-visit reconstruction based on the initial reconstruction (Step 01) and the previous registered scan (Step 02).

A docker container was again used to run the following notebook

In [None]:

import sys


import numpy as np
import os
import glob
import sys
import matplotlib.pyplot as plt
# Importing keras and data augmentation utils
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.optimizers import  Adam
import pickle
import nibabel as nib
from keras.models import load_model
from keras.preprocessing.image import ImageDataGenerator

# Include path where reconstruction modules are defined
MY_UTILS_PATH = "../src"
if not MY_UTILS_PATH in sys.path:
    sys.path.append(MY_UTILS_PATH)

import enhancement_unet as eu

In [2]:
batch_size = 16
lr = 1e-4
decay = 1e-7

In [3]:
#model_path = '../models/enhancement_unet_sc_512x512_5x.h5'
weights_path = '../models/weights_enhancement_unet_sc_10x-iki.h5'
previous_path = '../data/reference_reg_10x-iki/'
follow_up_path = '../data/predicted/10x-iki/'
reference_path = '../../../data/brain-cancer/'
out_path = '../data/predicted/10x-enhanced-iki/'#save folder
os.makedirs(out_path, exist_ok=True)


In [1]:
all_files = np.loadtxt("../data/train_val_test_split/test_long.txt",dtype = str)
test_files = all_files

for ii in range(test_files.shape[0]):
    ref = nib.load(reference_path + all_files[ii][1])
    aux = nib.load(os.path.join(previous_path,'elastic_' + test_files[ii,0][:-4] + "_" + test_files[ii,1][:14] + ".nii")).get_fdata()
    aux2 = nib.load(os.path.join(follow_up_path,test_files[ii,1][:-4] + '_predicted.nii')).get_fdata()
    aux = np.swapaxes(aux,0,2) / np.abs(aux).max()
    aux2 = np.swapaxes(aux2,0,2) / np.abs(aux2).max()
    Z,W,H = aux.shape
    print(aux.shape, aux2.shape)
    test = np.zeros((aux.shape[0],aux.shape[1],aux.shape[2],2),dtype = np.float32)
    print(test.shape)
    test[:,:,:,0] = aux
    test[:,:,:,1] = aux2

    Hpad = (8 - (H%8))//2
    Wpad = (8 - (W%8))//2

    model = eu.enhancement_unet(H=H, W=W,Hpad = Hpad, Wpad = Wpad)
    opt = Adam(learning_rate = lr,decay = decay)
    model.compile(loss = 'mse',optimizer=opt)
    model.load_weights(weights_path)
    print(model.summary())
    pred = model.predict([test,test[:,:,:,1,np.newaxis]])
    
    pred = np.swapaxes(pred,0,2)
    #pred = np.flip(pred,axis=-1)
    out_file = os.path.join(out_path,test_files[ii,1][:-4] + '_predicted.nii')
    img = nib.Nifti1Image(pred[:,:,:,0], ref.affine)
    nib.save(img,out_file)
    print(img.shape,ref.shape)

(152, 512, 512) (152, 512, 512)
(152, 512, 512, 2)
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 512, 512, 2) 0                                            
__________________________________________________________________________________________________
zero_padding2d (ZeroPadding2D)  (None, 520, 520, 2)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 520, 520, 48) 912         zero_padding2d[0][0]             
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 520, 520, 48) 20784       conv2d[0][0]                     
___________________________________________

(512, 512, 159) (512, 512, 159)
(179, 512, 512) (179, 512, 512)
(179, 512, 512, 2)
Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            [(None, 512, 512, 2) 0                                            
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 520, 520, 2)  0           input_5[0][0]                    
__________________________________________________________________________________________________
conv2d_44 (Conv2D)              (None, 520, 520, 48) 912         zero_padding2d_2[0][0]           
__________________________________________________________________________________________________
conv2d_45 (Conv2D)              (None, 520, 520, 48) 20784       conv2d_44[0][0]                  
_________

(512, 512, 142) (512, 512, 142)
(156, 512, 512) (156, 512, 512)
(156, 512, 512, 2)
Model: "model_4"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_9 (InputLayer)            [(None, 512, 512, 2) 0                                            
__________________________________________________________________________________________________
zero_padding2d_4 (ZeroPadding2D (None, 520, 520, 2)  0           input_9[0][0]                    
__________________________________________________________________________________________________
conv2d_88 (Conv2D)              (None, 520, 520, 48) 912         zero_padding2d_4[0][0]           
__________________________________________________________________________________________________
conv2d_89 (Conv2D)              (None, 520, 520, 48) 20784       conv2d_88[0][0]                  
_________

(512, 512, 166) (512, 512, 166)
(165, 512, 512) (165, 512, 512)
(165, 512, 512, 2)
Model: "model_6"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_13 (InputLayer)           [(None, 512, 512, 2) 0                                            
__________________________________________________________________________________________________
zero_padding2d_6 (ZeroPadding2D (None, 520, 520, 2)  0           input_13[0][0]                   
__________________________________________________________________________________________________
conv2d_132 (Conv2D)             (None, 520, 520, 48) 912         zero_padding2d_6[0][0]           
__________________________________________________________________________________________________
conv2d_133 (Conv2D)             (None, 520, 520, 48) 20784       conv2d_132[0][0]                 
_________