In [None]:
import sys 
sys.path.append('/host/d/Github/')
import os
import torch
import numpy as np
import nibabel as nb
import random
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset

import CT_registration_diffusion.functions_collection as ff
import CT_registration_diffusion.Build_lists.Build_list as Build_list
import CT_registration_diffusion.Data_processing as Data_processing
import CT_registration_diffusion.Generator as Generator
import CT_registration_diffusion.model.model as model
import CT_registration_diffusion.model.train_engine as train_engine

### define our trial name

In [None]:
trial_name = 'trial_1'

### step 0: define some pre-set parameters

In [None]:
# image
image_size = [224,224,96]
cutoff_range = [-200,250]

# train
train_batch_size = 1 # training batch size, set to 1 if your GPU memory is limited

### step 1: define patient list

In [None]:
# change the excel path to your own path
patient_list_spreadsheet = os.path.join('/host/d/Data/4DCT/Patient_lists/ct_list.xlsx')
build_sheet =  Build_list.Build(patient_list_spreadsheet)

# define train
batch_list_train, dataset_id_list_train, case_id_list_train, image_folder_list_train = build_sheet.__build__(batch_list = [0])
# 先用一个case跑通代码
batch_list_train = batch_list_train[0:1]
dataset_id_list_train = dataset_id_list_train[0:1]
case_id_list_train = case_id_list_train[0:1]
image_folder_list_train = image_folder_list_train[0:1]


# define validation
batch_list_val, dataset_id_list_val, case_id_list_val, image_folder_list_val = build_sheet.__build__(batch_list = [3])
# 先用一个case跑通代码, train 和 val暂时先用同一个case
batch_list_val = batch_list_train
dataset_id_list_val = dataset_id_list_train
case_id_list_val = case_id_list_train
image_folder_list_val = image_folder_list_train

# print一个路径来看看
print('train image folder:', image_folder_list_train[0])

### step 2: define generator

In [None]:
# define training generator
generator_train = Generator.Dataset_4DCT(
    image_folder_list = image_folder_list_train,
    
    image_size = image_size, # target image size after center-crop
    cutoff_range = cutoff_range, # default cutoff range for CT images

    num_of_pairs_each_case = 1, # 在一个4DCT case中，随机选取多少对time frames（比如说我们选了time frame 0和time frame 2作为一对，time frame 1和time frame 3作为另一对，那么num_of_pairs_each_case就是2）
    preset_paired_tf = None, # 预设每个case中time frame的配对情况，比如说[(0,2),(1,3)]表示time frame 0和2作为一对，1和3作为一对。如果设置了这个参数，那么num_of_pairs_each_case就需要和这个list的长度一致。如果没有设置这个参数，那么每次从4DCT中随机选取num_of_pairs_each_case对time frames。

    shuffle = True,

    augment = True, # whether to do data augmentation, in training set it to True
    augment_frequency = 0.5, )

# define validation generator
# 和training不同的是，我们在validation中要尽量保持数据的一致性，因此不进行shuffle和data augmentation。同时我们要设定preset_paired_tf，确保每次选取的time frame配对是一样的。
preset_paired_tf_val = [(0,1),(0,2),(0,4), (0,5),(0,6),(0,8)] # 预设validation中每个case的time frame配对情况
generator_val = Generator.Dataset_4DCT(
    image_folder_list = image_folder_list_val,
    image_size = image_size, 
    cutoff_range = cutoff_range, 

    num_of_pairs_each_case = len(preset_paired_tf_val), 
    preset_paired_tf = preset_paired_tf_val, 
    shuffle = False,
    augment = False, # whether to do data augmentation
    augment_frequency = 0.0, )

### step 3: model

In [None]:
# build model
our_model = model.Unet(
    problem_dimension = '3D',  # we are solving a 3D image registration problem
  
    input_channels = 2, # =1 如果只有一个4DCT time frame(比如只有time frame 0）作为模型输入；=2 如果有两个4DCT time frames（比如time frame 0和time frame 2作为moving和fixed image）作为模型输入
    out_channels = 3,  # =2 for 2D deformation field; =3 for 3D deformation field

    initial_dim = 4,  # default initial feature dimension after first conv layer
    dim_mults = (2,4,8,16),
    groups = 4,
      
    full_attn_paths = (None, None, None, None), # these are for downsampling and upsampling paths， 现在都是None因为要考虑GPU内存
    full_attn_bottleneck = None, # this is for the middle bottleneck layer， 现在是None因为要考虑GPU内存
    act = 'ReLU',
)

### step 4: build trainer and start to train the model

In [None]:
regularization_weight = 1.0 # weight for deformation field smoothness regularization term， 这个是要通过测试来确定最佳取值

total_epochs = 1000
save_models_every = 2 # save model every N epochs，训练样本越少这个数字越大，样本越大这个数字可以越小（通常我设成1-5之间，如果只有一个case我会设成20-50）
validation_every = save_models_every # validate every N epochs, should be same as save_models_every

# where to save your model weights? Change this path to your own path
results_folder = os.path.join('/host/d/projects/registration/models', trial_name, 'models')
ff.make_folder([os.path.basename(results_folder), results_folder])

trainer = train_engine.Trainer(
        our_model,
        generator_train,
        generator_val,
        train_batch_size = train_batch_size,
        regularization_weight = regularization_weight,
        train_num_steps = total_epochs,
        results_folder = results_folder,
       
        train_lr_decay_every = 100, 
        save_models_every = save_models_every,
        validation_every = validation_every,)

In [None]:
### do we have pre-trained model?
pretrained_model = None

# what is the start epoch?
start_epoch = 0 # if no pre-trained model, start from epoch 0, else start from the loaded epoch


In [None]:
# start training
trainer.train(pre_trained_model = pretrained_model, start_step = start_epoch)
# 如果跑不动（GPU内存不足），可以尝试减小model里initial_dim的值，比如改成4或者2