## Model training

In this script, we perform the model training


In [1]:
import sys 
sys.path.append('/workspace/Documents')
import os
import torch
import numpy as np
import pandas as pd
import nibabel as nb
from ema_pytorch import EMA
import Cardiac4DCT_Synth_Diffusion.denoising_diffusion_pytorch.denoising_diffusion_pytorch.conditional_diffusion_3D as ddpm_3D
import Cardiac4DCT_Synth_Diffusion.denoising_diffusion_pytorch.denoising_diffusion_pytorch.conditional_EDM as edm
import Cardiac4DCT_Synth_Diffusion.denoising_diffusion_pytorch.denoising_diffusion_pytorch.conditional_EDM_warp as edm_warp
import Cardiac4DCT_Synth_Diffusion.Build_lists.Build_list as Build_list
import Cardiac4DCT_Synth_Diffusion.Generator as Generator
main_path = '/mnt/camca_NAS/4DCT'

  from .autonotebook import tqdm as notebook_tqdm
  @autocast(enabled = False)


### step 1: set some default parameters

In [2]:
trial_name = 'MVF_EDM' 
EF_loss_weight = 1

how_many_timeframes_together = 10
picked_tf = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

pre_trained_model = None
start_step = 0

mvf_size_3D = [40,40,24]
mvf_slice_range = [0,96]
mvf_folder = os.path.join(main_path,'example_data/mvf_warp0_onecase')

downsample_list =  (True, True, False, False) 

augment_pre_done = True # done in step2 jupyter notebook
conditional_diffusion_timeframe = False
conditional_diffusion_image = True
conditional_diffusion_EF = True 
conditional_diffusion_seg = False

### step 2: define your own train and validation data

In [3]:
# # define training and validation data
data_sheet = os.path.join(main_path,'example_data/Patient_lists/example_data/patient_list.xlsx')

b = Build_list.Build(data_sheet)
patient_class_train_list, patient_id_train_list,_ = b.__build__(batch_list = [0])
patient_class_val_list = patient_class_train_list
patient_id_val_list = patient_id_train_list

print('patient_class_train_list:', len(patient_class_train_list), ' patient_class_val_list:', len(patient_class_val_list))

patient_class_train_list: 2  patient_class_val_list: 2


### step 3: define diffusion model

In [4]:
# define diffusion model
model = ddpm_3D.Unet3D_tfcondition(
    init_dim = 64,
    channels = 3 * how_many_timeframes_together,
    out_dim = 3 * how_many_timeframes_together,
    # conditional_timeframe_input_dim = None,
    # conditional_diffusion_timeframe = conditional_diffusion_timeframe,
    conditional_diffusion_image = conditional_diffusion_image,
    conditional_diffusion_EF = conditional_diffusion_EF,
    conditional_diffusion_seg = conditional_diffusion_seg, # should be False
    dim_mults = (1, 2, 4, 8),
    downsample_list = downsample_list,
    upsample_list = (downsample_list[2], downsample_list[1], downsample_list[0], False),
    flash_attn = False, 
    full_attn = (None, None, False, False), )
 
diffusion_model = edm.EDM(
    model,
    image_size = mvf_size_3D,
    num_sample_steps = 50,
    clip_or_not = False,) 

### step 4: define generators

In [5]:
# define generator
timeframe_info = pd.read_excel(os.path.join(main_path,'example_data/Patient_lists/example_data/patient_list_final_selection_timeframes.xlsx'))

generator_train = Generator.Dataset_dual_3D(
    patient_class_list = patient_class_train_list,
    patient_id_list = patient_id_train_list,
    main_path = main_path,
    timeframe_info = timeframe_info,
    
    how_many_timeframes_together = how_many_timeframes_together,

    mvf_size_3D = mvf_size_3D,
    slice_range = mvf_slice_range,
    
    picked_tf = picked_tf,
    condition_on_image = True,
    prepare_seg = True,
    mvf_cutoff = [-20,20],
    shuffle = True,
    augment = True,
    augment_frequency = 0.8, 
    augment_pre_done = augment_pre_done,
    augment_aug_index = [1,2])

generator_val = Generator.Dataset_dual_3D(
    patient_class_list = patient_class_train_list,
    patient_id_list = patient_id_train_list,
    main_path = main_path,
    timeframe_info = timeframe_info,
    
    how_many_timeframes_together = how_many_timeframes_together,

    mvf_size_3D = mvf_size_3D,
    slice_range = mvf_slice_range,
    
    picked_tf = picked_tf,
    condition_on_image = True,
    prepare_seg = True,
    mvf_cutoff = [-20,20],
    augment = False,
    augment_pre_done= augment_pre_done,)

### step 5: train the model

In [None]:
# define pretrained model if any
pre_trained_model = None

In [8]:
start_step = 0

# define trainer
trainer = edm_warp.Trainer(diffusion_model= diffusion_model, 
                            generator_train = generator_train, 
                            generator_val = generator_val,  
                            EF_loss_weight = EF_loss_weight,
                            train_batch_size = 2,
                            results_folder = os.path.join(main_path,'example_data/models', trial_name, 'models'),
)


trainer.train_num_steps = 1500
trainer.train_lr = 1e-4
trainer.train_lr_decay_every = 500
trainer.save_models_every = 10
trainer.validation_every = 10

trainer.train(pre_trained_model=pre_trained_model, start_step= start_step)

conditional_image:  True  condition_EF:  True  condition_seg:  False


  0%|          | 0/1500 [00:00<?, ?it/s]

training epoch:  1
learning rate:  0.0001


average loss: 1.0097:   0%|          | 1/1500 [00:00<14:39,  1.70it/s]

average loss:  1.0096626281738281 average diffusion loss:  1.0096626281738281  average EF loss for factual:  0.0013932535657659173  average EF loss for counterfactual:  0.1352357566356659
now run on_epoch_end function
now run on_epoch_end function
training epoch:  2
learning rate:  0.0001


average loss: 1.0565:   0%|          | 2/1500 [00:01<14:50,  1.68it/s]

average loss:  1.0564978122711182 average diffusion loss:  1.0564978122711182  average EF loss for factual:  0.01633423939347267  average EF loss for counterfactual:  0.1374102085828781
now run on_epoch_end function
now run on_epoch_end function
training epoch:  3
learning rate:  0.0001


average loss: 0.6646:   0%|          | 3/1500 [00:01<13:54,  1.79it/s]

average loss:  0.6646173000335693 average diffusion loss:  0.6646173000335693  average EF loss for factual:  0.07123744487762451  average EF loss for counterfactual:  0.1268320530653
now run on_epoch_end function
now run on_epoch_end function
training epoch:  4
learning rate:  0.0001


average loss: 0.8329:   0%|          | 4/1500 [00:02<13:35,  1.83it/s]

average loss:  0.8329178094863892 average diffusion loss:  0.8329178094863892  average EF loss for factual:  0.048838552087545395  average EF loss for counterfactual:  0.3127792775630951
now run on_epoch_end function
now run on_epoch_end function
training epoch:  5
learning rate:  0.0001


average loss: 0.9047:   0%|          | 5/1500 [00:02<13:33,  1.84it/s]

average loss:  0.9046949148178101 average diffusion loss:  0.9046949148178101  average EF loss for factual:  0.043366432189941406  average EF loss for counterfactual:  0.31793615221977234
now run on_epoch_end function
now run on_epoch_end function
training epoch:  6
learning rate:  0.0001


average loss: 0.4820:   0%|          | 6/1500 [00:03<13:15,  1.88it/s]

average loss:  0.48204296827316284 average diffusion loss:  0.48204296827316284  average EF loss for factual:  0.10319265723228455  average EF loss for counterfactual:  0.11382845044136047
now run on_epoch_end function
now run on_epoch_end function
training epoch:  7
learning rate:  0.0001


average loss: 0.4539:   0%|          | 7/1500 [00:03<14:02,  1.77it/s]

average loss:  0.45393040776252747 average diffusion loss:  0.45393040776252747  average EF loss for factual:  0.18206103146076202  average EF loss for counterfactual:  0.0017920746468007565
now run on_epoch_end function
now run on_epoch_end function
training epoch:  8
learning rate:  0.0001


average loss: 0.8582:   1%|          | 8/1500 [00:04<13:38,  1.82it/s]

average loss:  0.8582490682601929 average diffusion loss:  0.8582490682601929  average EF loss for factual:  0.004037504084408283  average EF loss for counterfactual:  0.022405628114938736
now run on_epoch_end function
now run on_epoch_end function
training epoch:  9
learning rate:  0.0001


average loss: 0.6921:   1%|          | 9/1500 [00:04<13:24,  1.85it/s]

average loss:  0.692130446434021 average diffusion loss:  0.692130446434021  average EF loss for factual:  0.08118289709091187  average EF loss for counterfactual:  0.0263849887996912
now run on_epoch_end function
now run on_epoch_end function
training epoch:  10
learning rate:  0.0001


average loss: 0.5289:   1%|          | 9/1500 [00:05<13:24,  1.85it/s]

average loss:  0.5288728475570679 average diffusion loss:  0.5288728475570679  average EF loss for factual:  0.18977992236614227  average EF loss for counterfactual:  0.08865281939506531
validation at step:  10


average loss: 0.5289:   1%|          | 10/1500 [00:22<2:21:56,  5.72s/it]

validation loss:  0.8826918601989746  validation diffusion loss:  0.8826918601989746  validation EF loss for factual:  0.004359617363661528  validation EF loss for counterfactual:  0.10569492727518082
now run on_epoch_end function
now run on_epoch_end function
training epoch:  11
learning rate:  0.0001


average loss: 0.6036:   1%|          | 11/1500 [00:22<1:42:28,  4.13s/it]

average loss:  0.6035605669021606 average diffusion loss:  0.6035605669021606  average EF loss for factual:  0.05417240783572197  average EF loss for counterfactual:  0.0831354483962059
now run on_epoch_end function
now run on_epoch_end function
training epoch:  12
learning rate:  0.0001


average loss: 0.7858:   1%|          | 12/1500 [00:23<1:14:59,  3.02s/it]

average loss:  0.7858350276947021 average diffusion loss:  0.7858350276947021  average EF loss for factual:  0.015112066641449928  average EF loss for counterfactual:  0.10164407640695572
now run on_epoch_end function
now run on_epoch_end function
training epoch:  13
learning rate:  0.0001


average loss: 0.4516:   1%|          | 13/1500 [00:23<56:03,  2.26s/it]  

average loss:  0.4516294002532959 average diffusion loss:  0.4516294002532959  average EF loss for factual:  0.04537529498338699  average EF loss for counterfactual:  0.05671017989516258
now run on_epoch_end function
now run on_epoch_end function
training epoch:  14
learning rate:  0.0001


average loss: 0.7414:   1%|          | 14/1500 [00:24<42:50,  1.73s/it]

average loss:  0.7414093017578125 average diffusion loss:  0.7414093017578125  average EF loss for factual:  0.015088777989149094  average EF loss for counterfactual:  0.2918218970298767
now run on_epoch_end function
now run on_epoch_end function
training epoch:  15
learning rate:  0.0001


average loss: 0.5221:   1%|          | 15/1500 [00:24<33:42,  1.36s/it]

average loss:  0.5221033096313477 average diffusion loss:  0.5221033096313477  average EF loss for factual:  0.05714167654514313  average EF loss for counterfactual:  0.00806721206754446
now run on_epoch_end function
now run on_epoch_end function
training epoch:  16
learning rate:  0.0001


average loss: 0.3031:   1%|          | 16/1500 [00:25<27:56,  1.13s/it]

average loss:  0.3030874729156494 average diffusion loss:  0.3030874729156494  average EF loss for factual:  0.08668564260005951  average EF loss for counterfactual:  0.25117528438568115
now run on_epoch_end function
now run on_epoch_end function
training epoch:  17
learning rate:  0.0001


average loss: 0.5049:   1%|          | 17/1500 [00:25<23:31,  1.05it/s]

average loss:  0.5049225687980652 average diffusion loss:  0.5049225687980652  average EF loss for factual:  0.11188337206840515  average EF loss for counterfactual:  0.13713885843753815
now run on_epoch_end function
now run on_epoch_end function
training epoch:  18
learning rate:  0.0001


average loss: 0.5505:   1%|          | 18/1500 [00:26<19:49,  1.25it/s]

average loss:  0.5505450367927551 average diffusion loss:  0.5505450367927551  average EF loss for factual:  0.03741660341620445  average EF loss for counterfactual:  0.21775120496749878
now run on_epoch_end function
now run on_epoch_end function
training epoch:  19
learning rate:  0.0001


average loss: 0.6070:   1%|▏         | 19/1500 [00:26<17:13,  1.43it/s]

average loss:  0.6070193648338318 average diffusion loss:  0.6070193648338318  average EF loss for factual:  0.12068792432546616  average EF loss for counterfactual:  0.1949847936630249
now run on_epoch_end function
now run on_epoch_end function
training epoch:  20
learning rate:  0.0001


average loss: 0.5175:   1%|▏         | 19/1500 [00:27<17:13,  1.43it/s]

average loss:  0.5174548625946045 average diffusion loss:  0.5174548625946045  average EF loss for factual:  0.18598932027816772  average EF loss for counterfactual:  0.012523302808403969
