## Model training

In this script, we use ***synthesized MVFs*** and ***3DCT template*** to generate 4DCT sequences


In [14]:
import sys
sys.path.append('/workspace/Documents')

# third party imports
import torch
import numpy as np 
import pandas as pd
import random
import nibabel as nb
import ast
from skimage.measure import block_reduce
from scipy.ndimage import zoom
import Cardiac4DCT_Synth_Diffusion.Build_lists.Build_list as Build_list
import Cardiac4DCT_Synth_Diffusion.functions_collection as ff
import Cardiac4DCT_Synth_Diffusion.Data_processing as Data_processing
import Cardiac4DCT_Synth_Diffusion.denoising_diffusion_pytorch.denoising_diffusion_pytorch.conditional_EDM_warp as warp_func

main_path = '/mnt/camca_NAS/4DCT/'

In [18]:
timeframe_info = pd.read_excel(os.path.join(main_path,'example_data/Patient_lists/example_data/patient_list_final_selection_timeframes.xlsx'))

trial_name = 'MVF_EDM'

save_path = os.path.join(main_path, 'example_data/models', trial_name, 'pred_mvf')

# # define training and validation data
data_sheet = os.path.join(main_path,'example_data/Patient_lists/example_data/patient_list.xlsx')

b = Build_list.Build(data_sheet)
patient_class_list, patient_id_list,_ = b.__build__(batch_list = [0])

In [19]:
for i in range(0,patient_class_list.shape[0]):
    
    patient_class = patient_class_list[i]
    patient_id = patient_id_list[i]
    print(i, patient_class, patient_id)

    img_path = os.path.join(main_path,'example_data/nii-images' ,patient_class, patient_id,'img-nii-resampled-1.5mm')
    save_folder = os.path.join(save_path, patient_class, patient_id)
    ff.make_folder([os.path.dirname(save_folder), save_folder])

    tf_files = ff.sort_timeframe(ff.find_all_target_files(['*.nii.gz'],img_path),2)
    template_image = nb.load(tf_files[0]).get_fdata()
    if len(template_image.shape) == 4:
        template_image = template_image[:,:,:,0]
    template_image = Data_processing.crop_or_pad(template_image, [160,160,96], value = np.min(template_image))
    affine = nb.load(tf_files[0]).affine

    row = timeframe_info[timeframe_info['patient_id'] == patient_id]
    sampled_time_frame_list = ast.literal_eval(row['sampled_time_frame_list'].iloc[0])
    normalized_time_frame_list = ast.literal_eval(row['normalized_time_frame_list_copy'].iloc[0])

    picked_tf_normalized = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    picked_tf = [sampled_time_frame_list[normalized_time_frame_list.index(picked_tf_normalized[iii])] for iii in range(0,len(picked_tf_normalized))]   
    print('picked tf:' ,picked_tf)

    save_folder_sub_list = ff.find_all_target_files(['*'],save_folder)

    for ss in range(0, save_folder_sub_list.shape[0]):
        save_folder_sub = save_folder_sub_list[ss]
        print('current folder:', save_folder_sub)
        nb.save(nb.Nifti1Image(template_image, affine), os.path.join(save_folder_sub, 'template_img.nii.gz'))
        
    
        for tf in picked_tf:
            warped_img_gt = nb.load(tf_files[tf]).get_fdata()
            if len(warped_img_gt.shape) == 4:
                warped_img_gt = warped_img_gt[:,:,:,0]
            warped_img_gt = Data_processing.crop_or_pad(warped_img_gt, [160,160,96], value = np.min(warped_img_gt))

            # load mvf
            mvf_x = nb.load(os.path.join(save_folder_sub, 'pred_tf'+str(tf)+'_x.nii.gz')).get_fdata()
            mvf_y = nb.load(os.path.join(save_folder_sub, 'pred_tf'+str(tf)+'_y.nii.gz')).get_fdata()
            mvf_z = nb.load(os.path.join(save_folder_sub, 'pred_tf'+str(tf)+'_z.nii.gz')).get_fdata()
          
            mvf = np.stack([mvf_x, mvf_y, mvf_z], axis = -1)
            if tf == 0:
                mvf = np.zeros_like(mvf)
                
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            template_image_torch = torch.from_numpy(template_image).unsqueeze(0).unsqueeze(0).float().to(device)
            mvf_torch = torch.from_numpy(np.transpose(mvf, (3,0,1,2))).unsqueeze(0).float().cuda()
            # apply deformation field to template image
            warped_img_torch = warp_func.warp_segmentation_from_mvf(template_image_torch, mvf_torch)
            warped_img = warped_img_torch.cpu().numpy().squeeze()

            nb.save(nb.Nifti1Image(warped_img, affine), os.path.join(save_folder_sub, 'warped_4DCT_pred_tf'+str(tf)+'.nii.gz'))


0 example_data example_1
picked tf: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
save_folder : /mnt/camca_NAS/4DCT/example_data/models/MVF_EDM/pred_mvf/example_data/example_1
current folder: /mnt/camca_NAS/4DCT/example_data/models/MVF_EDM/pred_mvf/example_data/example_1/test_0
current folder: /mnt/camca_NAS/4DCT/example_data/models/MVF_EDM/pred_mvf/example_data/example_1/test_1
1 example_data example_2
picked tf: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
save_folder : /mnt/camca_NAS/4DCT/example_data/models/MVF_EDM/pred_mvf/example_data/example_2
current folder: /mnt/camca_NAS/4DCT/example_data/models/MVF_EDM/pred_mvf/example_data/example_2/test_0
current folder: /mnt/camca_NAS/4DCT/example_data/models/MVF_EDM/pred_mvf/example_data/example_2/test_1
