'''use two adajacent noisy slices as input, use current noisy slice as output reference'''

In [1]:
import sys 
sys.path.append('/workspace/Documents')
import os
import torch
import numpy as np
import CTDenoising_Diffusion_N2N.noise2noise.model as noise2noise
import CTDenoising_Diffusion_N2N.functions_collection as ff
import CTDenoising_Diffusion_N2N.Build_lists.Build_list as Build_list
import CTDenoising_Diffusion_N2N.noise2noise.Generator as Generator

  from .autonotebook import tqdm as notebook_tqdm


### step 1: define trial name

In [22]:
trial_name = 'noise2noise'

### step 2: define parameters (no need to change)

In [14]:
image_size = [512,512]
num_patches_per_slice = 2
patch_size = [128,128]

histogram_equalization = True
background_cutoff = -1000
maximum_cutoff = 2000
normalize_factor = 'equation'

### step 3: build patient list

In [15]:
# change the excel path to your own path
build_sheet =  Build_list.Build(os.path.join('/mnt/camca_NAS/denoising/Patient_lists/fixedCT_static_simulation_train_test_gaussian_local.xlsx'))
_,_,_,_, condition_list_train, x0_list_train = build_sheet.__build__(batch_list = [0,1,2,3]) 
 
# define val
_,_,_,_, condition_list_val, x0_list_val = build_sheet.__build__(batch_list = [4])

print('train:', x0_list_train.shape, condition_list_train.shape, 'val:', x0_list_val.shape, condition_list_val.shape)
print(x0_list_train[0:5], condition_list_train[0:5], x0_list_val[0:5], condition_list_val[0:5])

train: (136,) (136,) val: (32,) (32,)
['/workspace/Documents/Data/denoising/fixedCT/00139437/0000258390/img_thinslice_partial.nii.gz'
 '/workspace/Documents/Data/denoising/fixedCT/00139437/0000258390/img_thinslice_partial.nii.gz'
 '/workspace/Documents/Data/denoising/fixedCT/00019599/0000029506/img_thinslice_partial.nii.gz'
 '/workspace/Documents/Data/denoising/fixedCT/00019599/0000029506/img_thinslice_partial.nii.gz'
 '/workspace/Documents/Data/denoising/fixedCT/00148611/0000455363/img_thinslice_partial.nii.gz'] ['/workspace/Documents/Data/denoising/simulation/00139437/0000258390/gaussian_random_0/recon.nii.gz'
 '/workspace/Documents/Data/denoising/simulation/00139437/0000258390/gaussian_random_1/recon.nii.gz'
 '/workspace/Documents/Data/denoising/simulation/00019599/0000029506/gaussian_random_0/recon.nii.gz'
 '/workspace/Documents/Data/denoising/simulation/00019599/0000029506/gaussian_random_1/recon.nii.gz'
 '/workspace/Documents/Data/denoising/simulation/00148611/0000455363/gaussian

### Step 4: build model

In [16]:
# build model
model = noise2noise.Unet2D(
    init_dim = 16,
    channels = 2, 
    out_dim = 1,
    dim_mults = (2,4,8,16),
    full_attn = (None,None, False, True),
    act = 'ReLU',
)

in out is :  [(16, 32), (32, 64), (64, 128), (128, 256)]


### Step 5: Data generator

In [17]:
# load histogram equalization pre-saved files
bins = np.load('/mnt/camca_NAS/denoising/Data/histogram_equalization/bins.npy') # change to your own path, they are in this repo as well (example data)
bins_mapped = np.load('/mnt/camca_NAS/denoising/Data/histogram_equalization/bins_mapped.npy')

In [18]:
generator_train = Generator.Dataset_2D(
        img_list = condition_list_train, 
        image_size = image_size,

        num_slices_per_image = 50,
        random_pick_slice = True,
        slice_range = None,

        num_patches_per_slice = num_patches_per_slice,
        patch_size = patch_size,

        bins = bins,
        bins_mapped = bins_mapped,
        histogram_equalization = histogram_equalization,
        background_cutoff = background_cutoff,
        maximum_cutoff = maximum_cutoff,
        normalize_factor = normalize_factor,

        shuffle = True,
        augment = True,
        augment_frequency = 0.5,)

generator_val = Generator.Dataset_2D(
        img_list = condition_list_val,
        image_size = image_size,

        num_slices_per_image = 20,
        random_pick_slice = False,
        slice_range = [50,70],

        num_patches_per_slice = 1,
        patch_size = [512,512],

        bins = bins,
        bins_mapped = bins_mapped,
        histogram_equalization = histogram_equalization,
        background_cutoff = background_cutoff,
        maximum_cutoff = maximum_cutoff,
        normalize_factor = normalize_factor,)

### step 6: train

In [19]:
### define pretrained model path if any
pre_trained_model = None#os.path.join('/mnt/camca_NAS/denoising/models', trial_name, 'models', 'model-13.pt')
start_step = 0#13

In [21]:
# train
trainer = noise2noise.Trainer(
    model= model,
    generator_train = generator_train,
    generator_val = generator_val,
    train_batch_size = 25,

    train_num_steps = 10000, # total training epochs
    results_folder = os.path.join('/mnt/camca_NAS/denoising/models', trial_name, 'models'),
   
    train_lr = 1e-4,
    train_lr_decay_every = 200, 
    save_models_every = 1,
    validation_every = 1,
)

trainer.train(pre_trained_model=pre_trained_model, start_step= start_step)