In [None]:
import logging
import os
import sys

import torch
import numpy as np

import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import DataLoader, ImageDataset
from monai.transforms import (
    EnsureChannelFirst,
    Compose,
    Resize,
    ScaleIntensity,
    ToTensor,
)

pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device:', device)

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
print_config()

  from .autonotebook import tqdm as notebook_tqdm


Device: cpu
MONAI version: 1.5.0
Numpy version: 1.26.1
Pytorch version: 2.1.2+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: d388d1c6fec8cb3a0eebee5b5a0b9776ca59ca83
MONAI __file__: /home/<username>/Documents/Projects/synthetic-CT-slices/.venv/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: 10.0.1
Tensorboard version: 2.19.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.16.2+cu121
tqdm version: 4.66.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 7.0.0
pandas version: 2.3.1
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INST

Define Hyperamaters

In [2]:
from dataclasses import dataclass, asdict
import json, pathlib, datetime

@dataclass
class TrainConfig:
    mode = "train"
    model_type: str = "DDPM"
    image_size: int = 256  # the generated image resolution
    num_img_channels: int = 1
    train_batch_size: int = 32
    eval_batch_size: int = 8  # how many images to sample during evaluation
    num_epochs: int = 200
    gradient_accumulation_steps: int = 1
    learning_rate: float = 1e-4
    lr_warmup_steps: int = 500
    save_image_epochs: int = 20
    save_model_epochs: int = 30
    mixed_precision: str = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision
    output_dir: str = None
    img_dir: str = "data/img"  # directory with training images
    seg_dir: str = "data/seg"  # directory with training segmentations

    # push_to_hub: bool = False  # whether to upload the saved model to the HF Hub
    # hub_private_repo: bool = False
    # overwrite_output_dir: bool = True  # overwrite the old model when re-running the notebook
    seed: int = 0

    # custom options
    segmentation_guided: bool = True
    segmentation_channel_mode: str = "single"
    num_segmentation_classes: int = 2 # INCLUDING background
    use_ablated_segmentations: bool = False
    dataset: str = "AVT_dongyang"
    resume_epoch: int = None

    eval_sample_size: int = 100
    eval_mask_removal: bool = True
    eval_blank_mask: bool = True

    #  EXPERIMENTAL/UNTESTED: classifier-free class guidance and image translation
    class_conditional: bool = False
    cfg_p_uncond: float = 0.2 # p_uncond in classifier-free guidance paper
    cfg_weight: float = 0.3 # w in the paper
    trans_noise_level: float = 0.5 # ratio of time step t to noise trans_start_images to total T before denoising in translation. e.g. value of 0.5 means t = 500 for default T = 1000.
    use_cfg_for_eval_conditioning: bool = True  # whether to use classifier-free guidance for or just naive class conditioning for main sampling loop
    cfg_maskguidance_condmodel_only: bool = True  # if using mask guidance AND cfg, only give mask to conditional network
    # ^ this is because giving mask to both uncond and cond model make class guidance not work 
    # (see "Classifier-free guidance resolution weighting." in ControlNet paper)

# ---------- save ----------
cfg = TrainConfig()

cfg.output_dir = "runs/AVT_dongyang"
pathlib.Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)

cfg.img_dir = "data_dongyang/img"
cfg.seg_dir = "data_dongyang/seg"

# Save the config to a JSON file
ts = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
path = pathlib.Path(cfg.output_dir, ts, "config.json")
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(asdict(cfg), indent=2))

898

In [None]:
import os
from torch.utils.data import Dataset
import nrrd
import numpy as np
import psutil, shutil

# ---------- tweak these two values ---------------------------------
MIN_FREE_RAM_GB   = 1.0   # stop if < 1 GB RAM left
MIN_FREE_DISK_GB  = 1.0   # stop if < 1 GB free on the caching partition
CHECK_EVERY_N_SLICES = 50 # how often to poll resources
# -------------------------------------------------------------------

def _enough_resources():
    # --- RAM ---
    avail_ram_gb = psutil.virtual_memory().available / 2**30
    # --- disk (whatever partition holds ~/.cache) ---
    cache_root   = os.path.expanduser("~/.cache")
    avail_disk_gb = shutil.disk_usage(cache_root).free / 2**30
    return (
        avail_ram_gb > MIN_FREE_RAM_GB
        and avail_disk_gb > MIN_FREE_DISK_GB
    ), avail_ram_gb, avail_disk_gb


class NRRDDataset(Dataset):
    def __init__(self, 
                 img_dir = None, 
                 seg_dir = None, 
                 split="train",
                 img_size=256,
                 segmentation_guided=True,
        ):
        super().__init__()
        self.img_dir = img_dir
        self.seg_dir = seg_dir
        self.split = split
        self.segmentation_guided = segmentation_guided
        self.samples = []

        seg_types = os.listdir(self.seg_dir)

        #--- Transforms ---------------------------------------------------
        img_tf = Compose([
            # EnsureChannelFirst(),
            Resize((img_size, img_size)),
            ScaleIntensity(minv=-1.0, maxv=1.0),
            ToTensor(),
        ]) if img_dir is not None else None

        seg_tf = Compose([
            # EnsureChannelFirst(),
            Resize((img_size, img_size), mode="nearest"),
            ToTensor(),
        ]) if segmentation_guided else None

        #--- Get Volume Paths ----------------------------------------
        if img_dir is not None:
            vol_paths = [os.path.join(img_dir, split, f) for f in os.listdir(os.path.join(img_dir, split)) if f.endswith('.nrrd')]
        else:
            vol_paths = [os.path.join(seg_dir, seg_type, split, f) for seg_type in seg_types for f in os.listdir(os.path.join(seg_dir, seg_type, split)) if f.endswith('.nrrd')]

        # --- Pre‑load and Slice Volumes -------------------------------
        for vol_path in vol_paths:
            # read volume
            vol_img = None
            if img_dir is not None:
                vol_img, _ = nrrd.read(vol_path)             # (H,W,D)

            mask_vols = {}
            if segmentation_guided:
                for seg_type in seg_types:
                    m_path = os.path.join(seg_dir, seg_type, split, os.path.basename(vol_path))
                    mask_vols[seg_type], _ = nrrd.read(m_path)

            depth = vol_img.shape[2] if img_dir else next(iter(mask_vols.values())).shape[2]

            for z in range(depth):
                record = {}

                # image slice
                if img_dir:
                    img_slice = vol_img[:, :, z].astype(np.float32)
                    img_slice = np.expand_dims(img_slice, axis=0)  # (1,H,W)
                    img_slice = img_tf(img_slice)
                    record["images"] = img_slice

                # mask slices
                if segmentation_guided:
                    for st in seg_types:
                        m = mask_vols[st][:, :, z].astype(np.float32)
                        m = np.expand_dims(m, axis=0) 
                        record[f"seg_{st}"] = seg_tf(m)

                # filename
                stem = os.path.splitext(os.path.basename(vol_path))[0]
                record["image_filenames"] = f"{stem}_axial_{z:04d}"

                self.samples.append(record)
                # slice_counter += 1

                # # periodic resource check
                # if slice_counter % CHECK_EVERY_N_SLICES == 0:
                #     ok, ram, disk = _enough_resources()
                #     if not ok:
                #         print(
                #             f"[NRRDDataset] stopping preload:"
                #             f" only {ram:.1f} GB RAM / {disk:.1f} GB disk free"
                #         )
                #         return


    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


In [4]:
def make_loaders(
    img_dir,
    seg_dir,
    img_size,
    segmentation_guided,
    batch_sizes,
    num_workers=4,
):
    train_ds = NRRDDataset(
        img_dir,
        seg_dir,
        split="train",
        img_size=img_size,
        segmentation_guided=segmentation_guided,
    )
    val_ds = NRRDDataset(
        img_dir,
        seg_dir,
        split="val",
        img_size=img_size,
        segmentation_guided=segmentation_guided,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_sizes["train"],
        shuffle=True,
        num_workers=num_workers,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_sizes["val"],
        shuffle=False,
        num_workers=num_workers,
    )

    return train_loader, val_loader

# ------- create the raw dataset --------------------------------------
ds = NRRDDataset(
    img_dir="data_dongyang/img",          # None if mask‑only
    seg_dir="data_dongyang/seg",           # required for segs
    split="train",
    img_size=256,
    segmentation_guided=True,       # False for image‑only
)

print(f"Dataset length : {len(ds)} slices")

sample = ds[0]                      # any index < len(ds)
print("Keys           :", sample.keys())
print("Image shape    :", sample["images"].shape)        # → (1, 256, 256)
print("Mask shape     :", sample["seg_all"].shape)    # each → (1, 256, 256)
print("Slice filename :", sample["image_filenames"])

# ------- wrap in loaders ---------------------------------------------
train_loader, val_loader = make_loaders(
    img_dir="data_dongyang/img",
    seg_dir="data_dongyang/seg",
    img_size=256,
    segmentation_guided=True,
    batch_sizes={"train": 4, "val": 4},
)

batch = next(iter(train_loader))
print("Batch tensor keys :", batch.keys())
print("Batch 'images'    :", batch["images"].shape)      # (B, 1, 256, 256)

Dataset length : 1891 slices
Keys           : dict_keys(['images', 'seg_all', 'image_filenames'])
Image shape    : torch.Size([1, 256, 256])
Mask shape     : torch.Size([1, 256, 256])
Slice filename : D6_axial_0000
Batch tensor keys : dict_keys(['images', 'seg_all', 'image_filenames'])
Batch 'images'    : torch.Size([4, 1, 256, 256])


In [None]:
from tqdm.auto import tqdm
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import diffusers

from eval import evaluate, add_segmentations_to_noise, SegGuidedDDPMPipeline, SegGuidedDDIMPipeline

def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, eval_dataloader, lr_scheduler, device='cuda'):
    # Prepare everything
    # There is no specific order to remember, you just need to unpack the
    # objects in the same order you gave them to the prepare method.

    global_step = 0

    # logging
    run_name = '{}-{}-{}'.format(config.model_type.lower(), config.dataset, config.image_size)
    if config.segmentation_guided:
        run_name += "-segguided"
    writer = SummaryWriter(comment=run_name)

    # for loading segs to condition on:
    eval_dataloader = iter(eval_dataloader)

    # Now you train the model
    start_epoch = 0
    if config.resume_epoch is not None:
        start_epoch = config.resume_epoch

    for epoch in range(start_epoch, config.num_epochs):
        print(f"Epoch {epoch + 1}/{config.num_epochs}")
        progress_bar = tqdm(total=len(train_dataloader))
        progress_bar.set_description(f"Epoch {epoch}")

        model.train()

        for step, batch in enumerate(train_dataloader):
            clean_images = batch['images']
            clean_images = clean_images.to(device)

            # Sample noise to add to the images
            noise = torch.randn(clean_images.shape).to(clean_images.device)
            bs = clean_images.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device).long()

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            if config.segmentation_guided:
                noisy_images = add_segmentations_to_noise(noisy_images, batch, config, device)

            # Predict the noise residual
            if config.class_conditional:
                class_labels = torch.ones(noisy_images.size(0)).long().to(device)
                # classifier-free guidance
                a = np.random.uniform()
                if a <= config.cfg_p_uncond:
                    class_labels = torch.zeros_like(class_labels).long()
                noise_pred = model(noisy_images, timesteps, class_labels=class_labels, return_dict=False)[0]
            else:
                noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
            loss = F.mse_loss(noise_pred, noise)
            loss.backward()

            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            # also train on target domain images if conditional
            # (we don't have masks for this domain, so we can't do segmentation-guided; just use blank masks)
            if config.class_conditional:
                target_domain_images = batch['images_target']
                target_domain_images = target_domain_images.to(device)

                # Sample noise to add to the images
                noise = torch.randn(target_domain_images.shape).to(target_domain_images.device)
                bs = target_domain_images.shape[0]

                # Sample a random timestep for each image
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=target_domain_images.device).long()

                # Add noise to the clean images according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_images = noise_scheduler.add_noise(target_domain_images, noise, timesteps)

                if config.segmentation_guided:
                    # no masks in target domain so just use blank masks
                    noisy_images = torch.cat((noisy_images, torch.zeros_like(noisy_images)), dim=1)

                # Predict the noise residual
                class_labels = torch.full([noisy_images.size(0)], 2).long().to(device)
                # classifier-free guidance
                a = np.random.uniform()
                if a <= config.cfg_p_uncond:
                    class_labels = torch.zeros_like(class_labels).long()
                noise_pred = model(noisy_images, timesteps, class_labels=class_labels, return_dict=False)[0]
                loss_target_domain = F.mse_loss(noise_pred, noise)
                loss_target_domain.backward()

                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            if config.class_conditional:
                logs = {"loss": loss.detach().item(), "loss_target_domain": loss_target_domain.detach().item(), 
                        "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
                writer.add_scalar("loss_target_domain", loss.detach().item(), global_step)
            else: 
                logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            writer.add_scalar("loss", loss.detach().item(), global_step)

            progress_bar.set_postfix(**logs)
            global_step += 1

        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if config.model_type == "DDPM":
            if config.segmentation_guided:
                pipeline = SegGuidedDDPMPipeline(
                    unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config
                    )
            else:
                if config.class_conditional:
                    raise NotImplementedError("TODO: Conditional training not implemented for non-seg-guided DDPM")
                else:
                    pipeline = diffusers.DDPMPipeline(unet=model.module, scheduler=noise_scheduler)
        elif config.model_type == "DDIM":
            if config.segmentation_guided:
                pipeline = SegGuidedDDIMPipeline(
                    unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config
                    )
            else:
                if config.class_conditional:
                    raise NotImplementedError("TODO: Conditional training not implemented for non-seg-guided DDIM")
                else:
                    pipeline = diffusers.DDIMPipeline(unet=model.module, scheduler=noise_scheduler)

        model.eval()

        if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
            if config.segmentation_guided:
                seg_batch = next(eval_dataloader)
                evaluate(config, epoch, pipeline, seg_batch)
            else:
                evaluate(config, epoch, pipeline)

        if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
            pipeline.save_pretrained(config.output_dir)


In [6]:
import diffusers
from diffusers.optimization import get_cosine_schedule_with_warmup
import torch.nn as nn

# custom imports
# from training import train_loop
from eval import evaluate_generation, evaluate_sample_many

# define the model
in_channels = cfg.num_img_channels
if cfg.segmentation_guided:
    assert cfg.num_segmentation_classes is not None
    assert cfg.num_segmentation_classes > 1, "must have at least 2 segmentation classes (INCLUDING background)" 
    if cfg.segmentation_channel_mode == "single":
        in_channels += 1
    elif cfg.segmentation_channel_mode == "multi":
        in_channels = len(os.listdir(cfg.seg_dir)) + in_channels

model = diffusers.UNet2DModel(
        sample_size=cfg.image_size,  # the target image resolution
        in_channels=in_channels,  # the number of input channels, 3 for RGB images
        out_channels=cfg.num_img_channels,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channes for each UNet block
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",  # a regular ResNet upsampling block
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D"
        ),
    )

mode = cfg.mode
resume_epoch = cfg.resume_epoch
model_type = cfg.model_type

if (mode == "train" and resume_epoch is not None) or "eval" in mode:
    if mode == "train":
        print("resuming from model at training epoch {}".format(resume_epoch))
    elif "eval" in mode:
        print("loading saved model...")
    model = model.from_pretrain0ed(os.path.join(cfg.output_dir, 'unet'), use_safetensors=True)

model = nn.DataParallel(model)
model.to(device)

# define noise scheduler
if model_type == "DDPM":
    noise_scheduler = diffusers.DDPMScheduler(num_train_timesteps=1000)
elif model_type == "DDIM":
    noise_scheduler = diffusers.DDIMScheduler(num_train_timesteps=1000)

if mode == "train":
    # training setup
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate)
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=cfg.lr_warmup_steps,
        num_training_steps=(len(train_loader) * cfg.num_epochs),
    )

    # train
    train_loop(
        cfg, 
        model, 
        noise_scheduler, 
        optimizer, 
        train_loader, 
        val_loader, 
        lr_scheduler, 
        device=device
        )
elif mode == "eval":
    """
    default eval behavior:
    evaluate image generation or translation (if for conditional model, either evaluate naive class conditioning but not CFG,
    or with CFG),
    possibly conditioned on masks.

    has various options.
    """
    evaluate_generation(
        cfg, 
        model, 
        noise_scheduler,
        val_loader, 
        eval_mask_removal=cfg.eval_mask_removal,
        eval_blank_mask=cfg.eval_blank_mask,
        device=device
        )

elif mode == "eval_many":
    """
    generate many images and save them to a directory, saved individually
    """
    evaluate_sample_many(
        cfg.eval_sample_size,
        cfg,
        model,
        noise_scheduler,
        val_loader,
        device=device
        )

else:
    raise ValueError("mode \"{}\" not supported.".format(mode))

BP1
BP2
BP3
Epoch 1/200


Epoch 0:   0%|          | 0/473 [00:00<?, ?it/s]

BP4
BP5
BP6


Epoch 0:   0%|          | 1/473 [00:09<1:17:32,  9.86s/it, loss=1.02, lr=2e-7, step=0]

BP6


Epoch 0:   0%|          | 2/473 [00:20<1:20:08, 10.21s/it, loss=1.04, lr=4e-7, step=1]

BP6


Epoch 0:   1%|          | 3/473 [00:31<1:23:49, 10.70s/it, loss=1.04, lr=6e-7, step=2]

BP6


Epoch 0:   1%|          | 4/473 [00:43<1:27:34, 11.20s/it, loss=1.06, lr=8e-7, step=3]

BP6


KeyboardInterrupt: 