## Data Preparation

You should prepare the following things before running this step. I also prepare a set of example data in the folder ```example_data```.

1. **simulated dataset** 
   - check step 1
   - for example data: we prepare one case ```00004038/0000455420```, under the ```example_data/fixedCT``` is its clean low-noise ground truth, under the ```example_data/simulation``` we have ```gaussian_random_0``` for unsupervised learning and ```poisson_random_0``` for supervised learning.


2. **A patient list** that emunarates the dataset 
   - check step 2
   - for example data: we prepare two lists, ```example_data/Patient_lists/patient_list_unsupervised_gaussian.xlsx``` for unsupervised learning (our proposed method) and ```example_data/Patient_lists/patient_list_supervised_poisson.xlsx``` for supervised learning.


3. bins for **histogram equalization**
    - provided in ```/help_data```

---

## Task: Train the model

- we have two types of noisy data: type 1 (possion) and type 2 (gaussian)
- These are the settings of the model:
   - **supervised vs. unsupervised**: 
      - **supervised** represents training on pairs of noisy-free thin-slice and noisy thin-slice with type 1 noise. it will be tested on type 2 noise to evaluate domain shift influence; 
      - ***unsupervised** is our method based on diffusion+noise2noise and directly trained on type 2 noise.

   - **beta**: this is the weight of bias loss. The total loss = diffusion loss + beta * bias loss. currently beta = 0.

---

### Docker environment
Please use `docker/docker_pytorch`, it will build a pytorch docker


In [2]:
import sys 
sys.path.append('/workspace/Documents')
import os
import torch
import numpy as np 
import CTDenoising_Diffusion_N2N.denoising_diffusion_pytorch.denoising_diffusion_pytorch.conditional_diffusion as ddpm
import CTDenoising_Diffusion_N2N.functions_collection as ff
import CTDenoising_Diffusion_N2N.Build_lists.Build_list as Build_list
import CTDenoising_Diffusion_N2N.Generator as Generator

main_path = '/mnt/camca_NAS/denoising/'  # replace with your own path

  @autocast(enabled = False)


### step 1: define settings 

In [3]:
supervision = 'unsupervised' # 'unsupervised' or 'supervised'
noise_type = 'possion' if supervision == 'supervised' else 'gaussian'
beta = 0 # by default

trial_name = 'model_'+supervision + '_' + noise_type + '_beta' + str(beta)
print(trial_name)

model_unsupervised_gaussian_beta0


### step 2: set default parameters
usually you don't need to change

In [4]:
problem_dimension = '2D'
condition_channel = 1 if (supervision == 'supervised') or ('mean' in trial_name) else 2
image_size = [512,512]
num_patches_per_slice = 2
patch_size = [128,128]

objective = 'pred_x0'

histogram_equalization = True
background_cutoff = -1000
maximum_cutoff = 2000
normalize_factor = 'equation'

### step 3: define patient list

In [5]:
# define train
if supervision == 'supervised':
    build_sheet =  Build_list.Build(os.path.join(main_path, 'example_data/Patient_lists','patient_list_supervised_poisson.xlsx'))
else:
    build_sheet =  Build_list.Build(os.path.join(main_path, 'example_data/Patient_lists','patient_list_unsupervised_gaussian.xlsx'))

_,_,_,_, condition_list_train, x0_list_train = build_sheet.__build__(batch_list = [0]) # batch list selects which batch we will use for training. usually you will have several batches and you leave one for validation and another for testing. here for the purpose of example, we use the same data for training and validation. 
x0_list_train = x0_list_train[0:1]; condition_list_train = condition_list_train[0:1]  

# define val
_,_,_,_, condition_list_val, x0_list_val = build_sheet.__build__(batch_list = [0])
x0_list_val = x0_list_val[0:1]; condition_list_val = condition_list_val[0:1]


print('train:', x0_list_train.shape, condition_list_train.shape, 'val:', x0_list_val.shape, condition_list_val.shape)
print('training condition:', condition_list_train[0], ' x0:', x0_list_train[0])
print('validation condition:', condition_list_val[0], ' x0:', x0_list_val[0])

train: (1,) (1,) val: (1,) (1,)
training condition: /mnt/camca_NAS/denoising/example_data/simulation/00004038/0000455420/gaussian_random_0/recon.nii.gz  x0: /mnt/camca_NAS/denoising/example_data/fixedCT/00004038/0000455420/img_thinslice_partial.nii.gz
validation condition: /mnt/camca_NAS/denoising/example_data/simulation/00004038/0000455420/gaussian_random_0/recon.nii.gz  x0: /mnt/camca_NAS/denoising/example_data/fixedCT/00004038/0000455420/img_thinslice_partial.nii.gz


### step 4: define model

In [6]:
# define u-net and diffusion model
model = ddpm.Unet(
    problem_dimension = problem_dimension,
    init_dim = 64,
    out_dim = 1,
    channels = 1, 
    conditional_diffusion = True,
    condition_channels = condition_channel,

    downsample_list = (True, True, True, False), # don't change
    upsample_list = (True, True, True, False), # don't change
    full_attn = (None, None, False, True),) # if you have enough GPU memory, you can set True to False (meaning you change from full attention to linear attention); then you can further save GPU by setting False to None (remove attention)

diffusion_model = ddpm.GaussianDiffusion(
    model,
    image_size = image_size if num_patches_per_slice == None else patch_size,
    timesteps = 1000,
    sampling_timesteps = 250,
    objective = objective,
    clip_or_not =False,
    auto_normalize = False,)


is ddim sampling True


### step 5: define data generator (Training and validation)

In [7]:
generator_train = Generator.Dataset_2D(
        supervision = supervision,

        img_list = x0_list_train,
        condition_list = condition_list_train,
        image_size = image_size,

        num_slices_per_image = 50,
        random_pick_slice = True,
        slice_range = None,

        num_patches_per_slice = num_patches_per_slice,
        patch_size = patch_size,

        histogram_equalization = histogram_equalization,
        bins = np.load('./help_data/histogram_equalization/bins.npy'),
        bins_mapped = np.load('./help_data/histogram_equalization/bins_mapped.npy'),

        background_cutoff = background_cutoff,
        maximum_cutoff = maximum_cutoff,
        normalize_factor = normalize_factor,

        shuffle = True,
        augment = True,
        augment_frequency = 0.5,)

generator_val = Generator.Dataset_2D(
        supervision = supervision,

        img_list = x0_list_val,
        condition_list = condition_list_val,
        image_size = image_size,

        num_slices_per_image = 20,
        random_pick_slice = False,
        slice_range = [50,70],

        num_patches_per_slice = 1,
        patch_size = [512,512],

        histogram_equalization = histogram_equalization,
        bins = np.load('./help_data/histogram_equalization/bins.npy'),
        bins_mapped = np.load('./help_data/histogram_equalization/bins_mapped.npy'),
        
        background_cutoff = background_cutoff,
        maximum_cutoff = maximum_cutoff,
        normalize_factor = normalize_factor,)

### train

In [8]:
### define trainer
results_folder = os.path.join(main_path, 'models', trial_name, 'models')
ff.make_folder([os.path.join(main_path, 'models'), os.path.join(main_path, 'models', trial_name), results_folder, os.path.join(main_path, 'models', trial_name, 'log')])

trainer = ddpm.Trainer(
    diffusion_model= diffusion_model,
    generator_train = generator_train,
    generator_val = generator_val,
    train_batch_size = 25, # make it small if you have limited GPU memory
    
    accum_iter = 1,
    train_num_steps = 200, # total training epochs
    results_folder = os.path.join('/mnt/camca_NAS/denoising/models', trial_name, 'models'),
   
    train_lr = 1e-4,
    train_lr_decay_every = 200, 
    save_models_every = 1,
    validation_every = 1,)

conditional diffusion:  True


In [9]:
# define pretrained model if any
pre_trained_model = None
start_step = 0 # define it as 0 if not using pre-trained model

In [11]:
# train
trainer.train(pre_trained_model=pre_trained_model, start_step= start_step, beta = beta)