In [1]:
import sys 
sys.path.append('/host/d/Github/')
import os
import torch
import numpy as np
import nibabel as nb
import random
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset

import CT_registration_diffusion.functions_collection as ff
import CT_registration_diffusion.Build_lists.Build_list as Build_list
import CT_registration_diffusion.Data_processing as Data_processing
import CT_registration_diffusion.Generator as Generator
import CT_registration_diffusion.model.model as model
import CT_registration_diffusion.model.train_engine as train_engine
import CT_registration_diffusion.model.predict_engine as predict_engine

  from .autonotebook import tqdm as notebook_tqdm


### define our trial name

In [None]:
trial_name = 'trial_1'

### step 0: define some pre-set parameters

In [3]:
# image
image_size = [224,224,96]
cutoff_range = [-200,250]

### step 1: define patient list

In [None]:
# change the excel path to your own path
patient_list_spreadsheet = os.path.join('/host/d/Data/4DCT/Patient_lists/ct_list.xlsx')
build_sheet =  Build_list.Build(patient_list_spreadsheet)

# define train
batch_list_train, dataset_id_list_train, case_id_list_train, image_folder_list_train = build_sheet.__build__(batch_list = [0])
# 先用一个case跑通代码
######## zhennong: 0:1, yuanqi: 1:2, leyu: 2:3, luxin: 3:4
batch_list_train = batch_list_train[0:1]
dataset_id_list_train = dataset_id_list_train[0:1]
case_id_list_train = case_id_list_train[0:1]
image_folder_list_train = image_folder_list_train[0:1]


# define validation
batch_list_val, dataset_id_list_val, case_id_list_val, image_folder_list_val = build_sheet.__build__(batch_list = [3])
# 先用一个case跑通代码, train 和 val暂时先用同一个case
batch_list_val = batch_list_train
dataset_id_list_val = dataset_id_list_train
case_id_list_val = case_id_list_train
image_folder_list_val = image_folder_list_train

# print一个路径来看看
print('train image folder:', image_folder_list_train[0])

train image folder: /host/d/Data/4DCT/DIR_LAB/Case1/cropped_image


### step 2: define generator

In [None]:
# define training generator
only_use_tf0_as_moving = True # if set True, only use time frame 0 as moving image, otherwise randomly select moving time frame

generator_train = Generator.Dataset_4DCT(
    image_folder_list = image_folder_list_train,
    
    image_size = image_size, # target image size after center-crop
    cutoff_range = cutoff_range, # default cutoff range for CT images

    num_of_pairs_each_case = 10, # 在一个4DCT case中，随机选取多少对time frames（比如说我们选了time frame 0和time frame 2作为一对，time frame 1和time frame 3作为另一对，那么num_of_pairs_each_case就是2）
    preset_paired_tf = None, # 预设每个case中time frame的配对情况，比如说[(0,2),(1,3)]表示time frame 0和2作为一对，1和3作为一对。如果设置了这个参数，那么num_of_pairs_each_case就需要和这个list的长度一致。如果是None，那么每次从4DCT中随机选取num_of_pairs_each_case对time frames。
    only_use_tf0_as_moving = only_use_tf0_as_moving, 

    shuffle = True,

    augment = True, # whether to do data augmentation, in training set it to True
    augment_frequency = 0.5, )

# define validation generator
# 和training不同的是，我们在validation中要尽量保持数据的一致性，因此不进行shuffle和data augmentation。同时我们要设定preset_paired_tf，确保每次选取的time frame配对是一样的。
preset_paired_tf_val = [(0,1),(0,2),(0,4), (0,5),(0,6),(0,8)] # 预设validation中每个case的time frame配对情况
generator_val = Generator.Dataset_4DCT(
    image_folder_list = image_folder_list_val,
    image_size = image_size, 
    cutoff_range = cutoff_range, 

    num_of_pairs_each_case = len(preset_paired_tf_val), 
    preset_paired_tf = preset_paired_tf_val, 
    only_use_tf0_as_moving = only_use_tf0_as_moving,
    shuffle = False,
    augment = False, # whether to do data augmentation
    augment_frequency = 0.0, )

### step 3: model

In [None]:
# build model
our_model = model.Unet(
    problem_dimension = '3D',  # we are solving a 3D image registration problem
  
    input_channels = 2, # =1 如果只有一个4DCT time frame(比如只有time frame 0）作为模型输入；=2 如果有两个4DCT time frames（比如time frame 0和time frame 2作为moving和fixed image）作为模型输入
    out_channels = 3,  # =2 for 2D deformation field; =3 for 3D deformation field

    initial_dim = 4,  # default initial feature dimension after first conv layer
    dim_mults = (2,4,8,16),
    groups = 4,
    # None：不用attention，"False":linear attention, "True": full attention
    full_attn_paths = (None, None, None, None), # these are for downsampling and upsampling paths， 现在都是None因为要考虑GPU内存
    full_attn_bottleneck  = False, # this is for the middle bottleneck layer， 现在是None因为要考虑GPU内存
    act = 'ReLU',
)

in out is :  [(4, 8), (8, 16), (16, 32), (32, 64)]


### step 4: build trainer and start to train the model

In [None]:
# IMPORTANT: training parameters
regularization_weight = 0.01 # weight for deformation field smoothness regularization term， 这个是要通过测试来确定最佳取值
similarity_metric = 'MSE' # similarity metric for image similarity loss, can be 'MSE' or 'NCC'

train_batch_size = 1 # training batch size,  如果GPU内存不够，可以把这个值设成1
accumulation_steps = 5 # gradient accumulation steps to simulate larger batch size， 如果GPU内存不够，可以把这个值设成更小

# training schedule
total_epochs = 4000
save_models_every = 10 # save model every N epochs，训练样本越少这个数字越大，样本越大这个数字可以越小（通常我设成1-5之间，如果只有一个case我会设成20-50）
validation_every = save_models_every # validate every N epochs, should be same as save_models_every
# where to save your model weights? Change this path to your own path
results_folder = os.path.join('/host/d/projects/registration/models', trial_name, 'models')
ff.make_folder([os.path.basename(results_folder), results_folder])

trainer = train_engine.Trainer(
        our_model,
        generator_train,
        generator_val,
        train_batch_size = train_batch_size,
        accum_iter= accumulation_steps,

        regularization_weight = regularization_weight,
        similarity_metric = similarity_metric,
        
        train_num_steps = total_epochs,
        results_folder = results_folder,
       
        train_lr_decay_every = 200, 
        save_models_every = save_models_every,
        validation_every = validation_every,)

In [26]:
### do we have pre-trained model?
pretrained_model = os.path.join(results_folder, 'model-170.pt')

# what is the start epoch?
start_epoch = 170 # if no pre-trained model, start from epoch 0, else start from the loaded epoch


In [None]:
# # # start training!!!
trainer.train(pre_trained_model = pretrained_model, start_step = start_epoch)
# #如果跑不动（GPU内存不足），
# 1.可以尝试减小model里initial_dim的值，比如改成4或者2
# 2.可以尝试减小train_batch_size
# 3.可以尝试减小accumulation_steps

  data = torch.load(trained_model_filename, map_location=device)


model loaded from  /host/d/projects/registration/models/trial_1_movingTF0_MSE_weight0.01/models/model-170.pt


  4%|▍         | 170/4000 [00:00<?, ?it/s]

training epoch:  171
learning rate:  0.0001


average loss: 0.0007, average similarity loss: 0.0007, average regularization loss: 0.0018:   4%|▍         | 171/4000 [00:24<26:02:27, 24.48s/it]

training epoch:  172
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0017:   4%|▍         | 172/4000 [00:55<29:58:50, 28.19s/it]

training epoch:  173
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0017:   4%|▍         | 173/4000 [01:12<24:42:49, 23.25s/it]

training epoch:  174
learning rate:  0.0001


average loss: 0.0007, average similarity loss: 0.0007, average regularization loss: 0.0017:   4%|▍         | 174/4000 [01:36<25:00:35, 23.53s/it]

training epoch:  175
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0018:   4%|▍         | 175/4000 [01:57<23:55:11, 22.51s/it]

training epoch:  176
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0017:   4%|▍         | 176/4000 [02:14<22:01:49, 20.74s/it]

training epoch:  177
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0007, average regularization loss: 0.0016:   4%|▍         | 177/4000 [02:51<27:38:14, 26.03s/it]

training epoch:  178
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0016:   4%|▍         | 178/4000 [03:12<25:59:16, 24.48s/it]

training epoch:  179
learning rate:  0.0001


average loss: 0.0007, average similarity loss: 0.0007, average regularization loss: 0.0015:   4%|▍         | 179/4000 [03:43<27:56:07, 26.32s/it]

training epoch:  180
learning rate:  0.0001


average loss: 0.0009, average similarity loss: 0.0008, average regularization loss: 0.0016:   4%|▍         | 179/4000 [04:00<27:56:07, 26.32s/it]

validation at step:  180


average loss: 0.0009, average similarity loss: 0.0008, average regularization loss: 0.0016:   4%|▍         | 180/4000 [04:04<26:28:13, 24.95s/it]

validation loss: 0.0010, validation similarity loss: 0.0010, validation regularization loss: 0.0017
training epoch:  181
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0017:   5%|▍         | 181/4000 [04:25<25:03:01, 23.61s/it]

training epoch:  182
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0007, average regularization loss: 0.0017:   5%|▍         | 182/4000 [04:49<25:09:29, 23.72s/it]

training epoch:  183
learning rate:  0.0001


average loss: 0.0009, average similarity loss: 0.0009, average regularization loss: 0.0016:   5%|▍         | 183/4000 [05:10<24:12:52, 22.84s/it]

training epoch:  184
learning rate:  0.0001


average loss: 0.0007, average similarity loss: 0.0007, average regularization loss: 0.0016:   5%|▍         | 184/4000 [05:41<26:55:45, 25.41s/it]

training epoch:  185
learning rate:  0.0001


average loss: 0.0006, average similarity loss: 0.0006, average regularization loss: 0.0016:   5%|▍         | 185/4000 [06:12<28:48:44, 27.19s/it]

training epoch:  186
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0007, average regularization loss: 0.0017:   5%|▍         | 186/4000 [06:37<27:59:01, 26.41s/it]

training epoch:  187
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0017:   5%|▍         | 187/4000 [07:08<29:20:51, 27.71s/it]

training epoch:  188
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0017:   5%|▍         | 188/4000 [07:35<29:11:51, 27.57s/it]

training epoch:  189
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0018:   5%|▍         | 189/4000 [07:59<28:03:08, 26.50s/it]

training epoch:  190
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0017:   5%|▍         | 189/4000 [08:23<28:03:08, 26.50s/it]

validation at step:  190


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0017:   5%|▍         | 190/4000 [08:27<28:36:11, 27.03s/it]

validation loss: 0.0009, validation similarity loss: 0.0009, validation regularization loss: 0.0017
training epoch:  191
learning rate:  0.0001


average loss: 0.0007, average similarity loss: 0.0007, average regularization loss: 0.0016:   5%|▍         | 191/4000 [08:57<29:35:55, 27.97s/it]

training epoch:  192
learning rate:  0.0001


average loss: 0.0007, average similarity loss: 0.0007, average regularization loss: 0.0015:   5%|▍         | 192/4000 [09:15<26:12:59, 24.78s/it]

training epoch:  193
learning rate:  0.0001


average loss: 0.0007, average similarity loss: 0.0007, average regularization loss: 0.0015:   5%|▍         | 193/4000 [09:42<26:55:31, 25.46s/it]

training epoch:  194
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0017:   5%|▍         | 194/4000 [09:56<23:21:15, 22.09s/it]

training epoch:  195
learning rate:  0.0001


average loss: 0.0009, average similarity loss: 0.0009, average regularization loss: 0.0017:   5%|▍         | 195/4000 [10:17<22:56:05, 21.70s/it]

training epoch:  196
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0017:   5%|▍         | 196/4000 [10:44<24:37:22, 23.30s/it]

training epoch:  197
learning rate:  0.0001


average loss: 0.0006, average similarity loss: 0.0006, average regularization loss: 0.0017:   5%|▍         | 197/4000 [11:14<26:45:18, 25.33s/it]

training epoch:  198
learning rate:  0.0001


average loss: 0.0007, average similarity loss: 0.0007, average regularization loss: 0.0016:   5%|▍         | 198/4000 [11:38<26:16:09, 24.87s/it]

training epoch:  199
learning rate:  0.0001


average loss: 0.0009, average similarity loss: 0.0009, average regularization loss: 0.0017:   5%|▍         | 199/4000 [11:55<23:56:02, 22.67s/it]

training epoch:  200
learning rate:  0.0001


average loss: 0.0008, average similarity loss: 0.0007, average regularization loss: 0.0017:   5%|▍         | 199/4000 [12:22<23:56:02, 22.67s/it]

validation at step:  200


average loss: 0.0008, average similarity loss: 0.0007, average regularization loss: 0.0017:   5%|▌         | 200/4000 [12:26<26:36:13, 25.20s/it]

validation loss: 0.0009, validation similarity loss: 0.0009, validation regularization loss: 0.0017
training epoch:  201
learning rate:  9.5e-05


average loss: 0.0007, average similarity loss: 0.0007, average regularization loss: 0.0017:   5%|▌         | 201/4000 [12:50<26:07:30, 24.76s/it]

training epoch:  202
learning rate:  9.5e-05


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0017:   5%|▌         | 202/4000 [13:11<24:48:35, 23.52s/it]

training epoch:  203
learning rate:  9.5e-05


average loss: 0.0006, average similarity loss: 0.0006, average regularization loss: 0.0016:   5%|▌         | 203/4000 [13:41<26:55:15, 25.52s/it]

training epoch:  204
learning rate:  9.5e-05


average loss: 0.0006, average similarity loss: 0.0006, average regularization loss: 0.0015:   5%|▌         | 204/4000 [14:11<28:28:19, 27.00s/it]

training epoch:  205
learning rate:  9.5e-05


average loss: 0.0007, average similarity loss: 0.0007, average regularization loss: 0.0016:   5%|▌         | 205/4000 [14:42<29:29:10, 27.97s/it]

training epoch:  206
learning rate:  9.5e-05


average loss: 0.0006, average similarity loss: 0.0006, average regularization loss: 0.0015:   5%|▌         | 206/4000 [15:06<28:10:34, 26.74s/it]

training epoch:  207
learning rate:  9.5e-05


average loss: 0.0006, average similarity loss: 0.0006, average regularization loss: 0.0016:   5%|▌         | 207/4000 [15:32<28:11:05, 26.75s/it]

training epoch:  208
learning rate:  9.5e-05


average loss: 0.0008, average similarity loss: 0.0007, average regularization loss: 0.0016:   5%|▌         | 208/4000 [15:56<27:12:33, 25.83s/it]

training epoch:  209
learning rate:  9.5e-05


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0016:   5%|▌         | 209/4000 [16:17<25:34:20, 24.28s/it]

training epoch:  210
learning rate:  9.5e-05


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0017:   5%|▌         | 209/4000 [16:47<25:34:20, 24.28s/it]

validation at step:  210


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0017:   5%|▌         | 210/4000 [16:51<28:51:06, 27.41s/it]

validation loss: 0.0009, validation similarity loss: 0.0009, validation regularization loss: 0.0017
training epoch:  211
learning rate:  9.5e-05


average loss: 0.0008, average similarity loss: 0.0008, average regularization loss: 0.0016:   5%|▌         | 211/4000 [17:05<24:37:51, 23.40s/it]

training epoch:  212
learning rate:  9.5e-05


### step 5: build predictor and use trained model to predict

In [21]:
# define save folder
save_folder = os.path.join('/host/d/projects/registration/models', trial_name, 'results')
ff.make_folder([os.path.basename(save_folder), save_folder])

In [None]:
# change the excel path to your own path
patient_list_spreadsheet = os.path.join('/host/d/Data/4DCT/Patient_lists/ct_list.xlsx')
build_sheet =  Build_list.Build(patient_list_spreadsheet)

# define test (作为展示我们先用train的case来跑一下)
batch_list_tst, dataset_id_list_tst, case_id_list_tst, image_folder_list_tst = build_sheet.__build__(batch_list = [0])
# 先用一个case跑通代码
######## zhennong: 0:1, yuanqi: 1:2, leyu: 2:3, luxin: 3:4
batch_list_tst = batch_list_tst[0:1]
dataset_id_list_tst = dataset_id_list_tst[0:1]
case_id_list_tst = case_id_list_tst[0:1]
image_folder_list_tst = image_folder_list_tst[0:1]

### step 5.2 define the pre-trained model (the epoch that has the best validation loss)

In [23]:
epoch = 170
trained_model_file = os.path.join('/host/d/projects/registration/models', trial_name, 'models', 'model-' + str(epoch) + '.pt')

### step 5.3 do the prediction

In [None]:
for i in range(0,case_id_list_tst.shape[0]):
    
    # find out how many time frames in this 4DCT
    image_folder = image_folder_list_tst[i]
    timeframes = ff.sort_timeframe(ff.find_all_target_files(['img*'], image_folder),2,'_')
    print('case id:', case_id_list_tst[i], 'has', len(timeframes), 'time frames.')

    # save folder for this case
    save_folder_case = os.path.join(save_folder, case_id_list_tst[i], 'epoch_' + str(epoch))
    ff.make_folder([os.path.basename(save_folder_case), save_folder_case])
    
    ### get the deformation fields for each time frame using the first time frame as moving image
    for tf in range(5,6):# range(1, len(timeframes)):
        moving_tf = 0
        fixed_tf = tf

        # define prediction generator
        generator_pred = Generator.Dataset_4DCT(
            image_folder_list = [image_folder_list_tst[i]],
            image_size = image_size, 
            cutoff_range = cutoff_range, 
            only_use_tf0_as_moving=True,

            num_of_pairs_each_case = 1, 
            preset_paired_tf = [(moving_tf, fixed_tf)], 
        )

        # define predictor
        predictor = predict_engine.Predictor(
            our_model,
            generator_pred,
            batch_size = 1,
        )

        
        # predict MVF
        pred_MVF, pred_MVF_numpy, warped_moving_image_numpy = predictor.predict_MVF_and_apply(trained_model_filename = trained_model_file)
        print('predicted MVF_numpy shape:', pred_MVF_numpy.shape)
        print('warped moving image shape:', warped_moving_image_numpy.shape)

        # save truth
        moving_image_file = timeframes[moving_tf]
        fixed_image_file = timeframes[fixed_tf]
        moving_image_nii = nb.load(moving_image_file)
        fixed_image_nii = nb.load(fixed_image_file)
        affine = moving_image_nii.affine

        nb.save(nb.Nifti1Image(moving_image_nii.get_fdata(), affine), os.path.join(save_folder_case, 'gt_tf' + str(moving_tf) + '.nii.gz'))
        nb.save(nb.Nifti1Image(fixed_image_nii.get_fdata(), affine), os.path.join(save_folder_case, 'gt_tf' + str(fixed_tf) + '.nii.gz'))
        # save warped moving image
        nb.save(nb.Nifti1Image(warped_moving_image_numpy, affine), os.path.join(save_folder_case, 'warped_tf' + str(fixed_tf) + '.nii.gz'))
        # save predicted MVF
        nb.save(nb.Nifti1Image(pred_MVF_numpy[0,...], affine), os.path.join(save_folder_case, 'pred_MVF_tf' + str(fixed_tf) + '_x.nii.gz'))
        nb.save(nb.Nifti1Image(pred_MVF_numpy[1,...], affine), os.path.join(save_folder_case, 'pred_MVF_tf' + str(fixed_tf) + '_y.nii.gz'))
        nb.save(nb.Nifti1Image(pred_MVF_numpy[2,...], affine), os.path.join(save_folder_case, 'pred_MVF_tf' + str(fixed_tf) + '_z.nii.gz'))

case id: Case1 has 10 time frames.


  data = torch.load(trained_model_filename, map_location=self.device)
100%|██████████| 1/1 [00:00<00:00,  1.17it/s]


predicted MVF_numpy shape: (3, 224, 224, 96)
warped moving image shape: (224, 224, 96)


In [25]:
print('data range:', np.min(pred_MVF_numpy), np.max(pred_MVF_numpy))

data range: -0.5904154 4.5218554


In [16]:
import CT_registration_diffusion.model.loss as my_loss

img = fixed_image_nii.get_fdata()[np.newaxis, np.newaxis, ...]
img = Data_processing.normalize_image(img, normalize_factor = generator_pred.normalize_factor, image_max = generator_pred.maximum_cutoff, image_min = generator_pred.background_cutoff ,invert = False,final_max = 1, final_min = 0)
img_torch = torch.from_numpy(img).float().to('cuda')
similarity_metric = my_loss.NCCLoss()
warped_moving_image_numpy = Data_processing.normalize_image(warped_moving_image_numpy, normalize_factor =generator_pred.normalize_factor, image_max = generator_pred.maximum_cutoff, image_min = generator_pred.background_cutoff ,invert = False,final_max = 1, final_min = 0)
warped_torch = torch.from_numpy(warped_moving_image_numpy[np.newaxis, np.newaxis, ...]).float().to('cuda')
value = similarity_metric(warped_torch, img_torch)
print('NCC value between identical images:', value.item())

NCC value between identical images: -0.08591607958078384
