## MVF synthesis using trained model


In [13]:
import sys 
sys.path.append('/workspace/Documents')
import os
import torch
import ast
import numpy as np
import pandas as pd
import nibabel as nb
from ema_pytorch import EMA
from scipy.ndimage import zoom
import Cardiac4DCT_Synth_Diffusion.denoising_diffusion_pytorch.denoising_diffusion_pytorch.conditional_diffusion_3D as ddpm_3D
import Cardiac4DCT_Synth_Diffusion.denoising_diffusion_pytorch.denoising_diffusion_pytorch.conditional_EDM as edm
import Cardiac4DCT_Synth_Diffusion.denoising_diffusion_pytorch.denoising_diffusion_pytorch.conditional_EDM_warp as edm_warp
import Cardiac4DCT_Synth_Diffusion.Build_lists.Build_list as Build_list
import Cardiac4DCT_Synth_Diffusion.Generator as Generator
import Cardiac4DCT_Synth_Diffusion.functions_collection as ff
main_path = '/mnt/camca_NAS/4DCT'

### step 1: set default parameters 

In [2]:
trial_name = 'MVF_EDM'

how_many_timeframes_together = 10
picked_tf = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]


mvf_size_3D = [40,40,24]
mvf_slice_range = [0,96]
mvf_folder = os.path.join(main_path,'example_data/mvf_warp0_onecase')

downsample_list =  (True, True, False, False) 

augment_pre_done = True # done in step2 jupyter notebook
conditional_diffusion_timeframe = False
conditional_diffusion_image = True
conditional_diffusion_EF = True 
conditional_diffusion_seg = False

save_folder = os.path.join(main_path, 'example_data/models', trial_name, 'pred_mvf'); os.makedirs(save_folder, exist_ok=True)

### step 2: define pre-trained model

In [3]:
trained_model_filename = os.path.join(main_path, 'example_data/models', trial_name, 'models/model-final.pt')

### step 3: define patient list

In [4]:
# # define training and validation data
data_sheet = os.path.join(main_path,'example_data/Patient_lists/example_data/patient_list.xlsx')

b = Build_list.Build(data_sheet)
patient_class_list, patient_id_list,_ = b.__build__(batch_list = [0])

### step 4: define diffusion model

In [11]:
# define diffusion model
model = ddpm_3D.Unet3D_tfcondition(
    init_dim = 64,
    channels = 3 * how_many_timeframes_together,
    out_dim = 3 * how_many_timeframes_together,
    # conditional_timeframe_input_dim = None,
    # conditional_diffusion_timeframe = conditional_diffusion_timeframe,
    conditional_diffusion_image = conditional_diffusion_image,
    conditional_diffusion_EF = conditional_diffusion_EF,
    conditional_diffusion_seg = conditional_diffusion_seg, # should be False
    dim_mults = (1, 2, 4, 8),
    downsample_list = downsample_list,
    upsample_list = (downsample_list[2], downsample_list[1], downsample_list[0], False),
    flash_attn = False, 
    full_attn = (None, None, False, False), )

diffusion_model = edm.EDM(
    model,
    image_size = mvf_size_3D,
    num_sample_steps = 50,
    clip_or_not = True,
    clip_range = [-1,1],)


### step 5: generate MVF

In [16]:
result = []
for i in range(0,patient_class_list.shape[0]):
    
    patient_class = patient_class_list[i]
    patient_id = patient_id_list[i]
    print('patient_class:', patient_class, 'patient_id:', patient_id)

    ff.make_folder([os.path.join(save_folder, patient_class), os.path.join(save_folder, patient_class, patient_id)])

    for round_test in range(0,2): # one factual synthesis, one counterfactual synthesis

        # get EF
        timeframe_info = pd.read_excel(os.path.join(main_path,'example_data/Patient_lists/example_data/patient_list_final_selection_timeframes.xlsx'))
        row = timeframe_info[timeframe_info['patient_id'] == patient_id]
        preset_EF = round(row['EF_sampled_in_10tf_by_mvf'].iloc[0],2)

        sampled_time_frame_list = ast.literal_eval(row['sampled_time_frame_list'].iloc[0])
        normalized_time_frame_list = ast.literal_eval(row['normalized_time_frame_list_copy'].iloc[0])

        if round_test > 0 and conditional_diffusion_EF:
            # randomly get a new EF 
            preset_EF = round(np.random.uniform(0.10,0.80),2)
        print('EF:', preset_EF)

        save_folder_case = os.path.join(save_folder,patient_class, patient_id, 'test_'+ str(round_test))
        os.makedirs(save_folder_case, exist_ok=True)

        with open(os.path.join(save_folder_case, 'EF.txt'), 'w') as f:
                f.write(str(preset_EF))

        generator = Generator.Dataset_dual_3D(

            patient_class_list = np.asarray([patient_class]),
            patient_id_list = np.asarray([patient_id]),
            main_path = main_path,
            timeframe_info = timeframe_info,
            
            how_many_timeframes_together = how_many_timeframes_together,

            mvf_size_3D = mvf_size_3D,
            slice_range = mvf_slice_range,
            
            picked_tf = picked_tf,
            preset_EF = preset_EF,
            condition_on_image = True,
            prepare_seg = True,
            mvf_cutoff = [-20,20],
            augment = False,
            augment_pre_done= augment_pre_done,)

        sampler = edm.Sampler(diffusion_model,generator,batch_size = 1,image_size =  mvf_size_3D,)

        save_file = os.path.join(save_folder_case, 'pred_mvf.nii.gz')
        original_image_file = os.path.join(main_path,'example_data/nii-images',patient_class, patient_id, 'img-nii-resampled-1.5mm/0.nii.gz')

        pred_mvf_torch, EF_pred, pred_mvf_numpy = sampler.sample_3D_w_trained_model(trained_model_filename=trained_model_filename,cutoff_min = -20, cutoff_max = 20,
                    save_file = save_file,  patient_class = patient_class, patient_id = patient_id,image_file = original_image_file,)
                
        EF_pred = EF_pred.cpu().detach().numpy()[0][0] if isinstance(EF_pred, torch.Tensor) else EF_pred

        result.append([patient_class, patient_id, round_test, preset_EF, EF_pred])
        df = pd.DataFrame(result, columns=['patient_class', 'patient_id', 'round_test', 'preset_EF', 'pred_EF'])
        df.to_excel(os.path.join(os.path.dirname(save_folder), 'EF_results.xlsx'), index=False)

        # save in original resolution
        if os.path.isfile(os.path.join(save_folder_case, 'pred_tf'+ str(sampled_time_frame_list[-1])+'_x.nii.gz')) == 0:
            image_file = os.path.join(main_path,'example_data/nii-images',patient_class, patient_id, 'img-nii-resampled-1.5mm/0.nii.gz')
            affine = nb.load(image_file).affine
            for ii in range(len(sampled_time_frame_list)):
                segment_range = [3*ii, 3*(ii+1)]
                mvf1 = pred_mvf_numpy[3*ii:3*(ii+1),...]; mvf1 = np.moveaxis(mvf1, 0, -1)
     
                mvf1 = zoom(mvf1, (4,4,4,1), order=1) # upsample to original resolution
                nb.save(nb.Nifti1Image(mvf1[:,:,:,0], affine), os.path.join(os.path.dirname(save_file), 'pred_tf'+str(sampled_time_frame_list[ii])+'_x.nii.gz'))
                nb.save(nb.Nifti1Image(mvf1[:,:,:,1], affine), os.path.join(os.path.dirname(save_file), 'pred_tf'+str(sampled_time_frame_list[ii])+'_y.nii.gz'))
                nb.save(nb.Nifti1Image(mvf1[:,:,:,2], affine), os.path.join(os.path.dirname(save_file), 'pred_tf'+str(sampled_time_frame_list[ii])+'_z.nii.gz'))

patient_class: example_data patient_id: example_1
EF: 0.68
data_condition_EF:  tensor([[0.6800]], device='cuda:0')


sampling time step: 100%|██████████| 50/50 [00:01<00:00, 29.08it/s]


EF predicted:  tensor([[0.7035]], device='cuda:0')
EF: 0.22
data_condition_EF:  tensor([[0.2200]], device='cuda:0')


sampling time step: 100%|██████████| 50/50 [00:01<00:00, 28.80it/s]


EF predicted:  tensor([[0.1645]], device='cuda:0')
patient_class: example_data patient_id: example_2
EF: 0.31
data_condition_EF:  tensor([[0.3100]], device='cuda:0')


sampling time step: 100%|██████████| 50/50 [00:01<00:00, 29.13it/s]


EF predicted:  tensor([[0.2804]], device='cuda:0')
EF: 0.13
data_condition_EF:  tensor([[0.1300]], device='cuda:0')


sampling time step: 100%|██████████| 50/50 [00:01<00:00, 29.08it/s]


EF predicted:  tensor([[0.1121]], device='cuda:0')
