## Preprocessing Steps

In this script, we perform the following preprocessing tasks:

1. Prepare the **left ventricular (LV) segmentation masks** using pre-trained network.
2. Sample the original time frames into **10 evenly spaced cardiac phases** and Get ground truth Ejection fraction (LVEF) for each case.
3. Prepare augmented data for training (since data is large, on-the-fly augmentation will be time-consuming)


In [2]:
import sys
sys.path.append('/workspace/Documents')
import os
import nibabel as nb
import numpy as np
import pandas as pd
import torch
import Cardiac4DCT_Synth_Diffusion.Build_lists.Build_list as Build_list
import Cardiac4DCT_Synth_Diffusion.functions_collection as ff
import Cardiac4DCT_Synth_Diffusion.Data_processing as Data_processing

main_path = '/mnt/camca_NAS/4DCT' 


### Task 1: Prepare LV segmentation masks using pre-trained segmentation network
this mask will be used for our mapping funciton during diffusion model training

In [2]:
import Cardiac4DCT_Synth_Diffusion.segmentation_network.Generator as Generator
import Cardiac4DCT_Synth_Diffusion.segmentation_network.model as seg_model
save_folder = os.path.join(main_path,'example_data/predicted_seg')
ff.make_folder([save_folder])

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# define pre-trained segmentation model and the patient list
trained_model_filename = os.path.join(main_path, 'models', 'seg_3D', 'models/model-352.pt')

data_sheet = os.path.join(main_path,'example_data/Patient_lists/example_data/patient_list.xlsx')
b = Build_list.Build(data_sheet)
patient_class_list, patient_id_list,_ = b.__build__(batch_list = [0])

In [4]:
# model
model = seg_model.Unet3D(
    init_dim = 16,
    channels = 1,
    num_classes = 4, #4 classes: background, LV, LA, LVOT
    dim_mults = (2,4,8,16),
    full_attn = (None,None, None, None),
    act = 'LeakyReLU',)

in out is :  [(16, 32), (32, 64), (64, 128), (128, 256)]


In [6]:
# generate segmentation
for i in range(0,patient_class_list.shape[0]):
    
    patient_class = patient_class_list[i]
    patient_id = patient_id_list[i]

    print(patient_class, patient_id)

    # get the number of time frames
    files = ff.find_all_target_files(['*.nii.gz'],os.path.join(main_path,'example_data','nii-images',patient_class, patient_id, 'img-nii-resampled-1.5mm'))
    tf_num = len(files)
    print('num of tf:', tf_num)

    save_folder_case = os.path.join(save_folder,patient_class, patient_id)
    ff.make_folder([os.path.join(save_folder, patient_class), os.path.join(save_folder, patient_class, patient_id)])

    for tf in range(0,tf_num):
        print('generating segmentation for tf:', tf)

        save_folder_case = os.path.join(save_folder,patient_class, patient_id)
        save_file = os.path.join(save_folder_case, 'pred_s_' + str(tf) + '.nii.gz')

        if os.path.isfile(save_file) == False:
            generator = Generator.Dataset_3D(
                np.asarray([patient_class]),
                np.asarray([patient_id]),
                image_folder = os.path.join(main_path,'example_data'),
                have_manual_seg = False,
                img_size_3D = [160,160,96], # default
                picked_tf = tf, #'random' or specific tf 
                )

            # sample:
            sampler = seg_model.Sampler(
                model,
                generator,
                image_size = [160,160,96], # default
                batch_size = 1)

            sampler.sample(trained_model_filename, save_file, patient_class, patient_id, picked_tf = tf, reshape_pred = True, save_gt_and_img=False, main_folder = '/mnt/camca_NAS/4DCT/example_data')

            # do postprocessing if needed (remove scatter)
            original_a = nb.load(save_file).get_fdata(); original_a = np.round(original_a).astype(int)
            a = np.copy(original_a)
            a[a != 1] = 0
            if np.sum(a) == 0:
                new_image = original_a
            else:
                new_image,need_to_remove = ff.remove_scatter3D(a,1)
                # print('need to remove:',need_to_remove)
                new_image[original_a == 2] = 2; new_image[original_a == 3] = 3
                nb.save(nb.Nifti1Image(new_image, nb.load(save_file).affine, nb.load(save_file).header), save_file)
            

example_data example_1
num of tf: 20
generating segmentation for tf: 0


  data = torch.load(trained_model_filename, map_location=self.device)


generating segmentation for tf: 1
generating segmentation for tf: 2
generating segmentation for tf: 3
generating segmentation for tf: 4
generating segmentation for tf: 5
generating segmentation for tf: 6
generating segmentation for tf: 7
generating segmentation for tf: 8
generating segmentation for tf: 9
generating segmentation for tf: 10
generating segmentation for tf: 11
generating segmentation for tf: 12
generating segmentation for tf: 13
generating segmentation for tf: 14
generating segmentation for tf: 15
generating segmentation for tf: 16
generating segmentation for tf: 17
generating segmentation for tf: 18
generating segmentation for tf: 19
example_data example_2
num of tf: 20
generating segmentation for tf: 0
generating segmentation for tf: 1
generating segmentation for tf: 2
generating segmentation for tf: 3
generating segmentation for tf: 4
generating segmentation for tf: 5
generating segmentation for tf: 6
generating segmentation for tf: 7
generating segmentation for tf: 8
g

### Task 2: Sample 10 time frames and get EF

the results will be saved into a spreadsheet

In [10]:
import sys
sys.path.append('/workspace/Documents')
import Cardiac4DCT_Synth_Diffusion.denoising_diffusion_pytorch.denoising_diffusion_pytorch.conditional_EDM_warp as warp_func

spreadsheet = pd.read_excel(os.path.join(main_path,'example_data/Patient_lists/example_data/patient_list.xlsx'))
results = []
for i in range(0,spreadsheet.shape[0]):
    patient_class = spreadsheet.iloc[i]['patient_class']
    patient_id = spreadsheet.iloc[i]['patient_id']
   
    print('patient_class:', patient_class, ' patient_id:', patient_id)

    # load segmentation
    seg_folder = os.path.join(main_path,'example_data/predicted_seg/',patient_class,patient_id)
    seg_files = ff.sort_timeframe(ff.find_all_target_files(['*.nii.gz'],seg_folder),2,'_')
    total_tf_num = len(seg_files)
    
    seg_volumes = []
    for ii in range(0, seg_files.shape[0]):
        img = nb.load(seg_files[ii]).get_fdata(); img = np.round(img).astype(np.uint8)
        seg_volumes.append(img)
    seg_volumes = np.transpose(np.asarray(seg_volumes),(1,2,3,0))

    # get a list of LV volumes (pixel value = 1) for all timeframes
    LV_volume_list = [np.sum(seg_volumes[:,:,:,i]==1) for i in range(0,seg_volumes.shape[-1])]
    LV_volume_list = np.asarray(LV_volume_list)

    # get ES time frame
    es_index = np.where(np.array(LV_volume_list) == np.min(LV_volume_list))[0][0]
    # ejection_fraction = (LV_volume_list[0] - LV_volume_list[es_index])/LV_volume_list[0]
    # last_tf_percent = (LV_volume_list[0] - LV_volume_list[-1])/LV_volume_list[0]

    # turn the time frame list into [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9] --> sample 10 time frames evenly

    # sample the temporal series
    normalized_time_frame_list = [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]; normalized_time_frame_list_copy = normalized_time_frame_list.copy()
    sampled_time_frame_list = []
    for t in range(0,len(normalized_time_frame_list)):
        # find the closest time frame to the normalized time frame
        tf_index = int(round(normalized_time_frame_list[t]*total_tf_num))
        if tf_index >= total_tf_num:
            tf_index = total_tf_num - 1
        sampled_time_frame_list.append(tf_index)

    # also calculate if pick 10 time frames, the ejection fraction is?
    # calculate the ejection fraction using segmentation at ED 0 and applying deformation field at each time frame --> don't just use the EF from segmentation, this EF from deformation field will be used in the model training
    picked_tf_normalized = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    if len(sampled_time_frame_list)< 10:
        ejection_fraction_picked_10tf_from_mvf = ''
    else:
        picked_tf = [sampled_time_frame_list[normalized_time_frame_list.index(picked_tf_normalized[iii])] for iii in range(0,len(picked_tf_normalized))]
        seg_template = nb.load(os.path.join(seg_folder, 'pred_s_0.nii.gz')).get_fdata()
        seg_template = np.round(seg_template).astype(np.int16)  
        seg_template[seg_template != 1] = 0
        seg_template = Data_processing.crop_or_pad(seg_template, [160,160,96], value = 0)

        mvf_folder = os.path.join(main_path,'example_data/mvf_warp0_onecase/',patient_class, patient_id, 'voxel_final')
        volume_list = []
        volume_list = []
        seg_template_torch = torch.from_numpy(seg_template).unsqueeze(0).unsqueeze(0).float().cuda()
        for tf_n in range(0,len(picked_tf)):
            mvf_file = os.path.join(mvf_folder, str(picked_tf[tf_n]) + '.nii.gz')
            mvf_data = nb.load(mvf_file).get_fdata()
            mvf_data_torch = torch.from_numpy(np.transpose(mvf_data, (3, 0, 1, 2))).unsqueeze(0).float().cuda()
            warped_seg = warp_func.warp_segmentation_from_mvf(seg_template_torch, mvf_data_torch)
            warped_seg = warped_seg.squeeze(0).squeeze(0).cpu().numpy()
            volume_list.append(np.sum(warped_seg))
        # print('volume_list: ', volume_list)
        volume_list = np.asarray(volume_list)
        ejection_fraction_picked_10tf_from_mvf = (LV_volume_list[0] - np.min(volume_list))/LV_volume_list[0]

    LV_volume_list = LV_volume_list.tolist()

    # # append the results to the results list
    results.append([patient_class, patient_id, total_tf_num, es_index,  sampled_time_frame_list, normalized_time_frame_list_copy, LV_volume_list, ejection_fraction_picked_10tf_from_mvf])

    df = pd.DataFrame(results, columns = ['patient_class', 'patient_id', 'total_tf_num', 'es_index','sampled_time_frame_list', 'normalized_time_frame_list_copy', 'LV_volume_list',   'EF_sampled_in_10tf_by_mvf'])
    df.to_excel(os.path.join(main_path,'example_data/Patient_lists/example_data/patient_list_final_selection_timeframes.xlsx'), index = False) 

patient_class: example_data  patient_id: example_1
patient_class: example_data  patient_id: example_2


## Task 3: Prepare augmented data

In [9]:
import ast
import shutil
import pandas as pd
import random
from scipy import ndimage
from skimage.measure import block_reduce

timeframe_info = pd.read_excel(os.path.join(main_path,'example_data/Patient_lists/example_data/patient_list_final_selection_timeframes.xlsx'))

# define the patient list
data_sheet = os.path.join(main_path,'example_data/Patient_lists/example_data/patient_list.xlsx')
b = Build_list.Build(data_sheet)
patient_class_list, patient_id_list,_ = b.__build__(batch_list = [0])

In [10]:
# do augmentation, 3 augmentation for each case
for i in range(0, patient_id_list.shape[0]):
    patient_class = patient_class_list[i]
    patient_id = patient_id_list[i]
  
    print('i:', i, 'patient_class:', patient_class, 'patient_id:', patient_id)

    # save folder 
    save_folder = os.path.join(main_path,'example_data/mvf_aug/', patient_class, patient_id)
    ff.make_folder([os.path.join(main_path,'example_data/mvf_aug'), os.path.join(main_path,'example_data/mvf_aug', patient_class), os.path.join(main_path,'example_data/mvf_aug', patient_class, patient_id)])

    # load the original MVF
    path = os.path.join(main_path, 'example_data/mvf_warp0_onecase',patient_class,patient_id,'voxel_final')
    files = ff.find_all_target_files(['*.nii.gz'],path)
    final_files = np.copy(files)
    for f in files:
        if 'moved' in f or 'original' in f:
            # remove it from the numpy array
            final_files = np.delete(final_files, np.where(final_files == f))
    files = ff.sort_timeframe(final_files,2)

    # get time frames
    row = timeframe_info[timeframe_info['patient_id'] == patient_id]
    sampled_time_frame_list = ast.literal_eval(row['sampled_time_frame_list'].iloc[0])
    print('time_frame_list:', sampled_time_frame_list)

    for aug_index in range(0,3):
        print('aug_index:', aug_index)
        # set augmentation parameters
        if os.path.isfile(os.path.join(save_folder, 'aug_'+str(aug_index), 'aug_parameter.npy')):
            print('load aug_parameter from file')
            aug_parameter = np.load(os.path.join(save_folder, 'aug_'+str(aug_index), 'aug_parameter.npy'))
            z_rotate_degree = aug_parameter[0]
            x_translate = int(aug_parameter[1])
            y_translate = int(aug_parameter[2])
        else:
            z_rotate_degree = random.uniform(-15,15) if aug_index != 0 else 0
            x_translate = int(round(random.uniform(-15,15))) if aug_index != 0 else 0
            y_translate = int(round(random.uniform(-15,15))) if aug_index != 0 else 0
            aug_parameter = [z_rotate_degree, x_translate, y_translate] if aug_index != 0 else [0,0,0]
            save_folder_aug = os.path.join(save_folder, 'aug_'+str(aug_index)); ff.make_folder([save_folder_aug])
            np.save(os.path.join(save_folder_aug, 'aug_parameter.npy'), np.asarray(aug_parameter))
        print('z_rotate_degree:', z_rotate_degree, 'x_translate:', x_translate, 'y_translate:', y_translate)
        
        ######## load each MVF and do augmentation as well as latent encoding
        for j in range(0,len(sampled_time_frame_list)):
            
            j = sampled_time_frame_list[j]
            print('current time frame: ', j , ' file:', files[j])
            
            ##### aug_index = 0, just copy the original MVF
            if aug_index == 0:
                ff.make_folder([os.path.join(save_folder, 'aug_'+str(aug_index), 'mvf'), os.path.join(save_folder, 'aug_'+str(aug_index), 'mvf_downsampled')])
                mvf = nb.load(files[j]).get_fdata()
                downsample_mvf = block_reduce(np.copy(mvf), (4,4,4,1), func=np.mean)
                affine = nb.load(files[j]).affine
                nb.save(nb.Nifti1Image(downsample_mvf, affine), os.path.join(save_folder, 'aug_'+str(aug_index), 'mvf_downsampled', str(j)+'.nii.gz'))
                continue
       
            ###### MVF
            if os.path.isfile(os.path.join(save_folder, 'aug_'+str(aug_index), 'mvf', str(j)+'.nii.gz')):
                print('Aug mvf file exists:', os.path.join(save_folder, 'aug_'+str(aug_index), 'mvf', str(j)+'.nii.gz'))
                mvf_aug = nb.load(os.path.join(save_folder, 'aug_'+str(aug_index), 'mvf', str(j)+'.nii.gz'))
                affine = mvf_aug.affine
                mvf_aug = mvf_aug.get_fdata()
            else:
                mvf = nb.load(files[j]).get_fdata()
                affine = nb.load(files[j]).affine
                mvf_aug = Data_processing.random_move(mvf,x_translate,y_translate,z_rotate_degree, fill_val=0, do_augment=True)
                # print('mvf aug shape:', mvf_aug.shape)

                # save the augmented MVF
                save_folder_aug = os.path.join(save_folder, 'aug_'+str(aug_index), 'mvf'); ff.make_folder([os.path.dirname(save_folder_aug),save_folder_aug])
                save_path = os.path.join(save_folder_aug, str(j)+'.nii.gz')
                img = nb.Nifti1Image(mvf_aug, affine)
                nb.save(img, save_path)

            ###### downsample MVF:
            if os.path.isfile(os.path.join(save_folder, 'aug_'+str(aug_index), 'mvf_downsampled', str(j)+'.nii.gz'))==1:
                print('Aug downsampled mvf file exists:', os.path.join(save_folder, 'aug_'+str(aug_index), 'mvf_downsampled', str(j)+'.nii.gz'))
            else:
                downsample_mvf_aug =  block_reduce(np.copy(mvf_aug), (4,4,4,1), func=np.mean)
                save_folder_aug_downsample = os.path.join(save_folder, 'aug_'+str(aug_index), 'mvf_downsampled'); ff.make_folder([os.path.dirname(save_folder_aug_downsample),save_folder_aug_downsample])
                save_path = os.path.join(save_folder_aug_downsample, str(j)+'.nii.gz')
                nb.save( nb.Nifti1Image(downsample_mvf_aug, affine), save_path)
        

        # augment for condition image as well
        if os.path.isfile(os.path.join(save_folder, 'aug_'+str(aug_index), 'condition_img', '0.nii.gz')):
            print('condition image file exists:', os.path.join(save_folder, 'aug_'+str(aug_index), 'condition_img', '0.nii.gz'))
        else:
            img_path = os.path.join(main_path, 'example_data/nii-images', patient_class, patient_id, 'img-nii-resampled-1.5mm/0.nii.gz')
        
            con_img = nb.load(img_path).get_fdata()
            affine = nb.load(img_path).affine
            if len(con_img.shape) == 4:
                con_img = con_img[:,:,:,0]
            con_img = Data_processing.crop_or_pad(con_img, [160,160,96], value = np.min(con_img))

            con_img1 = np.copy(con_img)
            con_img1 = Data_processing.random_move(con_img1,x_translate,y_translate,z_rotate_degree,do_augment = True, fill_val = np.min(con_img))
            con_img1 = block_reduce(con_img1, (160//40,160//40, 96//24), func=np.mean)

            save_folder_aug = os.path.join(save_folder, 'aug_'+str(aug_index), 'condition_img'); ff.make_folder([os.path.dirname(save_folder_aug),save_folder_aug])
            nb.save(nb.Nifti1Image(con_img1, affine), os.path.join(save_folder_aug, '0.nii.gz'))

        # augment for segmentation as well
        if os.path.isfile(os.path.join(save_folder, 'aug_'+str(aug_index), 'segmentation', '0.nii.gz')) and os.path.isfile(os.path.join(save_folder, 'aug_'+str(aug_index), 'segmentation_original_res', '0.nii.gz')):
            print('segmentation file exists:', os.path.join(save_folder, 'aug_'+str(aug_index), 'segmentation', '0.nii.gz'))
        else:
            seg_path = os.path.join(main_path,'example_data/predicted_seg', patient_class, patient_id,'pred_s_0.nii.gz')
            # seg_path = os.path.join('/mnt/camca_NAS/4DCT/mgh_data/predicted_seg', patient_class, patient_id,'pred_s_0.nii.gz')
            seg_img = nb.load(seg_path).get_fdata(); seg_img = np.round(seg_img).astype(np.int16)
            affine = nb.load(seg_path).affine
            if len(seg_img.shape) == 4:
                seg_img = seg_img[:,:,:,0]
            # make it binary
            seg_img[seg_img != 1] = 0
            # crop
            seg_img = Data_processing.crop_or_pad(seg_img, [160,160,96], value = 0)
            # augmentation
            seg_img1 = np.copy(seg_img)
            seg_img1 = Data_processing.random_move(seg_img1,x_translate,y_translate,z_rotate_degree,do_augment = True, fill_val = 0, order = 0)
            save_folder_aug = os.path.join(save_folder, 'aug_'+str(aug_index), 'segmentation_original_res'); ff.make_folder([os.path.dirname(save_folder_aug),save_folder_aug])
            nb.save(nb.Nifti1Image(seg_img1, affine), os.path.join(save_folder, 'aug_'+str(aug_index), 'segmentation_original_res', '0.nii.gz'))
            seg_img1 = block_reduce(seg_img1, (160//40,160//40, 96//24), func=np.max)
            save_folder_aug = os.path.join(save_folder, 'aug_'+str(aug_index), 'segmentation'); ff.make_folder([os.path.dirname(save_folder_aug),save_folder_aug])
            nb.save(nb.Nifti1Image(seg_img1, affine), os.path.join(save_folder_aug, '0.nii.gz'))

i: 0 patient_class: example_data patient_id: example_1
time_frame_list: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
aug_index: 0
load aug_parameter from file
z_rotate_degree: 0 x_translate: 0 y_translate: 0
current time frame:  0  file: /mnt/camca_NAS/4DCT/example_data/mvf_warp0_onecase/example_data/example_1/voxel_final/0.nii.gz
current time frame:  2  file: /mnt/camca_NAS/4DCT/example_data/mvf_warp0_onecase/example_data/example_1/voxel_final/2.nii.gz
current time frame:  4  file: /mnt/camca_NAS/4DCT/example_data/mvf_warp0_onecase/example_data/example_1/voxel_final/4.nii.gz
current time frame:  6  file: /mnt/camca_NAS/4DCT/example_data/mvf_warp0_onecase/example_data/example_1/voxel_final/6.nii.gz
current time frame:  8  file: /mnt/camca_NAS/4DCT/example_data/mvf_warp0_onecase/example_data/example_1/voxel_final/8.nii.gz
current time frame:  10  file: /mnt/camca_NAS/4DCT/example_data/mvf_warp0_onecase/example_data/example_1/voxel_final/10.nii.gz
current time frame:  12  file: /mnt/camca_NAS/4DC