In [1]:
import sys 
sys.path.append('/host/d/Github/')
import os
import torch
import numpy as np
import nibabel as nb
import random
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset

import CT_registration_diffusion.functions_collection as ff
import CT_registration_diffusion.Build_lists.Build_list as Build_list
import CT_registration_diffusion.Data_processing as Data_processing
import CT_registration_diffusion.Generator as Generator
import CT_registration_diffusion.model.model as model
import CT_registration_diffusion.model.train_engine as train_engine

  from .autonotebook import tqdm as notebook_tqdm


### define our trial name

In [2]:
trial_name = 'trial_1'

### step 1: define patient list

In [None]:
# change the excel path to your own path
patient_list_spreadsheet = os.path.join('/host/d/Data/4DCT/Patient_lists/ct_list.xlsx')
build_sheet =  Build_list.Build(patient_list_spreadsheet)

# define train
batch_list_train, dataset_id_list_train, case_id_list_train, image_folder_list_train = build_sheet.__build__(batch_list = [0])
# 先用一个case跑通代码
batch_list_train = batch_list_train[0:1]
dataset_id_list_train = dataset_id_list_train[0:1]
case_id_list_train = case_id_list_train[0:1]
image_folder_list_train = image_folder_list_train[0:1]


# define validation
batch_list_val, dataset_id_list_val, case_id_list_val, image_folder_list_val = build_sheet.__build__(batch_list = [3])
# 先用一个case跑通代码
batch_list_val = batch_list_val[0:1]
dataset_id_list_val = dataset_id_list_val[0:1]
case_id_list_val = case_id_list_val[0:1]
image_folder_list_val = image_folder_list_val[0:1]

### step 2: define generator

In [4]:
# define this generator
generator_train = Generator.Dataset_4DCT(
    image_folder_list = image_folder_list_train,
    
    image_size = [224,224,96], # target image size after center-crop

    cutoff_range = [-200,250], # default cutoff range for CT images
    shuffle = True,

    augment = True, # whether to do data augmentation
    augment_frequency = 0.5, )

generator_val = Generator.Dataset_4DCT(
    image_folder_list = image_folder_list_val,
    image_size = [224,224,96], # target image size after center-crop
    cutoff_range = [-200,250], # default cutoff range for CT images
    shuffle = False,
    augment = False, # whether to do data augmentation
    augment_frequency = 0.0, )

### step 3: model

In [9]:
# build model
our_model = model.Unet(
    problem_dimension = '3D',  # we are solving a 3D image registration problem
  
    input_channels = 2, # =1 if only moving image as input; =2 if both fixed and moving images as input
    out_channels = 3,  # =2 for 2D deformation field; =3 for 3D deformation field

    initial_dim = 8,  # default initial feature dimension after first conv layer
    dim_mults = (2,4,8,16),
    groups = 4,
      
    full_attn_paths = (None, None, None, None), # these are for downsampling and upsampling paths
    full_attn_bottleneck = None, # this is for the middle bottleneck layer
    act = 'ReLU',
)

in out is :  [(8, 16), (16, 32), (32, 64), (64, 128)]


### step 4: build trainer and start to train the model

In [10]:
train_batch_size = 1 # training batch size, set to 1 if your GPU memory is limited
regularization_weight = 1.0 # weight for deformation field smoothness regularization term
total_epochs = 1000
save_models_every = 1 # save model every N epochs
validation_every = 1 # validate every N epochs, should be same as save_models_every

# where to save your model weights? Change this path to your own path
results_folder = os.path.join('/host/d/projects/registration/models', trial_name)
ff.make_folder([results_folder])

trainer = train_engine.Trainer(
        our_model,
        generator_train,
        generator_val,
        train_batch_size = train_batch_size,
        regularization_weight = regularization_weight,
        train_num_steps = total_epochs,
        results_folder = results_folder,
       
        train_lr_decay_every = 100, 
        save_models_every = save_models_every,
        validation_every = validation_every,)

In [7]:
### do we have pre-trained model?
pretrained_model = None

# what is the start epoch?
start_epoch = 0 # if no pre-trained model, start from epoch 0, else start from the loaded epoch


In [None]:
# start training
trainer.train(pre_trained_model = pretrained_model, start_step = start_epoch)

# 如果跑不动（GPU内存不足），可以尝试减小model里initial_dim的值，比如改成4或者2

  0%|          | 0/1000 [00:00<?, ?it/s]

training epoch:  1
learning rate:  0.0001


average loss: 0.4100, average similarity loss: 0.0247, average regularization loss: 0.3853:   0%|          | 0/1000 [00:04<?, ?it/s]

validation at step:  1


average loss: 0.4100, average similarity loss: 0.0247, average regularization loss: 0.3853:   0%|          | 1/1000 [00:06<1:46:15,  6.38s/it]

validation loss: 0.5097, validation similarity loss: 0.0432, validation regularization loss: 0.4666
now run on_epoch_end function
now run on_epoch_end function
training epoch:  2
learning rate:  0.0001


average loss: 0.3225, average similarity loss: 0.0272, average regularization loss: 0.2953:   0%|          | 1/1000 [00:07<1:46:15,  6.38s/it]

validation at step:  2


average loss: 0.3225, average similarity loss: 0.0272, average regularization loss: 0.2953:   0%|          | 2/1000 [00:09<1:14:23,  4.47s/it]

validation loss: 0.4238, validation similarity loss: 0.0450, validation regularization loss: 0.3788
now run on_epoch_end function
now run on_epoch_end function
training epoch:  3
learning rate:  0.0001


average loss: 0.3225, average similarity loss: 0.0272, average regularization loss: 0.2953:   0%|          | 2/1000 [00:09<1:21:09,  4.88s/it]


KeyboardInterrupt: 