## Data Preparation

You should prepare the following things before running this step. Please refer to the `example_data` folder for guidance:

1. **PAR images and corresponding motion parameters** from step 1
   - the default dimension of PAR is [25,128,128,50] where 25 is the number of PAR images, [128,128] is x-y-dimension, 50 is the number of slices

2. **A patient list** that enumerates all your cases.  
   - To understand the expected format, please refer to the file:  
     `example_data/Patient_list/patient_list.xlsx`.
   - our model takes 15 consecutive slices as model input, so we define three starting slice (5,20,35) in the patient list, which represent different regions of the head (bottom, mid, top)
---

### Docker environment
1. Please use `docker/docker_tensorflow`, it will build a tensorflow docker


In [3]:
# %%
import os, sys
sys.path.append('/workspace/Documents')

import argparse
import os
import numpy as np
import nibabel as nb
from contextlib import redirect_stdout
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.regularizers import l2

import HeadCT_MotionCorrection_PARDL.STN.model_STN as model_STN
import HeadCT_MotionCorrection_PARDL.STN.Generator_STN as Generator_STN
import HeadCT_MotionCorrection_PARDL.Data_processing as dp
import HeadCT_MotionCorrection_PARDL.functions_collection as ff
from HeadCT_MotionCorrection_PARDL.Build_lists import Build_list
import HeadCT_MotionCorrection_PARDL.Hyperparameters as hyper

tf.random.set_seed(int(np.random.rand() * 500))
main_path = '/mnt/camca_NAS/motion_correction/'  # replace with your main path

### set the trial name and default parameters

In [4]:
trial_name = 'PAR_model'

CP_num = 5
input_shape = (128,128,15,25)  # 128 for x dim, 128 for y dim, 15 for z dim (15 slices by default), 25 for the number of PAR images

save_folder = os.path.join(main_path, 'example_data/models', trial_name)
ff.make_folder([os.path.join(main_path,'example_data/models'), save_folder])
    

### set the patient list

In [5]:
data_sheet = os.path.join(main_path,'example_data/Patient_list/patient_list.xlsx')
b = Build_list.Build(data_sheet)

_, _, _, _, start_slice_trn, end_slice_trn, _, _,  y_motion_param_trn, x_par_image_trn = b.__build__(batch_list = [0])
_, _, _, _, start_slice_val, end_slice_val, _, _, y_motion_param_val, x_par_image_val = b.__build__(batch_list = [0]) # just as an example, we use train as validation data
  

### build model

In [6]:
model_inputs = [Input(input_shape)]
model_outputs=[]
tx, ty, tz, rx, ry, rz = model_STN.get_CNN(nb_filters = [16,32,64,128,256], dimension = 3, CP_num = CP_num)(model_inputs[0])
model_outputs = [tx, ty, tz, rx , ry, rz]
model = Model(inputs = model_inputs,outputs = model_outputs)

# if continue your training:
load_model_file = None # or your own model file path
if load_model_file != None:
    print(load_model_file)
    model.load_weights(load_model_file)

# compile model
print('Compile Model...')
opt = Adam(lr = 1e-4)
weights = [1,1,1,1,1,1]
model.compile(optimizer= opt, 
                loss= ['MSE','MSE','MSE', 'MSE', 'MSE', 'MSE'],
                loss_weights = weights,)

# set callbacks
print('Set callbacks...')
model_fld = os.path.join(save_folder,'models')
ff.make_folder([model_fld, os.path.join(save_folder, 'logs')])

filepath=os.path.join(model_fld, 'model-{epoch:03d}.hdf5')

csv_logger = CSVLogger(os.path.join(save_folder, 'logs','training-log.csv'))
callbacks = [csv_logger,
            ModelCheckpoint(filepath,          
                                monitor='val_loss',
                                save_best_only=False,),
                    LearningRateScheduler(hyper.learning_rate_step_decay_classic)]

2025-07-09 17:37:11.741400: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2025-07-09 17:37:11.741565: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2025-07-09 17:37:11.833081: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:941] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2025-07-09 17:37:11.835262: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:13:00.0 name: NVIDIA A100-SXM4-40GB computeCapability: 8.0
coreClock: 1.41GHz coreCount: 108 deviceMemorySize: 39.49GiB deviceMemoryBandwidth: 1.41TiB/s
2025-07-09 17:37:11.835309: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2025-07-09 17:37:11.835389: I tensorflow/stream_executor/platform/defau

Compile Model...
Set callbacks...


### data generator

In [7]:
datagen = Generator_STN.DataGenerator(x_par_image_trn,
                                    y_motion_param_trn, 
                                    start_slice_trn,
                                    end_slice_trn,
                                    start_slice_sampling = np.array([5,20,35]), # model input size is 15 slices so we define three different start slices which represent different regions of the head
                                    patient_num = x_par_image_trn.shape[0], 
                                    batch_size = 1,
                                    input_dimension = input_shape,
                                    output_vector_dimension = (CP_num - 1,),
                                    shuffle = True,
                                    augment = False,
                                    seed = 10)


valgen = Generator_STN.DataGenerator(x_par_image_val,
                                    y_motion_param_val,
                                    start_slice_val,
                                    end_slice_val,
                                    start_slice_sampling = np.array([5,20,35]),
                                    patient_num = x_par_image_val.shape[0], 
                                    batch_size = 1,
                                    input_dimension = input_shape,
                                    output_vector_dimension = (CP_num - 1,),
                                    shuffle = False,
                                    augment = False,
                                    seed = 11)

### train the model

In [9]:
model.fit_generator(generator = datagen,
                        epochs = 100,
                        validation_data = valgen,
                        callbacks = callbacks,
                        verbose = 1,
                        )