The notebooks takes unseen test data and produces a multi-visit reconstruction based on the initial reconstruction (Step 01) and the previous registered scan (Step 02).

A docker container was again used to run the following notebook

In [1]:
import subprocess

def install(name):
    subprocess.call(['pip', 'install', name])

install('nibabel')
install('scikit-learn')

import sys


import numpy as np
import os
import glob
import sys
import matplotlib.pyplot as plt
# Importing keras and data augmentation utils
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.optimizers import  Adam
import pickle
import nibabel as nib
from keras.models import load_model
from keras.preprocessing.image import ImageDataGenerator

# Include path where reconstruction modules are defined
MY_UTILS_PATH = "../src"
if not MY_UTILS_PATH in sys.path:
    sys.path.append(MY_UTILS_PATH)

import enhancement_unet as eu

In [40]:
batch_size = 16
lr = 1e-4
decay = 1e-7

In [82]:
#model_path = '../models/enhancement_unet_sc_512x512_5x.h5'
weights_path = '../models/weights_enhancement_unet_sc_10x-recurrentvar-rnl-3.h5'
previous_path = '../data/reference_reg_10x-varnet/'
follow_up_path = '../data/predicted/10x-varnet/nib/'
reference_path = '../../../data/brain-cancer/'
out_path = '../data/predicted/10x-enhanced-varnet/'#save folder
os.makedirs(out_path, exist_ok=True)


In [97]:
test_files = np.loadtxt("../data/test_long.txt",dtype = str)
print(test_files)

[['TUM04-20170309.nii' 'TUM04-20171108.nii' 'TUM04-20180103.nii']
 ['TUM04-20171108.nii' 'TUM04-20190328.nii' 'TUM04-20190523.nii']
 ['TUM04-20180618.nii' 'TUM04-20181009.nii' 'TUM04-20190328.nii']
 ['TUM20-20170928.nii' 'TUM20-20180205.nii' 'TUM20-20180402.nii']
 ['TUM10-20170316.nii' 'TUM10-20171018.nii' 'TUM10-20180122.nii']
 ['TUM10-20171018.nii' 'TUM10-20180122.nii' 'TUM10-20180307.nii']
 ['TUM15-20170531.nii' 'TUM15-20170801.nii' 'TUM15-20170816.nii']]


In [98]:
test_files = test_files[:,1:]

In [99]:
test_files

array([['TUM04-20171108.nii', 'TUM04-20180103.nii'],
       ['TUM04-20190328.nii', 'TUM04-20190523.nii'],
       ['TUM04-20181009.nii', 'TUM04-20190328.nii'],
       ['TUM20-20180205.nii', 'TUM20-20180402.nii'],
       ['TUM10-20171018.nii', 'TUM10-20180122.nii'],
       ['TUM10-20180122.nii', 'TUM10-20180307.nii'],
       ['TUM15-20170801.nii', 'TUM15-20170816.nii']], dtype='<U18')

In [100]:
count = 0
for ii in range(test_files.shape[0]):
    ref = nib.load(reference_path + test_files[ii][1])
    aux = nib.load(os.path.join(previous_path,'elastic_' + test_files[ii,0][:-4] + "_" + test_files[ii,1][:14] +"-2.nii")).get_fdata()
    aux2 = nib.load(os.path.join(follow_up_path, test_files[ii,1][:-4] + '_predicted.nii')).get_fdata()
    aux = np.swapaxes(aux,0,2) / np.abs(aux).max()
    aux2 = np.swapaxes(aux2,0,2) / np.abs(aux2).max()
    Z,W,H = aux.shape
    count = count + aux.shape[0] - 40
    print(aux.shape, aux2.shape)
    

(160, 512, 512) (160, 512, 512)
(163, 512, 512) (163, 512, 512)
(152, 512, 512) (152, 512, 512)
(179, 512, 512) (179, 512, 512)
(142, 512, 512) (142, 512, 512)
(156, 512, 512) (156, 512, 512)
(165, 512, 512) (165, 512, 512)


In [101]:
count


837

In [102]:


for ii in range(test_files.shape[0]):
    ref = nib.load(reference_path + test_files[ii][1])
    aux = nib.load(os.path.join(previous_path,'elastic_' + test_files[ii,0][:-4] + "_" + test_files[ii,1][:14]  +"-2.nii")).get_fdata()
    aux2 = nib.load(os.path.join(follow_up_path, test_files[ii,1][:-4] + '_predicted.nii')).get_fdata()
    aux = np.swapaxes(aux,0,2) / np.abs(aux).max()
    aux2 = np.swapaxes(aux2,0,2) / np.abs(aux2).max()
    Z,W,H = aux.shape
    print(aux.shape, aux2.shape)
    test = np.zeros((aux.shape[0],aux.shape[1],aux.shape[2],2),dtype = np.float32)
    print(test.shape)
    test[:,:,:,0] = aux
    test[:,:,:,1] = aux2

    Hpad = (8 - (H%8))//2
    Wpad = (8 - (W%8))//2

    model = eu.enhancement_unet(H=H, W=W,Hpad = Hpad, Wpad = Wpad)
    opt = Adam(learning_rate = lr,decay = decay)
    model.compile(loss = 'mse',optimizer=opt)
    model.load_weights(weights_path)
    print(model.summary())
    pred = model.predict([test,test[:,:,:,1,np.newaxis]])
    
    pred = np.swapaxes(pred,0,2)
    #pred = np.flip(pred,axis=-1)
    out_file = os.path.join(out_path,test_files[ii,1][:-4] + '_predicted-2.nii')
    img = nib.Nifti1Image(pred[:,:,:,0], ref.affine)
    nib.save(img,out_file)
    print(img.shape,ref.shape)

(160, 512, 512) (160, 512, 512)
(160, 512, 512, 2)
Model: "model_61"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_123 (InputLayer)         [(None, 512, 512, 2  0           []                               
                                )]                                                                
                                                                                                  
 zero_padding2d_61 (ZeroPadding  (None, 520, 520, 2)  0          ['input_123[0][0]']              
 2D)                                                                                              
                                                                                                  
 conv2d_1342 (Conv2D)           (None, 520, 520, 48  912         ['zero_padding2d_61[0][0]']      
                                )       

                                                                                                  
 concatenate_185 (Concatenate)  (None, 520, 520, 11  0           ['up_sampling2d_185[0][0]',      
                                2)                                'conv2d_1344[0][0]']            
                                                                                                  
 conv2d_1360 (Conv2D)           (None, 520, 520, 48  48432       ['concatenate_185[0][0]']        
                                )                                                                 
                                                                                                  
 conv2d_1361 (Conv2D)           (None, 520, 520, 48  20784       ['conv2d_1360[0][0]']            
                                )                                                                 
                                                                                                  
 conv2d_13

                                                                                                  
 conv2d_1376 (Conv2D)           (None, 130, 130, 12  442496      ['concatenate_186[0][0]']        
                                8)                                                                
                                                                                                  
 conv2d_1377 (Conv2D)           (None, 130, 130, 12  147584      ['conv2d_1376[0][0]']            
                                8)                                                                
                                                                                                  
 conv2d_1378 (Conv2D)           (None, 130, 130, 12  147584      ['conv2d_1377[0][0]']            
                                8)                                                                
                                                                                                  
 up_sampli

                                                                                                  
 max_pooling2d_190 (MaxPooling2  (None, 130, 130, 64  0          ['conv2d_1391[0][0]']            
 D)                             )                                                                 
                                                                                                  
 conv2d_1392 (Conv2D)           (None, 130, 130, 12  73856       ['max_pooling2d_190[0][0]']      
                                8)                                                                
                                                                                                  
 conv2d_1393 (Conv2D)           (None, 130, 130, 12  147584      ['conv2d_1392[0][0]']            
                                8)                                                                
                                                                                                  
 conv2d_13

 input_129 (InputLayer)         [(None, 512, 512, 2  0           []                               
                                )]                                                                
                                                                                                  
 zero_padding2d_64 (ZeroPadding  (None, 520, 520, 2)  0          ['input_129[0][0]']              
 2D)                                                                                              
                                                                                                  
 conv2d_1408 (Conv2D)           (None, 520, 520, 48  912         ['zero_padding2d_64[0][0]']      
                                )                                                                 
                                                                                                  
 conv2d_1409 (Conv2D)           (None, 520, 520, 48  20784       ['conv2d_1408[0][0]']            
          

                                2)                                'conv2d_1410[0][0]']            
                                                                                                  
 conv2d_1426 (Conv2D)           (None, 520, 520, 48  48432       ['concatenate_194[0][0]']        
                                )                                                                 
                                                                                                  
 conv2d_1427 (Conv2D)           (None, 520, 520, 48  20784       ['conv2d_1426[0][0]']            
                                )                                                                 
                                                                                                  
 conv2d_1428 (Conv2D)           (None, 520, 520, 48  20784       ['conv2d_1427[0][0]']            
                                )                                                                 
          

                                8)                                                                
                                                                                                  
 conv2d_1443 (Conv2D)           (None, 130, 130, 12  147584      ['conv2d_1442[0][0]']            
                                8)                                                                
                                                                                                  
 conv2d_1444 (Conv2D)           (None, 130, 130, 12  147584      ['conv2d_1443[0][0]']            
                                8)                                                                
                                                                                                  
 up_sampling2d_196 (UpSampling2  (None, 260, 260, 12  0          ['conv2d_1444[0][0]']            
 D)                             8)                                                                
          

 D)                             )                                                                 
                                                                                                  
 conv2d_1458 (Conv2D)           (None, 130, 130, 12  73856       ['max_pooling2d_199[0][0]']      
                                8)                                                                
                                                                                                  
 conv2d_1459 (Conv2D)           (None, 130, 130, 12  147584      ['conv2d_1458[0][0]']            
                                8)                                                                
                                                                                                  
 conv2d_1460 (Conv2D)           (None, 130, 130, 12  147584      ['conv2d_1459[0][0]']            
                                8)                                                                
          

                                )]                                                                
                                                                                                  
 zero_padding2d_67 (ZeroPadding  (None, 520, 520, 2)  0          ['input_135[0][0]']              
 2D)                                                                                              
                                                                                                  
 conv2d_1474 (Conv2D)           (None, 520, 520, 48  912         ['zero_padding2d_67[0][0]']      
                                )                                                                 
                                                                                                  
 conv2d_1475 (Conv2D)           (None, 520, 520, 48  20784       ['conv2d_1474[0][0]']            
                                )                                                                 
          

 conv2d_1492 (Conv2D)           (None, 520, 520, 48  48432       ['concatenate_203[0][0]']        
                                )                                                                 
                                                                                                  
 conv2d_1493 (Conv2D)           (None, 520, 520, 48  20784       ['conv2d_1492[0][0]']            
                                )                                                                 
                                                                                                  
 conv2d_1494 (Conv2D)           (None, 520, 520, 48  20784       ['conv2d_1493[0][0]']            
                                )                                                                 
                                                                                                  
 conv2d_1495 (Conv2D)           (None, 520, 520, 1)  49          ['conv2d_1494[0][0]']            
          