In [1]:
import sys
sys.path.append('/workspace/Documents')
# imports
import os, sys

# third party imports
import numpy as np 
import pandas as pd
import random
import nibabel as nb
import torch
import torch.nn.functional as F

import Diffusion_motion_field.Build_lists.Build_list as Build_list
import Diffusion_motion_field.functions_collection as ff
import Diffusion_motion_field.Data_processing as Data_processing

from Diffusion_motion_field.denoising_diffusion_pytorch.denoising_diffusion_pytorch.conditional_EDM_warp import *

main_path = '/mnt/camca_NAS/4DCT'

  from .autonotebook import tqdm as notebook_tqdm
  @autocast(enabled = False)


### load the patient list

In [2]:
patient_list = pd.read_excel(os.path.join(main_path,'Patient_lists/uc/patient_list_MVF_diffusion_train_test.xlsx'))
print(patient_list.shape)

(317, 12)


### Analyze the time frames

In [3]:
results = []
for i in range(282,patient_list.shape[0]):
    patient_id = patient_list['patient_id'][i]
    patient_class = patient_list['patient_class'][i]
    print(i, patient_id, patient_class)

    seg_folder = os.path.join(main_path, 'predicted_seg', patient_class, patient_id,'seg-pred-0.625-4classes-connected-retouch-resampled-1.5mm')

    seg_files = ff.sort_timeframe(ff.find_all_target_files(['pred*'],seg_folder),2,'_')

    total_tf_num = len(seg_files) 
    
    # get the list of LV volume
    LV_volume_list = []
    for tf in range(total_tf_num):
        seg_file = seg_files[tf]
        seg_data = nb.load(seg_file).get_fdata(); seg_data = seg_data.astype(np.int16)
        LV_volume = np.sum(seg_data == 1)
        LV_volume_list.append(LV_volume)
    LV_volume_list = np.asarray(LV_volume_list)

    # find the index of the minimum volume as well as the Ejection fraction
    es_index = np.argmin(LV_volume_list)
    ejection_fraction = (LV_volume_list[0] - LV_volume_list[es_index])/LV_volume_list[0]
    last_tf_percent = (LV_volume_list[0] - LV_volume_list[-1])/LV_volume_list[0]

    # turn the time frame list into [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
    # first find out the normalized index of es
    es_index_normalized = round(es_index/(total_tf_num),1)

    # sample the temporal series
    normalized_time_frame_list = [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]; normalized_time_frame_list_copy = normalized_time_frame_list.copy()
    sampled_time_frame_list = []
    for t in range(0,len(normalized_time_frame_list)):
        normalized_time_frame = normalized_time_frame_list[t]
        if normalized_time_frame == es_index_normalized:
            sampled_time_frame_list.append(es_index)
        else:
            time_index = round(normalized_time_frame*total_tf_num)
            if time_index == es_index or time_index in sampled_time_frame_list:
                # remove this in normalized_time_frame_list
                normalized_time_frame_list_copy.remove(normalized_time_frame)
                continue
            assert time_index != es_index # make sure the sampled time frame is not the ES time frame
            sampled_time_frame_list.append(time_index)

    # also calculate if pick [0.1,0.3,0.5,0.7,0.9], the ejection fraction is?
    # calculate the ejection fraction directly using semgnetaiton at each time frame
    picked_tf_normalized = [0.1,0.3,0.5,0.7,0.9]
    if len(sampled_time_frame_list)< 10:
        ejection_fraction_picked_5tf = ''
    else:
        picked_tf = [sampled_time_frame_list[normalized_time_frame_list.index(picked_tf_normalized[iii])] for iii in range(0,len(picked_tf_normalized))]
        picked_LV_volume_list = [LV_volume_list[picked_tf[iii]] for iii in range(0,len(picked_tf))]
        ejection_fraction_picked_5tf_from_seg = (LV_volume_list[0] - np.min(picked_LV_volume_list))/LV_volume_list[0]
    # print('volume list: ', picked_LV_volume_list)
    # print('ejection_fraction_picked_5tf_from_seg: ', ejection_fraction_picked_5tf_from_seg)

    # also calculate if pick [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], the ejection fraction is?
    # calculate the ejection fraction directly using semgnetaiton at each time frame
    picked_tf_normalized = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    if len(sampled_time_frame_list)< 10:
        ejection_fraction_picked_10tf = ''
    else:
        picked_tf = [sampled_time_frame_list[normalized_time_frame_list.index(picked_tf_normalized[iii])] for iii in range(0,len(picked_tf_normalized))]
        picked_LV_volume_list = [LV_volume_list[picked_tf[iii]] for iii in range(0,len(picked_tf))]
        ejection_fraction_picked_10tf_from_seg = (LV_volume_list[0] - np.min(picked_LV_volume_list))/LV_volume_list[0]

    # also calculate the ejection fraction using segmentation at ED 0 and deformation field at each time frame
    if len(sampled_time_frame_list)< 10:
        ejection_fraction_picked_5tf_from_mvf = ''
        ejection_fraction_picked_10tf_from_mvf = ''
    else:
        picked_tf = [sampled_time_frame_list[normalized_time_frame_list.index(picked_tf_normalized[iii])] for iii in range(0,len(picked_tf_normalized))]
        seg_template = nb.load(os.path.join(seg_folder, 'pred_s_0.nii.gz')).get_fdata()
        seg_template = np.round(seg_template).astype(np.int16)  # make sure the seg_template is in int16 format
        seg_template[seg_template != 1] = 0
        seg_template = Data_processing.crop_or_pad(seg_template, [160,160,96], value = 0)

        mvf_folder = os.path.join('/mnt/camca_NAS/4DCT/mvf_warp0_onecase/',patient_class, patient_id, 'voxel_final')
        volume_list = []
        volume_list = []
        seg_template_torch = torch.from_numpy(seg_template).unsqueeze(0).unsqueeze(0).float().cuda()
        for tf_n in range(0,len(picked_tf)):
            mvf_file = os.path.join(mvf_folder, str(picked_tf[tf_n]) + '.nii.gz')
            mvf_data = nb.load(mvf_file).get_fdata()
            mvf_data_torch = torch.from_numpy(np.transpose(mvf_data, (3, 0, 1, 2))).unsqueeze(0).float().cuda()
            warped_seg = warp_segmentation_from_mvf(seg_template_torch, mvf_data_torch)
            warped_seg = warped_seg.squeeze(0).squeeze(0).cpu().numpy()
            # warped_seg = Data_processing.apply_deformation_field_numpy(np.copy(seg_template), mvf_data)
            volume_list.append(np.sum(warped_seg))
        # print('volume_list: ', volume_list)
        volume_list = np.asarray(volume_list)
        ejection_fraction_picked_10tf_from_mvf = (LV_volume_list[0] - np.min(volume_list))/LV_volume_list[0]
        volume_list_5tf = np.asarray([volume_list[1], volume_list[3], volume_list[5], volume_list[7], volume_list[9]]).reshape(-1)
        ejection_fraction_picked_5tf_from_mvf = (LV_volume_list[0] - np.min(volume_list_5tf))/LV_volume_list[0]

    # # print('Patient ID: ', patient_id, 'total time frame: ', total_tf_num, ' ES index: ', es_index, ' ES index normalized: ', es_index_normalized, ' sampled time frame: ', sampled_time_frame_list, ' normalized time frame: ', normalized_time_frame_list_copy)
    # # print('ejection fraction original: ', np.round(ejection_fraction,3), 'ejection fraction sampled in 5 tf: ', np.round(ejection_fraction_picked_5tf,3), 'ejection fraction sampled in 10 tf: ', np.round(ejection_fraction_picked_10tf,3))
        
    # how to assert that there is no duplicate in sampled_time_frame_list
    assert len(sampled_time_frame_list) == len(set(sampled_time_frame_list))
    # turn LV_volume_list back to a list
    LV_volume_list = LV_volume_list.tolist()

    # # append the results to the results list
    results.append([patient_class, patient_id, total_tf_num, es_index, es_index_normalized, sampled_time_frame_list, normalized_time_frame_list_copy, LV_volume_list, ejection_fraction, last_tf_percent, ejection_fraction_picked_5tf_from_seg, ejection_fraction_picked_5tf_from_mvf, ejection_fraction_picked_10tf_from_seg, ejection_fraction_picked_10tf_from_mvf])

    df = pd.DataFrame(results, columns = ['patient_class', 'patient_id', 'total_tf_num', 'es_index', 'es_index_normalized', 'sampled_time_frame_list', 'normalized_time_frame_list_copy', 'LV_volume_list', 'EF_original', 'last_tf_percent', 'EF_sampled_in_5tf_by_seg', 'EF_sampled_in_5tf_by_mvf', 'EF_sampled_in_10tf_by_seg', 'EF_sampled_in_10tf_by_mvf'])
    df.to_excel(os.path.join(main_path,'Patient_lists/uc/patient_list_final_selection_timeframes.xlsx'), index = False) 

282 CVC1807171636 Abnormal
283 CVC2006261457 Abnormal
284 CVC1811270917 Normal
285 CVC1805310959 Normal
286 CVC1910150845 Abnormal
287 CVC1908301036 Normal
288 CVC1801031507 Normal
289 CVC1912181517 Abnormal
290 CVC1907301359 Abnormal
291 CVC1912160957 Normal
292 CVC1802051130 Abnormal
293 CVC2005201013 Abnormal
294 CVC1811191453 Abnormal
295 CVC1904240910 Normal
296 CVC2006041408 Normal
297 CVC2001081016 Abnormal
298 CVC2006021121 Abnormal
299 CVC1910221524 Normal
300 CVC2001280905 Abnormal
301 CVC2006250902 Normal
302 CVC1912111149 Normal
303 CVC2002250842 Normal
304 CVC1812271121 Normal
305 CVC1905311311 Abnormal
306 CVC2002131112 Abnormal
307 CVC1911050924 Normal
308 CVC2005201443 Normal
309 CVC1901111110 Normal
310 CVC1912121107 Abnormal
311 CVC1908061347 Abnormal
312 CVC1907251059 Normal
313 CVC2004131548 Abnormal
314 CVC2003191457 Abnormal
315 CVC1905071046 Normal
316 CVC1902151307 Normal


### according to timeframe, we shall exclude several cases

In [77]:
patient_list = pd.read_excel(os.path.join(main_path,'Patient_lists/patient_list_MVF_diffusion_train_test.xlsx'))
timeframe_info = pd.read_excel(os.path.join(main_path,'Patient_lists/patient_list_final_selection_timeframes.xlsx'))

In [80]:
exclude_index = []
for i in range(0,patient_list.shape[0]):
    exclude = False
    patient_id = patient_list['patient_id'][i]
    row = timeframe_info[timeframe_info['patient_id'] == patient_id]

    es_index_normalized = row['es_index_normalized'].iloc[0]

    last_tf_percent = abs(row['last_tf_percent'].iloc[0])

    if es_index_normalized < 0.3 or es_index_normalized >= 0.7:
        exclude = True

    if last_tf_percent >= 0.30:
        exclude = True
    
    if exclude:
        exclude_index.append(i)

# remove the exclude index from the patient list
patient_list = patient_list.drop(exclude_index)
print(patient_list.shape)
patient_list.to_excel(os.path.join(main_path,'Patient_lists/patient_list_MVF_diffusion_train_test_filtered.xlsx'), index = False)

(290, 12)
