In [1]:
#### create config
from dataclasses import dataclass

@dataclass
class Config: 
    target_shape = None # will transform t1n during preprocessing (computationally expensive)
    unet_img_shape = (160,256)
    channels = 1
    effective_train_batch_size=16 
    eval_batch_size = 4   
    sorted_slice_sample_size = 1
    num_dataloader_workers = 8
    num_epochs = 192
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    evaluate_epochs = 5 # adjust to num_epochs
    evaluate_num_batches = -1 # ~3s/batch. 2.5 min/Evaluation 3D epoch with all batchesr
    evaluate_num_batches_3d = -1 
    evaluate_unconditional_img = True
    deactivate2Devaluation = False
    deactivate3Devaluation = True
    evaluate_3D_epochs = 1000  # 3 min/Evaluation 3D 
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "lesion-synthesis-256-unconditioned"  # the model name locally and on the HF Hub
    dataset_train_path = "./datasets/synthesis/dataset_train/imgs"
    segm_train_path = "./datasets/synthesis/dataset_train/segm"
    masks_train_path = "./datasets/synthesis/dataset_train/masks"
    dataset_eval_path = "./datasets/synthesis/dataset_eval/imgs"
    segm_eval_path = "./datasets/synthesis/dataset_eval/segm"
    masks_eval_path = "./datasets/synthesis/dataset_eval/masks"
    train_connected_masks=False  # No Training with lesion masks
    eval_connected_masks=False 
    num_inference_steps=50
    log_csv = True
    add_lesion_technique = "other_lesions_99Quantile" # 'mean_intensity', 'other_lesions_1stQuantile', 'other_lesions_mean', 'other_lesions_median', 'other_lesions_3rdQuantile', 'other_lesions_99Quantile' 
    intermediate_timestep = 3 # starting from this timesteps. num_inference_steps means the whole pipeline and 1 the last step. 
    mode = "train" # 'train', 'eval' or "tuning_parameters"
    debug = True     
    jump_length=intermediate_timestep
    jump_n_sample=3
    brightness_augmentation = False
    eval_loss_timesteps=[20,40,80,140]
    restrict_train_slices = "segm"
    restrict_eval_slices = "mask"
    use_min_snr_loss=True
    snr_gamma=5

    push_to_hub = False  # whether to upload the saved model to the HF Hub 
    seed = 0
config = Config()

In [2]:
#setup huggingface accelerate
import torch
import numpy as np
import accelerate
accelerate.commands.config.default.write_basic_config(config.mixed_precision)
#if there are problems with ports then add manually "main_process_port: 0" or another number to yaml file

Configuration already exists at /home/jovyan/.cache/huggingface/accelerate/default_config.yaml, will not override. Run `accelerate config` manually or pass a different `save_location`.


False

In [3]:
from pathlib import Path
import json
with open(Path.home() / ".cache/huggingface/accelerate/default_config.yaml") as f:
    data = json.load(f)
    config.num_processes = data["num_processes"]

In [4]:
config.train_batch_size = int((config.effective_train_batch_size / config.gradient_accumulation_steps) / config.num_processes)

In [5]:
if config.debug:
    config.num_inference_steps = 1
    config.intermediate_timestep = 1
    config.train_batch_size = 1
    config.eval_batch_size = 1
    config.eval_loss_timesteps = [20]
    config.train_connected_masks=False
    config.eval_connected_masks=False
    config.evaluate_num_batches = 3
    config.deactivate3Devaluation = False
    config.dataset_train_path = "./datasets/synthesis/dataset_eval/imgs"
    config.segm_train_path = "./datasets/synthesis/dataset_eval/segm"
    config.masks_train_path = "./datasets/synthesis/dataset_eval/masks" 
    config.num_dataloader_workers = 1

In [6]:
assert len(config.eval_loss_timesteps) == config.eval_batch_size

In [7]:
from custom_modules import DatasetMRI2D, DatasetMRI3D, ScaleDecorator
from pathlib import Path
from torchvision import transforms 

#add augmentation
transformations = None
if config.brightness_augmentation:
    transformations = transforms.RandomApply([ScaleDecorator(transforms.ColorJitter(brightness=1))], p=0.5)

#create dataset
dataset_train = DatasetMRI2D(root_dir_img=Path(config.dataset_train_path), restriction=config.restrict_train_slices, root_dir_segm=Path(config.segm_train_path), connected_masks=config.train_connected_masks, target_shape=config.target_shape, transforms=transformations)
dataset_evaluation = DatasetMRI2D(root_dir_img=Path(config.dataset_eval_path), restriction=config.restrict_eval_slices, root_dir_masks=Path(config.masks_eval_path), root_dir_synthesis=Path(config.masks_eval_path), connected_masks=config.eval_connected_masks, target_shape=config.target_shape)
dataset_3D_evaluation = DatasetMRI3D(root_dir_img=Path(config.dataset_eval_path), root_dir_masks=Path(config.masks_eval_path), root_dir_synthesis=Path(config.masks_eval_path), connected_masks=config.eval_connected_masks, target_shape=config.target_shape)

2024-06-14 11:58:20.804144: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-14 11:58:21.709384: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-14 11:58:21.709448: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-14 11:58:21.709504: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-14 11:58:21.764558: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-14 11:58:21.783251: I tensorflow/core/platform/cpu_feature_guard.cc:182] This Tens

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

### Finding good intensity of lesions

In [8]:
from pathlib import Path
import nibabel as nib
from tqdm.auto import tqdm 
from custom_modules import DatasetMRI2D, DatasetMRI3D, ScaleDecorator

#skip to speed up
if False:
    #t1w_norm_noskull_list = list(Path("temp/unhealthy_DL+DiReCT_Segmentation").rglob("*T1w_norm_noskull.nii.gz"))
    #lesion_list = list(Path("temp/unhealthy_registered_lesions").rglob("*transformed_lesion.nii.gz"))
    #t1w_norm_noskull_list = list(Path("datasets_full/unhealthy_flair").rglob("*3DFLAIR.nii.gz"))
    #lesion_list = list(Path("datasets_full/unhealthy_flair_masks").rglob("*Consensus.nii.gz"))

    dataset_eval_path="datasets_full/unhealthy_flair"
    masks_eval_path="datasets_full/unhealthy_flair_masks"
    
    dataset_3D_evaluation = DatasetMRI3D(root_dir_img=Path(dataset_eval_path), root_dir_masks=Path(masks_eval_path), root_dir_synthesis=Path(masks_eval_path), only_connected_masks=config.eval_only_connected_masks, t1n_target_shape=config.t1n_target_shape)
    
    t1ws = list()
    lesions = list()
    for i in tqdm(range(len(dataset_3D_evaluation))):
        t1w = dataset_3D_evaluation[i]["gt_image"]
        lesion = dataset_3D_evaluation[i]["mask"]
        
        #t1w = t1w.get_fdata()
        #t1w, _, lesion, _, _ = DatasetMRI3D.preprocess(t1w, masks=lesion)
        #lesion = lesion.get_fdata()
        #lesion = DatasetMRI3D._padding(torch.from_numpy(lesion).to(torch.uint8))
        #means.append(t1w[lesion.to(torch.bool)].mean())
        #stds.append(t1w[lesion.to(torch.bool)].std())
        t1ws.append(t1w)
        lesions.append(lesion)
    t1w_big = torch.cat(t1ws)
    lesion_big = torch.cat(lesions)
    
    print("mean lesion intensity: ", t1w_big[lesion_big.to(torch.bool)].mean()) # -0.3572
    print("std lesion intensity: ", t1w_big[lesion_big.to(torch.bool)].std()) # 0.1829
    print("median lesion intensity: ", t1w_big[lesion_big.to(torch.bool)].median()) # 0.1829

In [9]:
import matplotlib.pyplot as plt 
import random 
 

#skip to speed up
if False: 
    random_idx = torch.randint(len(dataset_evaluation)-1, size=(1,)).item()
    
    fig, axis = plt.subplots(1,4, figsize=(16,4)) 
    img = dataset_evaluation[random_idx]["gt_image"].squeeze()
    img[dataset_evaluation[random_idx]["mask"].to(torch.bool).squeeze()] = -0.5492 # mean-std of lesion intensity
    axis[0].imshow(dataset_evaluation[random_idx]["gt_image"].squeeze()/2+0.5)
    axis[0].set_axis_off()
    axis[1].imshow(dataset_evaluation[random_idx]["mask"].squeeze())
    axis[1].set_axis_off()
    axis[2].imshow(img/2+0.5)
    axis[2].set_axis_off()
    axis[3].imshow(dataset_evaluation[random_idx]["synthesis"].squeeze())
    axis[3].set_axis_off()
    fig.show()

### Playground LPIPS

In [10]:
#skip to speed up
if False:
    from Evaluation2DSynthesis import Evaluation2DSynthesis
    from accelerate import Accelerator
    from DDIMGuidedPipeline import DDIMGuidedPipeline
    
    pipeline = DDIMGuidedPipeline.from_pretrained(config.output_dir) 
    
    eval = Evaluation2DSynthesis(config, pipeline, None, None, Accelerator(), None)
    images_with_lesions, _ = eval._add_coarse_lesions(dataset_evaluation[123]["gt_image"], dataset_evaluation[123])
    clean_image=dataset_evaluation[123]["gt_image"]
    synthesized_images = pipeline(images_with_lesions, 1).images
    synthesized_images_2 = pipeline(images_with_lesions, 3).images
    print("Pips for one timestep: ", eval._calc_lpip(clean_image.unsqueeze(0), synthesized_images))
    print("Pips for three timestep: ", eval._calc_lpip(clean_image.unsqueeze(0), synthesized_images_2))

In [11]:
#skip to speed up
if False:
    fig, axis = plt.subplots(1,4, figsize=(24,12))
    axis[0].imshow(clean_image.squeeze()/2+0.5)
    axis[0].set_title("original image")
    axis[1].imshow(images_with_lesions.squeeze()/2+0.5)
    axis[1].set_title("coarse lesions")
    axis[2].imshow(synthesized_images.squeeze()/2+0.5)
    axis[2].set_title("Diffusion with one timestep")
    axis[3].imshow(synthesized_images_2.squeeze()/2+0.5)
    axis[3].set_title("Diffusion with three timestep")

In [12]:
#skip to speed up
if False:
    diff_1=clean_image-synthesized_images
    diff_3=clean_image-synthesized_images_2
    fig, axis = plt.subplots(1,2, figsize=(16,8))
    axis[0].imshow(diff_1.squeeze()/2+0.5)
    axis[0].set_title("Differences with one timestep")
    axis[1].imshow(diff_3.squeeze()/2+0.5)
    axis[1].set_title("Differences with three timestep")

### Training

In [13]:
#create model
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=config.unet_img_shape,  # the target image resolution
    in_channels=config.channels,  # the number of input channels, 3 for RGB images
    out_channels=config.channels,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

config.model = "UNet2DModel"

In [14]:
#setup noise scheduler
import torch
from PIL import Image
from diffusers import DDIMScheduler

noise_scheduler = DDIMScheduler(num_train_timesteps=1000)

config.noise_scheduler = "DDIMScheduler(num_train_timesteps=1000)"

In [15]:
# setup lr scheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
import math

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(math.ceil(len(dataset_train)/config.train_batch_size) * config.num_epochs), # num_iterations per epoch * num_epochs
)
config.lr_scheduler = "cosine_schedule_with_warmup"

In [16]:
from accelerate import Accelerator 

accelerator = Accelerator(
    mixed_precision=config.mixed_precision,
    gradient_accumulation_steps=config.gradient_accumulation_steps,  
)

In [17]:
from torch.utils.tensorboard import SummaryWriter
import os

if accelerator.is_main_process:
    if config.output_dir is not None:
        os.makedirs(config.output_dir, exist_ok=True) 
    #setup tensorboard
    tb_summary = SummaryWriter(config.output_dir, purge_step=0)
    accelerator.init_trackers("train_example") #maybe delete

In [18]:
if accelerator.is_main_process:
    from custom_modules import Logger
    logger = Logger(tb_summary, config.log_csv)
    logger.log_config(config)

In [19]:
from custom_modules import get_dataloader

train_dataloader = get_dataloader(dataset=dataset_train, batch_size = config.train_batch_size, 
                                  num_workers=config.num_dataloader_workers, random_sampler=True, 
                                  seed=config.seed, multi_slice=config.sorted_slice_sample_size > 1)
d2_eval_dataloader = get_dataloader(dataset=dataset_evaluation, batch_size = config.eval_batch_size, 
                                    num_workers=config.num_dataloader_workers, random_sampler=False, 
                                    seed=config.seed, multi_slice=config.sorted_slice_sample_size > 1)
d3_eval_dataloader = get_dataloader(dataset=dataset_3D_evaluation, batch_size = 1, 
                                    num_workers=config.num_dataloader_workers, random_sampler=False, 
                                    seed=config.seed, multi_slice=False) 

In [21]:
from custom_modules import ModelInputGenerator, Evaluation2DSynthesis, Evaluation3DSynthesis

model_input_generator = ModelInputGenerator(conditional=False, noise_scheduler=noise_scheduler)

args = {
    "intermediate_timestep": config.intermediate_timestep,
    "add_lesion_technique": config.add_lesion_technique,
    "eval_dataloader": d2_eval_dataloader, 
    "train_dataloader": train_dataloader,
    "logger": None if not accelerator.is_main_process else logger, 
    "accelerator": accelerator,
    "num_inference_steps": config.num_inference_steps,
    "model_input_generator": model_input_generator,
    "output_dir": config.output_dir,
    "eval_loss_timesteps": config.eval_loss_timesteps, 
    "evaluate_num_batches": config.evaluate_num_batches, 
    "seed": config.seed
}
evaluation2D = Evaluation2DSynthesis(**args)
args = {
    "intermediate_timestep": config.intermediate_timestep,
    "add_lesion_technique": config.add_lesion_technique,
    "dataloader": d3_eval_dataloader, 
    "logger": None if not accelerator.is_main_process else logger, 
    "accelerator": accelerator,
    "output_dir": config.output_dir,
    "num_inference_steps": config.num_inference_steps,
    "eval_batch_size": config.eval_batch_size,
    "sorted_slice_sample_size": config.sorted_slice_sample_size,
    "evaluate_num_batches": config.evaluate_num_batches_3d,
    "seed": config.seed,
}
evaluation3D = Evaluation3DSynthesis(**args)

In [23]:
from custom_modules import Training, GuidedRePaintPipeline, Evaluation2DSynthesis, Evaluation3DSynthesis 
from custom_modules import PipelineFactories

config.conditional_data = "None"

args = { 
    "accelerator": accelerator,
    "model": model, 
    "noise_scheduler": noise_scheduler, 
    "optimizer": optimizer, 
    "lr_scheduler": lr_scheduler, 
    "train_dataloader": train_dataloader, 
    "d2_eval_dataloader": d2_eval_dataloader, 
    "d3_eval_dataloader": d3_eval_dataloader, 
    "model_input_generator": model_input_generator,
    "evaluation2D": evaluation2D,
    "evaluation3D": evaluation3D,
    "logger": None if not accelerator.is_main_process else logger,
    "pipeline_factory": PipelineFactories.get_guided_repaint_pipeline,
    "num_epochs": config.num_epochs, 
    "evaluate_2D_epochs": config.evaluate_epochs,
    "evaluate_3D_epochs": config.evaluate_3D_epochs,
    "min_snr_loss": config.use_min_snr_loss,
    "snr_gamma": config.snr_gamma,
    "evaluate_unconditional_img": config.evaluate_unconditional_img,
    "deactivate_2D_evaluation": config.deactivate2Devaluation, 
    "deactivate_3D_evaluation": config.deactivate3Devaluation, 
    "evaluation_pipeline_parameters": {
        "jump_length": config.jump_length,
        "jump_n_sample": config.jump_n_sample,
    },
    "debug": config.debug, 
    }
trainingSynthesis = Training(**args)

In [None]:
if config.mode == "train":
    trainingSynthesis.train()

  0%|          | 0/226 [00:00<?, ?it/s]

In [None]:
if config.mode == "eval":
    trainingSynthesis.deactivate_3D_evaluation = False
    pipeline = GuidedRePaintPipeline.from_pretrained(config.output_dir) 
    trainingSynthesis.evaluate(pipeline, deactivate_save_model=True)

In [None]:
import os
import csv
from itertools import product
import matplotlib.pyplot as plt

if config.mode == "tuning_parameters":
    trainingSynthesis.deactivate_3D_evaluation = False
    pipeline = GuidedRePaintPipeline.from_pretrained(config.output_dir)
    original_output_dir = config.output_dir 
    timesteps = [1, 3, 5]
    resample_step = [1, 3, 5]
    parameters = product(timesteps, resample_step) 
    
    for timestep, resample_step in parameters:
        print("Begin evaluation for timestep ", timestep, " and resample step ", resample_step)
        trainingSynthesis.intermediate_timestep = timestep
        trainingSynthesis.jump_length = timestep
        trainingSynthesis.jump_n_sample = resample_step
        trainingSynthesis.output_dir = original_output_dir + "/tuning_parameters/timestep_" + str(timestep) + "_resample_" + str(resample_step)
        if trainingSynthesis.accelerator.is_main_process:
            os.makedirs(config.output_dir, exist_ok=True)
            trainingSynthesis.log_meta_logs()
        trainingSynthesis.evaluate(pipeline, deactivate_save_model=True) 

    # plot lpips score vs  
    folder_list = list(Path().glob(original_output_dir + "/tuning_parameters/timestep_*"))
    lpips = []
    labels = [] 
    for folder in folder_list:
        timestep = str(folder).split("_")[-3] 
        resample_step = str(folder).split("_")[-1]
        with open(folder / "metrics.csv", 'r') as fp: 
            _ = fp.readline()
            csv_metrics = fp.readline()
            reader = csv_metrics.split(',') 
            for metric in reader:
                if metric != "":
                    name, value = metric.split(':')
                    if name == "lpips":
                        lpips.append(float(value))
                        labels.append(timestep + "_" + resample_step) 
    plt.clf()
    plt.bar(labels, lpips) 
    plt.savefig(original_output_dir + '/lpips_parameters.png')

In [None]:
#create python script for ubelix 
import os

!jupyter nbconvert --to script "lesion_synthesis_unconditioned.ipynb"
filename="lesion_synthesis_unconditioned.py"

# delete this cell from python file
lines = []
with open(filename, 'r') as fp:
    lines = fp.readlines()
with open(filename, 'w') as fp:
    for number, line in enumerate(lines):
        if number < len(lines)-17: 
            fp.write(line)