'''use two adajacent noisy slices as input, use current noisy slice as output reference'''

In [1]:
import sys 
sys.path.append('/workspace/Documents')
import os
import torch
import numpy as np
import Example_UNet.model.model as model
import Example_UNet.functions_collection as ff
import Example_UNet.Build_lists.Build_list as Build_list
import Example_UNet.Generator as Generator

  from .autonotebook import tqdm as notebook_tqdm


### step 1: define trial name

In [2]:
trial_name = 'trial'

### step 2: define parameters (no need to change)

In [3]:
image_size = [512,512]
num_patches_per_slice = 2
patch_size = [128,128]

background_cutoff = -1000
maximum_cutoff = 2000
normalize_factor = 'equation'

### step 3: build patient list

In [4]:
# change the excel path to your own path
patient_list_spreadsheet = os.path.join('/workspace/Documents/Example_UNet/example_data/Patient_lists/patient_list_example.xlsx') # change to your own path
build_sheet =  Build_list.Build(patient_list_spreadsheet)
_,_,input_file_train, reference_file_train = build_sheet.__build__(batch_list = [0,1]) 
 
# define val
_,_,input_file_val, reference_file_val = build_sheet.__build__(batch_list = [1])  # just as an example, use the same batch for val

print('input:', input_file_train.shape, ' reference:', reference_file_train.shape)
print('input file example: ', input_file_train[0], ' reference file: ', reference_file_train[0])

input: (2,)  reference: (2,)
input file example:  /workspace/Documents/Example_UNet/example_data/data/ID_001/input.nii.gz  reference file:  /workspace/Documents/Example_UNet/example_data/data/ID_001/output_reference.nii.gz


### Step 4: build model

In [5]:
# build model
our_model = model.Unet2D(
    init_dim = 16,
    channels = 1, 
    out_dim = 1,
    dim_mults = (2,4,8,16),
    full_attn = (None,None, False, True),
    act = 'ReLU',
)

in out is :  [(16, 32), (32, 64), (64, 128), (128, 256)]


### Step 5: Data generator

In [6]:
generator_train = Generator.Dataset_2D(
        input_list = input_file_train,
        reference_list = reference_file_train,

        image_size = image_size,

        num_slices_per_image = 50,
        random_pick_slice = True,
        slice_range = None, # None or [a,b]

        background_cutoff = background_cutoff,
        maximum_cutoff = maximum_cutoff,
        normalize_factor = normalize_factor,

        num_patches_per_slice = 2,
        patch_size = patch_size, # train on patches

        shuffle = True,
        augment = True,
        augment_frequency = 0.5,)

generator_val = Generator.Dataset_2D(
        input_list = input_file_train,
        reference_list = reference_file_train,

        image_size = image_size,

        num_slices_per_image = 20,
        random_pick_slice = False,
        slice_range = [0,20], # None or [a,b]

        background_cutoff = background_cutoff,
        maximum_cutoff = maximum_cutoff,
        normalize_factor = normalize_factor,

        num_patches_per_slice = 1,
        patch_size = image_size, ## validation on full image

        shuffle = False,
        augment = False,
        augment_frequency = 0,)

### step 6: train

In [7]:
# define pretrained model path if any
pre_trained_model = None 
start_step = 0

In [8]:
# train
trainer = model.Trainer(
    model= our_model,
    generator_train = generator_train,
    generator_val = generator_val,
    train_batch_size = 5,

    train_num_steps = 100, # total training epochs
    results_folder = os.path.join('/mnt/camca_NAS/denoising/models', trial_name, 'models'), # define your own save path
   
    train_lr = 1e-4,
    train_lr_decay_every = 200,  # define your own lr decay frequency
    save_models_every = 1, # define your own save frequency
    validation_every = 1, # define your own validation frequency
)

trainer.train(pre_trained_model=pre_trained_model, start_step= start_step)

  0%|          | 0/100 [00:00<?, ?it/s]

training epoch:  1
learning rate:  0.0001
random origin x is:  1  and random origin y is:  374
random origin x is:  344  and random origin y is:  215
random origin x is:  159  and random origin y is:  221
z rotate degree is:  1.9384927083056631  translate is:  -1
random origin x is:  169  and random origin y is:  23
random origin x is:  70  and random origin y is:  34
z rotate degree is:  9.046319777224138  translate is:  4
random origin x is:  221  and random origin y is:  108
z rotate degree is:  6.640953412170308  translate is:  8
random origin x is:  58  and random origin y is:  327
random origin x is:  68  and random origin y is:  331
random origin x is:  112  and random origin y is:  115
z rotate degree is:  -3.086908525883098  translate is:  0
random origin x is:  272  and random origin y is:  373
z rotate degree is:  -5.640099117602764  translate is:  2
random origin x is:  141  and random origin y is:  76
z rotate degree is:  -8.177807689149715  translate is:  -1
random origin

average loss: 0.8192:   0%|          | 0/100 [00:17<?, ?it/s]

z rotate degree is:  -0.0004065853122945384  translate is:  0
random origin x is:  282  and random origin y is:  278
random origin x is:  75  and random origin y is:  76
z rotate degree is:  -2.553919938985782  translate is:  9
validation at step:  1
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random 

average loss: 0.8192:   1%|          | 1/100 [00:30<50:04, 30.35s/it]

validation loss:  0.21189108677208424
now run on_epoch_end function
now run on_epoch_end function
training epoch:  2
learning rate:  0.0001
random origin x is:  333  and random origin y is:  330
random origin x is:  93  and random origin y is:  253
z rotate degree is:  -9.599975696096996  translate is:  4
random origin x is:  259  and random origin y is:  315
z rotate degree is:  6.975830184121882  translate is:  -4
random origin x is:  277  and random origin y is:  264
random origin x is:  1  and random origin y is:  328
random origin x is:  175  and random origin y is:  310
random origin x is:  212  and random origin y is:  275
z rotate degree is:  -2.024095372005153  translate is:  -4
random origin x is:  97  and random origin y is:  271
z rotate degree is:  7.220625842945651  translate is:  -8
random origin x is:  124  and random origin y is:  131
z rotate degree is:  -9.772397948232426  translate is:  -1
random origin x is:  121  and random origin y is:  8
random origin x is:  267

average loss: 0.2181:   1%|          | 1/100 [00:43<50:04, 30.35s/it]

z rotate degree is:  -5.709833337678041  translate is:  9
random origin x is:  355  and random origin y is:  166
z rotate degree is:  -3.8141403296754373  translate is:  8
random origin x is:  13  and random origin y is:  282
random origin x is:  66  and random origin y is:  241
z rotate degree is:  1.895392780312779  translate is:  7
random origin x is:  217  and random origin y is:  153
z rotate degree is:  -7.773556169729847  translate is:  5
validation at step:  2
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x is:  0  and random origin y is:  0
random origin x i

average loss: 0.2181:   2%|▏         | 2/100 [00:58<47:35, 29.14s/it]

now run on_epoch_end function
now run on_epoch_end function
training epoch:  3
learning rate:  0.0001
random origin x is:  142  and random origin y is:  160
z rotate degree is:  5.69461611880865  translate is:  -1
random origin x is:  112  and random origin y is:  98
z rotate degree is:  3.3948419173150626  translate is:  3
random origin x is:  69  and random origin y is:  242
z rotate degree is:  1.6318651719760382  translate is:  3
random origin x is:  63  and random origin y is:  116
random origin x is:  42  and random origin y is:  221
random origin x is:  253  and random origin y is:  63
z rotate degree is:  2.3483836013397212  translate is:  7
random origin x is:  76  and random origin y is:  237
random origin x is:  356  and random origin y is:  256
z rotate degree is:  8.590233321862769  translate is:  7
random origin x is:  88  and random origin y is:  323
z rotate degree is:  8.342506190666029  translate is:  7
random origin x is:  98  and random origin y is:  1
z rotate degr

average loss: 0.1422:   2%|▏         | 2/100 [01:03<47:35, 29.14s/it]

z rotate degree is:  -7.339174545360832  translate is:  2
validation at step:  3


average loss: 0.1422:   2%|▏         | 2/100 [01:07<54:53, 33.61s/it]


KeyboardInterrupt: 