'''use two adajacent noisy slices as input, use current noisy slice as output reference'''

In [None]:
# make sure you have Github copilot installed, search it in the VSCode extension marketplace, it will make your coding much easier
import sys 
sys.path.append('/host/d/Github/')
import os
import torch
import numpy as np
import CTDenoising_Diffusion_N2N.noise2noise.model as noise2noise
import CTDenoising_Diffusion_N2N.functions_collection as ff
import CTDenoising_Diffusion_N2N.Build_lists.Build_list as Build_list
import CTDenoising_Diffusion_N2N.noise2noise.Generator as Generator

# if it says no module named ... (e.g., lpips), do the following:
# 1. go to powershell, type wsl, enter wsl
# 2. in wsl, type: sudo docker container ls, you will see the container id of your running container
# 3. type: sudo docker container exec -it -u 0 <container_id> bash
# 4. now you are inside the container root, type: pip install lpips

### step 1: define trial name

In [8]:
trial_name = 'noise2noise'

### step 2: define parameters (no need to change)

In [9]:
image_size = [512,512]
num_patches_per_slice = 2
patch_size = [128,128]

histogram_equalization = True
background_cutoff = -1000
maximum_cutoff = 2000
normalize_factor = 'equation'

### step 3: build patient list

In [None]:
# change the excel path to your own path
# example excel path (you can find it in this repo - example data folder)
# build_sheet =  Build_list.Build(os.path.join('/host/d/Github/CTDenoising_Diffusion_N2N/example_data/patient_lists/patient_list_supervised.xlsx'))
# use the spreadsheet I gave to you and put the path here:
build_sheet =  Build_list.Build(os.path.join('/host/d/projects/denoising/Patient_lists/fixedCT_static_simulation_train_test_gaussian_local.xlsx'))

_,_,_,_, condition_list_train, x0_list_train = build_sheet.__build__(batch_list = [0,1,2,3])
 
# define val
_,_,_,_, condition_list_val, x0_list_val = build_sheet.__build__(batch_list = [4])

print('train:', x0_list_train.shape, condition_list_train.shape, 'val:', x0_list_val.shape, condition_list_val.shape)
print(x0_list_train[0:5], condition_list_train[0:5], x0_list_val[0:5], condition_list_val[0:5])

train: (1,) (1,) val: (1,) (1,)
['/host/d/Github/CTDenoising_Diffusion_N2N/example_data/fixedCT/00004038/0000455420/img_thinslice_partial.nii.gz'] ['/host/d/Github/CTDenoising_Diffusion_N2N/example_data/simulation/00004038/0000455420/poisson_random_0/recon.nii.gz'] ['/host/d/Github/CTDenoising_Diffusion_N2N/example_data/fixedCT/00004038/0000455420/img_thinslice_partial.nii.gz'] ['/host/d/Github/CTDenoising_Diffusion_N2N/example_data/simulation/00004038/0000455420/poisson_random_0/recon.nii.gz']


### Step 4: build model

In [11]:
# build model
model = noise2noise.Unet2D(
    init_dim = 16,
    channels = 2, 
    out_dim = 1,
    dim_mults = (2,4,8,16),
    full_attn = (None,None, False, True),
    act = 'ReLU',
)

in out is :  [(16, 32), (32, 64), (64, 128), (128, 256)]


### Step 5: Data generator

In [13]:
# load histogram equalization pre-saved files
# change to your own path, they are in this repo - example data folder
bins = np.load('/host/d/Github/CTDenoising_Diffusion_N2N/example_data/histogram_equalization/bins.npy') 
bins_mapped = np.load('/host/d/Github/CTDenoising_Diffusion_N2N/example_data/histogram_equalization/bins_mapped.npy')

In [14]:
generator_train = Generator.Dataset_2D(
        img_list = condition_list_train, 
        image_size = image_size,

        num_slices_per_image = 50,
        random_pick_slice = True,
        slice_range = None,

        num_patches_per_slice = num_patches_per_slice,
        patch_size = patch_size, # train on patches to save GPU memory

        bins = bins,
        bins_mapped = bins_mapped,
        histogram_equalization = histogram_equalization,
        background_cutoff = background_cutoff,
        maximum_cutoff = maximum_cutoff,
        normalize_factor = normalize_factor,

        shuffle = True,
        augment = True,
        augment_frequency = 0.5,)

generator_val = Generator.Dataset_2D(
        img_list = condition_list_val,
        image_size = image_size,

        num_slices_per_image = 20,
        random_pick_slice = False,
        slice_range = [50,70],

        num_patches_per_slice = 1,
        patch_size = [512,512], # validate on full image

        bins = bins,
        bins_mapped = bins_mapped,
        histogram_equalization = histogram_equalization,
        background_cutoff = background_cutoff,
        maximum_cutoff = maximum_cutoff,
        normalize_factor = normalize_factor,)

### step 6: train

In [15]:
### define pretrained model path if any
pre_trained_model = None
start_step = 0

In [None]:
# train
# first define a path to save your model
model_save_path = os.path.join('/host/d/projects/denoising/models', trial_name,'models')
ff.make_folder([os.path.dirname(model_save_path), model_save_path])

trainer = noise2noise.Trainer(
    model= model,
    generator_train = generator_train,
    generator_val = generator_val,
    train_batch_size = 25,

    train_num_steps = 10000, # total training epochs
    results_folder = model_save_path,
   
    train_lr = 1e-4,
    train_lr_decay_every = 200, # define learning rate decay schedule
    save_models_every = 1, # save model frequency (in epochs)
    validation_every = 1, # validation frequency (in epochs)
)

trainer.train(pre_trained_model=pre_trained_model, start_step= start_step)