In [2]:
# make sure you have Github copilot installed, search it in the VSCode extension marketplace, it will make your coding much easier
import sys 
sys.path.append('/host/d/Github/')
import os
import torch
import numpy as np
import Example_UNet.model.model as model
import Example_UNet.functions_collection as ff
import Example_UNet.Build_lists.Build_list as Build_list
import Example_UNet.Generator as Generator

# if it says no module named ... (e.g., lpips), do the following:
# 1. go to powershell, type wsl, enter wsl
# 2. in wsl, type: sudo docker container ls, you will see the container id of your running container
# 3. type: sudo docker container exec -it -u 0 <container_id> bash
# 4. now you are inside the container root, type: pip install lpips

  from .autonotebook import tqdm as notebook_tqdm


### step 1: define trial name

In [3]:
trial_name = 'trial'

### step 2: define parameters (no need to change)

In [4]:
image_size = [512,512]
num_patches_per_slice = 2
patch_size = [128,128]

background_cutoff = -1000
maximum_cutoff = 2000
normalize_factor = 'equation'

### step 3: build patient list

In [None]:
# change the excel path to your own path
patient_list_spreadsheet = os.path.join('/host/d/Github/Example_UNet/example_data/Patient_lists/patient_list_example.xlsx')
build_sheet =  Build_list.Build(patient_list_spreadsheet)
_,_,input_file_train, reference_file_train = build_sheet.__build__(batch_list = [0,1]) 
 
# define val
_,_,input_file_val, reference_file_val = build_sheet.__build__(batch_list = [1])  # just as an example, use the same batch for val

print('input:', input_file_train.shape, ' reference:', reference_file_train.shape)
print('input file example: ', input_file_train[0], ' reference file: ', reference_file_train[0])

input: (2,)  reference: (2,)
input file example:  /host/d/Github/Example_UNet/example_data/data/ID_001/input.nii.gz  reference file:  /host/d/Github/Example_UNet/example_data/data/ID_001/output_reference.nii.gz


### Step 4: build model

In [7]:
# build model
our_model = model.Unet2D(
    init_dim = 16,
    channels = 1, 
    out_dim = 1,
    dim_mults = (2,4,8,16),
    full_attn = (None,None, False, True),
    act = 'ReLU',
)

in out is :  [(16, 32), (32, 64), (64, 128), (128, 256)]


### Step 5: Data generator

In [15]:
generator_train = Generator.Dataset_2D(
        input_list = input_file_train,
        reference_list = reference_file_train,

        image_size = image_size,

        num_slices_per_image = 50,
        random_pick_slice = True,
        slice_range = None, # None or [a,b]

        background_cutoff = background_cutoff,
        maximum_cutoff = maximum_cutoff,
        normalize_factor = normalize_factor,

        num_patches_per_slice = 2,
        patch_size = patch_size, # train on patches

        shuffle = True,
        augment = True,
        augment_frequency = 0.5,)

generator_val = Generator.Dataset_2D(
        input_list = input_file_train,
        reference_list = reference_file_train,

        image_size = image_size,

        num_slices_per_image = 20,
        random_pick_slice = False,
        slice_range = [0,20], # None or [a,b]

        background_cutoff = background_cutoff,
        maximum_cutoff = maximum_cutoff,
        normalize_factor = normalize_factor,

        num_patches_per_slice = 1,
        patch_size = image_size, ## validation on full image

        shuffle = False,
        augment = False,
        augment_frequency = 0,)

### step 6: train

In [9]:
# define pretrained model path if any
pre_trained_model = None 
start_step = 0

In [None]:
# train
# first define a path to save your model, and create the folder
model_save_path = os.path.join('/host/d/projects/Example_UNet/models', trial_name,'models')
ff.make_folder([os.path.dirname(model_save_path), model_save_path])


trainer = model.Trainer(
    model= our_model,
    generator_train = generator_train,
    generator_val = generator_val,
    train_batch_size = 5,

    train_num_steps = 1000, # total training epochs
    results_folder = model_save_path,
   
    train_lr = 1e-4,
    train_lr_decay_every = 200,  # define your own lr decay frequency
    save_models_every = 2, # define your own save frequency
    validation_every = 2, # define your own validation frequency
)

trainer.train(pre_trained_model=pre_trained_model, start_step= start_step)