In [1]:
import sys 
sys.path.append('/host/d/Github/')
import os
import torch
import numpy as np
import nibabel as nb
import random
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset

import CT_registration_diffusion.functions_collection as ff
import CT_registration_diffusion.Build_lists.Build_list as Build_list
import CT_registration_diffusion.Data_processing as Data_processing
import CT_registration_diffusion.Generator as Generator
import CT_registration_diffusion.model.model as model
import CT_registration_diffusion.model.train_engine as train_engine
import CT_registration_diffusion.model.predict_engine as predict_engine

  from .autonotebook import tqdm as notebook_tqdm


### define our trial name

In [2]:
trial_name = 'trial_1_movingTF0_MSE_weight0.01'

### step 0: define some pre-set parameters

In [None]:
# image
image_size = [224,224,96]
cutoff_range = [-200,250]

# train
train_batch_size = 1 # training batch size,  如果GPU内存不够，可以把这个值设成1
accumulation_steps = 5 # gradient accumulation steps to simulate larger batch size， 如果GPU内存不够，可以把这个值设成更小

### step 1: define patient list

In [4]:
# change the excel path to your own path
patient_list_spreadsheet = os.path.join('/host/d/Data/4DCT/Patient_lists/ct_list.xlsx')
build_sheet =  Build_list.Build(patient_list_spreadsheet)

# define train
batch_list_train, dataset_id_list_train, case_id_list_train, image_folder_list_train = build_sheet.__build__(batch_list = [0])
# 先用一个case跑通代码
batch_list_train = batch_list_train[0:1]
dataset_id_list_train = dataset_id_list_train[0:1]
case_id_list_train = case_id_list_train[0:1]
image_folder_list_train = image_folder_list_train[0:1]


# define validation
batch_list_val, dataset_id_list_val, case_id_list_val, image_folder_list_val = build_sheet.__build__(batch_list = [3])
# 先用一个case跑通代码, train 和 val暂时先用同一个case
batch_list_val = batch_list_train
dataset_id_list_val = dataset_id_list_train
case_id_list_val = case_id_list_train
image_folder_list_val = image_folder_list_train

# print一个路径来看看
print('train image folder:', image_folder_list_train[0])

train image folder: /host/d/Data/4DCT/DIR_LAB/Case1/cropped_image


### step 2: define generator

In [5]:
# define training generator
only_use_tf0_as_moving = True # if set True, only use time frame 0 as moving image, otherwise randomly select moving time frame
generator_train = Generator.Dataset_4DCT(
    image_folder_list = image_folder_list_train,
    
    image_size = image_size, # target image size after center-crop
    cutoff_range = cutoff_range, # default cutoff range for CT images

    num_of_pairs_each_case = 10, # 在一个4DCT case中，随机选取多少对time frames（比如说我们选了time frame 0和time frame 2作为一对，time frame 1和time frame 3作为另一对，那么num_of_pairs_each_case就是2）
    preset_paired_tf = None, # 预设每个case中time frame的配对情况，比如说[(0,2),(1,3)]表示time frame 0和2作为一对，1和3作为一对。如果设置了这个参数，那么num_of_pairs_each_case就需要和这个list的长度一致。如果是None，那么每次从4DCT中随机选取num_of_pairs_each_case对time frames。
    only_use_tf0_as_moving = only_use_tf0_as_moving, 

    shuffle = True,

    augment = True, # whether to do data augmentation, in training set it to True
    augment_frequency = 0.5, )

# define validation generator
# 和training不同的是，我们在validation中要尽量保持数据的一致性，因此不进行shuffle和data augmentation。同时我们要设定preset_paired_tf，确保每次选取的time frame配对是一样的。
preset_paired_tf_val = [(0,1),(0,2),(0,4), (0,5),(0,6),(0,8)] # 预设validation中每个case的time frame配对情况
generator_val = Generator.Dataset_4DCT(
    image_folder_list = image_folder_list_val,
    image_size = image_size, 
    cutoff_range = cutoff_range, 

    num_of_pairs_each_case = len(preset_paired_tf_val), 
    preset_paired_tf = preset_paired_tf_val, 
    only_use_tf0_as_moving = only_use_tf0_as_moving,
    shuffle = False,
    augment = False, # whether to do data augmentation
    augment_frequency = 0.0, )

### step 3: model

In [6]:
# build model
our_model = model.Unet(
    problem_dimension = '3D',  # we are solving a 3D image registration problem
  
    input_channels = 2, # =1 如果只有一个4DCT time frame(比如只有time frame 0）作为模型输入；=2 如果有两个4DCT time frames（比如time frame 0和time frame 2作为moving和fixed image）作为模型输入
    out_channels = 3,  # =2 for 2D deformation field; =3 for 3D deformation field

    initial_dim = 4,  # default initial feature dimension after first conv layer
    dim_mults = (2,4,8,16),
    groups = 4,
      
    full_attn_paths = (None, None, None, None), # these are for downsampling and upsampling paths， 现在都是None因为要考虑GPU内存
    full_attn_bottleneck = False, # this is for the middle bottleneck layer， 现在是None因为要考虑GPU内存
    act = 'ReLU',
)

in out is :  [(4, 8), (8, 16), (16, 32), (32, 64)]


### step 4: build trainer and start to train the model

In [11]:
regularization_weight = 0.01 # weight for deformation field smoothness regularization term， 这个是要通过测试来确定最佳取值

total_epochs = 4000
save_models_every = 20 # save model every N epochs，训练样本越少这个数字越大，样本越大这个数字可以越小（通常我设成1-5之间，如果只有一个case我会设成20-50）
validation_every = 10000#save_models_every # validate every N epochs, should be same as save_models_every
# where to save your model weights? Change this path to your own path
results_folder = os.path.join('/host/d/projects/registration/models', trial_name, 'models')
ff.make_folder([os.path.basename(results_folder), results_folder])

trainer = train_engine.Trainer(
        our_model,
        generator_train,
        generator_val,
        train_batch_size = train_batch_size,
        accum_iter= accumulation_steps,

        regularization_weight = regularization_weight,
        train_num_steps = total_epochs,
        results_folder = results_folder,
       
        train_lr_decay_every = 200, 
        save_models_every = save_models_every,
        validation_every = validation_every,)

In [12]:
### do we have pre-trained model?
pretrained_model = None

# what is the start epoch?
start_epoch = 0 # if no pre-trained model, start from epoch 0, else start from the loaded epoch


In [None]:
# # start training
trainer.train(pre_trained_model = pretrained_model, start_step = start_epoch)
# #如果跑不动（GPU内存不足），
# 1.可以尝试减小model里initial_dim的值，比如改成4或者2
# 2.可以尝试减小train_batch_size
# 3.可以尝试减小accumulation_steps

  0%|          | 0/4000 [00:00<?, ?it/s]

training epoch:  1
learning rate:  0.0001


average loss: 0.0107, average similarity loss: 0.0106, average regularization loss: 0.0122:   0%|          | 1/4000 [00:14<15:47:20, 14.21s/it]

training epoch:  2
learning rate:  0.0001


average loss: 0.0082, average similarity loss: 0.0081, average regularization loss: 0.0095:   0%|          | 2/4000 [00:38<22:02:41, 19.85s/it]

training epoch:  3
learning rate:  0.0001


average loss: 0.0063, average similarity loss: 0.0062, average regularization loss: 0.0090:   0%|          | 3/4000 [01:08<27:15:23, 24.55s/it]

training epoch:  4
learning rate:  0.0001


### step 5: build predictor and use trained model to predict

In [10]:
# define save folder
save_folder = os.path.join('/host/d/projects/registration/models', trial_name, 'results')
ff.make_folder([os.path.basename(save_folder), save_folder])

In [11]:
# change the excel path to your own path
patient_list_spreadsheet = os.path.join('/host/d/Data/4DCT/Patient_lists/ct_list.xlsx')
build_sheet =  Build_list.Build(patient_list_spreadsheet)

# define test (作为展示我们先用train的case来跑一下)
batch_list_tst, dataset_id_list_tst, case_id_list_tst, image_folder_list_tst = build_sheet.__build__(batch_list = [0])
# 先用一个case跑通代码
batch_list_tst = batch_list_tst[0:1]
dataset_id_list_tst = dataset_id_list_tst[0:1]
case_id_list_tst = case_id_list_tst[0:1]
image_folder_list_tst = image_folder_list_tst[0:1]

### step 5.2 define the pre-trained model (the epoch that has the best validation loss)

In [12]:
epoch = 60
trained_model_file = os.path.join('/host/d/projects/registration/models', trial_name, 'models', 'model-' + str(epoch) + '.pt')

### step 5.3 do the prediction

In [13]:
for i in range(0,case_id_list_tst.shape[0]):
    
    # find out how many time frames in this 4DCT
    image_folder = image_folder_list_tst[i]
    timeframes = ff.sort_timeframe(ff.find_all_target_files(['img*'], image_folder),2,'_')
    print('case id:', case_id_list_tst[i], 'has', len(timeframes), 'time frames.')

    # save folder for this case
    save_folder_case = os.path.join(save_folder, case_id_list_tst[i], 'epoch_' + str(epoch))
    ff.make_folder([os.path.basename(save_folder_case), save_folder_case])
    
    ### get the deformation fields for each time frame using the first time frame as moving image
    for tf in range(5,6):# len(timeframes)):
        moving_tf = 0
        fixed_tf = tf

        # define prediction generator
        generator_pred = Generator.Dataset_4DCT(
            image_folder_list = [image_folder_list_tst[i]],
            image_size = image_size, 
            cutoff_range = cutoff_range, 
            only_use_tf0_as_moving=True,

            num_of_pairs_each_case = 1, 
            preset_paired_tf = [(moving_tf, fixed_tf)], 
        )

        # define predictor
        predictor = predict_engine.Predictor(
            our_model,
            generator_pred,
            batch_size = 1,
        )

        
        # predict MVF
        pred_MVF, pred_MVF_numpy, warped_moving_image_numpy = predictor.predict_MVF_and_apply(trained_model_filename = trained_model_file)
        print('predicted MVF_numpy shape:', pred_MVF_numpy.shape)
        print('warped moving image shape:', warped_moving_image_numpy.shape)

        # save truth
        moving_image_file = timeframes[moving_tf]
        fixed_image_file = timeframes[fixed_tf]
        moving_image_nii = nb.load(moving_image_file)
        fixed_image_nii = nb.load(fixed_image_file)
        affine = moving_image_nii.affine

        nb.save(nb.Nifti1Image(moving_image_nii.get_fdata(), affine), os.path.join(save_folder_case, 'gt_tf' + str(moving_tf) + '.nii.gz'))
        nb.save(nb.Nifti1Image(fixed_image_nii.get_fdata(), affine), os.path.join(save_folder_case, 'gt_tf' + str(fixed_tf) + '.nii.gz'))
        # save warped moving image
        nb.save(nb.Nifti1Image(warped_moving_image_numpy, affine), os.path.join(save_folder_case, 'warped_tf' + str(fixed_tf) + '.nii.gz'))
        # save predicted MVF
        nb.save(nb.Nifti1Image(pred_MVF_numpy[0,...], affine), os.path.join(save_folder_case, 'pred_MVF_tf' + str(fixed_tf) + '_x.nii.gz'))
        nb.save(nb.Nifti1Image(pred_MVF_numpy[1,...], affine), os.path.join(save_folder_case, 'pred_MVF_tf' + str(fixed_tf) + '_y.nii.gz'))
        nb.save(nb.Nifti1Image(pred_MVF_numpy[2,...], affine), os.path.join(save_folder_case, 'pred_MVF_tf' + str(fixed_tf) + '_z.nii.gz'))

  data = torch.load(trained_model_filename, map_location=self.device)


case id: Case1 has 10 time frames.


100%|██████████| 1/1 [00:00<00:00,  1.24it/s]


predicted MVF_numpy shape: (3, 224, 224, 96)
warped moving image shape: (224, 224, 96)


In [14]:
print('data range:', np.min(pred_MVF_numpy), np.max(pred_MVF_numpy))

data range: -1.5182745 2.666669


In [40]:
import CT_registration_diffusion.model.loss as my_loss

img = fixed_image_nii.get_fdata()[np.newaxis, np.newaxis, ...]
img_torch = torch.from_numpy(img).float().to('cuda')
similarity_metric = my_loss.NCCLoss()
warped_torch = torch.from_numpy(warped_moving_image_numpy[np.newaxis, np.newaxis, ...]).float().to('cuda')
value = similarity_metric(warped_torch, img_torch)
print('NCC value between identical images:', value.item())

NCC value between identical images: -25475170.0
