## Data Preparation

You should prepare two things before running this step. Please refer to the `example_data` folder for guidance:

1. **NIfTI images** of 4DCT resampled to a voxel size of **[1.5, 1.5, 1.5] mm³**.  
   - Each cardiac phase should be saved as a separate file.  
   - All files should be placed in a folder named:  
     `img-nii-resampled-1.5mm`.

2. **A patient list** that enumerates all your cases.  
   - To understand the expected format, please refer to the file:  
     `example_data/Patient_lists/example_data/patient_list.xlsx`.
   - Make sure the number of time frames is equal to the number of nii files


---

## Get reference MVF for each 4DCT case

In this script, we use Voxelmorph to obtain **motion vector field (MVF)** for each cardiac phase (time frame), using the first phase (end-diastole) as template

---

### Docker environment
1. Please use `docker/docker_tensorflow`, it will build a tensorflow docker
2. make sure you have `voxelmorph` installed


In [1]:
# imports
import os, sys
sys.path.append('/workspace/Documents')

# third party imports
import numpy as np 
import tensorflow as tf

import voxelmorph as vxm
import neurite as ne
import pandas as pd
import random
import nibabel as nb

from tensorflow.keras.utils import Sequence
import Cardiac4DCT_Synth_Diffusion.Build_lists.Build_list as Build_list
import Cardiac4DCT_Synth_Diffusion.functions_collection as ff
import Cardiac4DCT_Synth_Diffusion.Data_processing as Data_processing
import Cardiac4DCT_Synth_Diffusion.Generator_voxelmorph as Generator_voxelmorph
main_path = '/mnt/camca_NAS/4DCT/'

2025-06-27 14:51:09.872624: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


For each case, the voxelmorph is optimized/trained individually. Then we apply the trained model to get MVF

In [2]:
trial_name = 'voxel_morph_warp0_onecase'
which_timeframe_is_template = 'others' # 'others' means warp 0 to other time frames, '0' means warp other time frames to 0

## training for each case until converge

In [4]:
######## Training ########

## set patient list
data_sheet = os.path.join(main_path, 'example_data/Patient_lists/example_data/patient_list.xlsx')

b = Build_list.Build(data_sheet)
patient_class_list, patient_id_list, tf_list = b.__build__(batch_list = [0]) 

results = []
for i in range(0,1):
    patient_class = patient_class_list[i]
    patient_id = patient_id_list[i]
    tf_num = tf_list[i]
    
    print(patient_class, patient_id, tf_num)

    # set save path
    save_path = os.path.join(main_path, 'example_data//models', trial_name, 'individual_models',patient_class,patient_id)
    ff.make_folder([os.path.join(main_path,'example_data/models'), os.path.join(main_path,'example_data/models',trial_name), 
                    os.path.join(main_path,'example_data/models',trial_name, 'individual_models'), 
                    os.path.join(main_path,'example_data/models',trial_name, 'individual_models', patient_class), 
                    os.path.join(main_path,'example_data/models',trial_name, 'individual_models', patient_class, patient_id),
                    os.path.join(save_path,'models'), 
                    os.path.join(save_path,'logs')])

    # check whether the patient has been processed
    if os.path.isfile(os.path.join(save_path,'models/vxm_final.h5')):
        print('patient:', patient_id, 'has been processed')
        continue

    ## build the model
    input_shape = [160,160,96]
    nb_features = [[16, 32, 32, 32],[32, 32, 32, 32, 32, 16, 16]]
    vxm_model = vxm.networks.VxmDense(input_shape, nb_features, int_steps=0)
    losses = [vxm.losses.MSE().loss, vxm.losses.Grad('l2').loss]
    loss_weights = [1, 0.05]
    vxm_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), loss=losses, loss_weights=loss_weights)

    ## set the generator
    train_generator = Generator_voxelmorph.DataGenerator_alltf(
            np.asarray([patient_class]),
            np.asarray([patient_id]),
            np.asarray([tf_num]),
            which_timeframe_is_template = which_timeframe_is_template,
            main_path = '/mnt/camca_NAS/4DCT/',
            patient_num = 1,
            batch_size = 1,
            shuffle = False,
            normalize = True,
            adapt_shape = [160,160,96],)
    
    # check whether there is a pre-trained model
    pre_model_list = ff.find_all_target_files(['vxm_*'],os.path.join(save_path, 'models'))
    pre_model_list = np.delete(pre_model_list, np.where(pre_model_list == os.path.join(save_path, 'models/vxm_final.h5'))) # if you want to re-train the model, you need to delete the previous final one
  
    if len(pre_model_list) == 0:
        start_epoch = 0
        print('no pre-trained model')
    else:
        pre_model_list = ff.sort_timeframe(pre_model_list,1,'_')
        pre_model = pre_model_list[-1]
        start_epoch = ff.find_timeframe(pre_model,1,'_')
        vxm_model.load_weights(pre_model)
        print('pre-trained model loaded, epoch:', start_epoch)

    # ### train the model
    nb_epochs = 1000

    ### Initialize an Excel sheet data storage
    loss_results = []

    ### training loop
    previous_loss = 100; freeze_count = 0
    for epoch in range(start_epoch , start_epoch + nb_epochs):
        print(f"Epoch {epoch + 1}/{nb_epochs}")

        # Train the model for one epoch
        hist = vxm_model.fit(train_generator, epochs=1, verbose=1,use_multiprocessing=False,workers = 1, shuffle = False,)

        # Get the training loss
        training_loss = hist.history['loss'][0]
        transformer_loss = hist.history.get('vxm_dense_transformer_loss', [None])[0]
        flow_loss = hist.history.get('vxm_dense_flow_loss', [None])[0]

        if (epoch + 1) % 10 == 0:
            # save the loss results
            epoch_results = [epoch + 1, training_loss, transformer_loss, flow_loss]
            loss_results.append(epoch_results)
            df = pd.DataFrame(loss_results, columns=['Epoch', 'Training Loss', 'Transformer Loss', 'Flow Loss'])
            file_name = os.path.join(save_path, 'logs/training_metrics.xlsx')
            df.to_excel(file_name, index=False)

            # Save the model parameters for each epoch
            vxm_model.save(os.path.join(save_path,'models/vxm_'+str(epoch + 1)+'.h5'))

            training_loss_round = round(training_loss, 4)
            # check whether we should stop the training
            if training_loss_round < previous_loss:
                previous_loss = training_loss_round; freeze_count = 0
            else:
                freeze_count += 1

            if epoch <= 150:
                continue # at least train 150 epochs

            if training_loss_round <= 0.0021 or epoch >= 300:
                print('training loss is less than 0.0021 or epoch >= 300, stop at epoch:', epoch)
                # copy and paste the last model to the final model
                vxm_model.save(os.path.join(save_path,'models/vxm_final.h5'))
                break
            
            if freeze_count >= 4: # 40 epochs no improvement
                print('training loss has not improved for 40 epochs, stop at epoch:', epoch)
                # copy and paste the last model to the final model
                vxm_model.save(os.path.join(save_path,'models/vxm_final.h5'))
                break

## Generate ground truth MVF for each case using trained Voxelmorph

In [5]:
####### Testing ########
# set the patient list
data_sheet = os.path.join(main_path,'example_data/Patient_lists/example_data/patient_list.xlsx')

b = Build_list.Build(data_sheet)
patient_class_test_list, patient_id_test_list, tf_test_list = b.__build__(batch_list = [0]) 

for i in range(0, patient_id_test_list.shape[0]):
    patient_class = patient_class_test_list[i]
    patient_id = patient_id_test_list[i]
    tf_num = tf_test_list[i]
   
    print(patient_class, patient_id, tf_num)

    ### check whether the we have the voxel final model
    model_path = os.path.join(main_path, 'example_data/models', trial_name, 'individual_models',patient_class,patient_id,'models/vxm_final.h5')
    if not os.path.isfile(model_path):
        print('no model for patient:', patient_id)
        continue

    ### set save path
    save_path = os.path.join(main_path, 'example_data/mvf_warp0_onecase',patient_class,patient_id, 'voxel_final')
    ff.make_folder([os.path.join(main_path,'example_data/mvf_warp0_onecase'), os.path.join(main_path,'example_data/mvf_warp0_onecase',patient_class),
                    os.path.join(main_path,'example_data/mvf_warp0_onecase',patient_class,patient_id),
                    save_path])

    ### build the model
    input_shape = [160,160,96]
    nb_features = [[16, 32, 32, 32],[32, 32, 32, 32, 32, 16, 16]]
    vxm_model = vxm.networks.VxmDense(input_shape, nb_features, int_steps=0)
    # voxelmorph has a variety of custom loss classes
    losses = [vxm.losses.MSE().loss, vxm.losses.Grad('l2').loss]
    loss_weights = [1, 0.05]
    vxm_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), loss=losses, loss_weights=loss_weights)
    vxm_model.load_weights(model_path)

    ### do the prediction
    image_path = os.path.join(main_path,'example_data/nii-images' ,patient_class, patient_id,'img-nii-resampled-1.5mm')
    tf_files = ff.sort_timeframe(ff.find_all_target_files(['*.nii.gz'],image_path),2)

    affine = nb.load(tf_files[0]).affine
    image_shape = nb.load(tf_files[0]).shape
    
    for timeframe in range(0,len(tf_files)):
        if os.path.isfile(os.path.join(save_path, str(timeframe) + '.nii.gz')) == 1:
            print('timeframe:', timeframe, 'has been processed')
        else:
            original_image = nb.load(tf_files[timeframe]).get_fdata()
            if timeframe == 0:
                mvf = np.zeros([160,160,96,3]) 
                mae = 0
                moved_pred = nb.load(tf_files[0]).get_fdata()
                moved_pred = Data_processing.crop_or_pad(moved_pred, [160,160,96], value = np.min(original_image))
            else:
                if which_timeframe_is_template == 'others':
                    tf1 = nb.load(tf_files[0]).get_fdata()
                    tf2 = nb.load(tf_files[timeframe]).get_fdata()
                else:
                    tf1 = nb.load(tf_files[timeframe]).get_fdata()
                    tf2 = nb.load(tf_files[0]).get_fdata()

                if len(tf1.shape) == 4:
                    tf1 = tf1[...,0]
                if len(tf2.shape) == 4:
                    tf2 = tf2[...,0]
                
                tf1 = Data_processing.crop_or_pad(tf1, [160,160,96], value = np.min(tf1)) / 1000
                tf2 = Data_processing.crop_or_pad(tf2, [160,160,96], value = np.min(tf2)) / 1000

              
                val_input = [ tf1[np.newaxis, ..., np.newaxis],
                    tf2[np.newaxis, ..., np.newaxis]]
                    
                val_pred = vxm_model.predict(val_input)
                moved_pred = val_pred[0].squeeze() * 1000
                pred_warp = val_pred[1]
                mvf = pred_warp.squeeze()
                
            save_file = os.path.join(save_path, str(timeframe) + '.nii.gz')
            img = nb.Nifti1Image(mvf, affine)
            nb.save(img, save_file)

            moved_pred_img = nb.Nifti1Image(moved_pred, affine)
            nb.save(moved_pred_img, os.path.join(save_path, str(timeframe) + '_moved.nii.gz'))

            original_image = Data_processing.crop_or_pad(original_image, [160,160,96], value = np.min(original_image))
            nb.save(nb.Nifti1Image(original_image, affine), os.path.join(save_path, str(timeframe) + '_original.nii.gz'))

example_data example_1 20
example_data example_2 20
