In [None]:
import logging
import os
import sys

import matplotlib.pyplot as plt
import torch
from torch.utils.tensorboard import SummaryWriter
import numpy as np

import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import DataLoader, ImageDataset
from monai.transforms import (
    EnsureChannelFirst,
    Compose,
    Resize,
    ScaleIntensity,
    ToTensor,
)

pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device:', device)

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
print_config()

ImportError: cannot import name 'AddChannel' from 'monai.transforms' (/home/patrick-do/Documents/Projects/synthetic-CT-slices/.venv/lib/python3.11/site-packages/monai/transforms/__init__.py)

Define Hyperamaters

In [2]:
# data_dir = 'data_dongyang'
# model_type = 'DDIM'

In [11]:
import os, json
import torch
from torch.utils.data import Dataset
import nrrd
import numpy as np

class NRRDDataset(Dataset):
    def __init__(self, 
                 img_dir = None, 
                 seg_dir = None, 
                 split="train",
                 img_size=256,
                 segmentation_guided=True,
        ):
        super().__init__()
        self.img_dir = img_dir
        self.seg_dir = seg_dir
        self.split = split
        self.segmentation_guided = segmentation_guided
        self.samples = []

        seg_types = os.listdir(self.seg_dir)

        #--- Transforms ---------------------------------------------------
        img_tf = Compose([
            # EnsureChannelFirst(),
            Resize((img_size, img_size)),
            ScaleIntensity(minv=-1.0, maxv=1.0),
            ToTensor(),
        ]) if img_dir is not None else None

        seg_tf = Compose([
            # EnsureChannelFirst(),
            Resize((img_size, img_size), mode="nearest"),
            ToTensor(),
        ]) if segmentation_guided else None

        #--- Get Volume Paths ----------------------------------------
        if img_dir is not None:
            vol_paths = [os.path.join(img_dir, split, f) for f in os.listdir(os.path.join(img_dir, split)) if f.endswith('.nrrd')]
        else:
            vol_paths = [os.path.join(seg_dir, seg_type, split, f) for seg_type in seg_types for f in os.listdir(os.path.join(seg_dir, seg_type, split)) if f.endswith('.nrrd')]

        #--- Pre-load and Slice Volumes -------------------------------
        for vol_path in vol_paths:
            vol_img = None
            if img_dir is not None:
                vol_img, _ = nrrd.read(vol_path)           # ndarray (H,W,D)

            mask_vols = {}
            if segmentation_guided:
                for seg_type in seg_types:
                    m_path = os.path.join(seg_dir, seg_type, split, os.path.basename(vol_path))
                    mask_vols[seg_type], _ = nrrd.read(m_path)

            depth = (
                vol_img.shape[2] if img_dir is not None else next(iter(mask_vols.values())).shape[2]
            )

            for z in range(depth):
                record = {}

                # --- image slice -------------------------------------------
                if img_dir is not None:
                    img_slice = vol_img[:, :, z].astype(np.float32)
                    # img_slice = np.expand_dims(img_slice, axis=0) 
                    img_slice = img_tf(img_slice)          # (1,H,W)
                    record["images"] = img_slice

                # --- mask slices -------------------------------------------
                if segmentation_guided:
                    for st in seg_types:
                        m = mask_vols[st][:, :, z].astype(np.float32)
                        # m = np.expand_dims(m, axis=0) 
                        record[f"seg_{st}"] = seg_tf(m)   # (1,H,W)

                # --- filename metadata ------------------------------------
                stem = os.path.splitext(os.path.basename(vol_path))[0]
                record["image_filenames"] = f"{stem}_axial_{z:04d}"

                self.samples.append(record)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


In [None]:
def make_loaders(
    img_dir,
    seg_dir,
    img_size,
    segmentation_guided,
    batch_sizes,
    num_workers=4,
):
    train_ds = NRRDDataset(
        img_dir,
        seg_dir,
        split="train",
        img_size=img_size,
        segmentation_guided=segmentation_guided,
    )
    val_ds = NRRDDataset(
        img_dir,
        seg_dir,
        split="val",
        img_size=img_size,
        segmentation_guided=segmentation_guided,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_sizes["train"],
        shuffle=True,
        num_workers=num_workers,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_sizes["val"],
        shuffle=False,
        num_workers=num_workers,
    )

    return train_loader, val_loader

# ------- create the raw dataset --------------------------------------
ds = NRRDDataset(
    img_dir="data_dongyang/img",          # None if mask‑only
    seg_dir="data_dongyang/seg",           # required for segs
    split="train",
    img_size=256,
    segmentation_guided=True,       # False for image‑only
)

print(f"Dataset length : {len(ds)} slices")

sample = ds[0]                      # any index < len(ds)
print("Keys           :", sample.keys())
print("Image shape    :", sample["images"].shape)        # → (1, 256, 256)
print("Mask shape     :", sample["seg_all"].shape)    # each → (1, 256, 256)
print("Slice filename :", sample["image_filenames"])

# ------- wrap in loaders ---------------------------------------------
train_loader, val_loader = make_loaders(
    img_dir="data_dongyang/img",
    seg_dir="data_dongyqang/seg",
    img_size=256,
    segmentation_guided=True,
    batch_sizes={"train": 4, "val": 4},
)

batch = next(iter(train_loader))
print("Batch tensor keys :", batch.keys())
print("Batch 'images'    :", batch["images"].shape)      # (B, 1, 256, 256)

In [None]:
# model = diffusers.UNet2DModel(
#         sample_size=config.image_size,  # the target image resolution
#         in_channels=in_channels,  # the number of input channels, 3 for RGB images
#         out_channels=num_img_channels,  # the number of output channels
#         layers_per_block=2,  # how many ResNet layers to use per UNet block
#         block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channes for each UNet block
#         down_block_types=(
#             "DownBlock2D",  # a regular ResNet downsampling block
#             "DownBlock2D",
#             "DownBlock2D",
#             "DownBlock2D",
#             "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
#             "DownBlock2D",
#         ),
#         up_block_types=(
#             "UpBlock2D",  # a regular ResNet upsampling block
#             "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
#             "UpBlock2D",
#             "UpBlock2D",
#             "UpBlock2D",
#             "UpBlock2D"
#         ),
#     )

# model = nn.DataParallel(model)
# model.to(device)

In [None]:
# # define noise scheduler
#     if model_type == "DDPM":
#         noise_scheduler = diffusers.DDPMScheduler(num_train_timesteps=1000)
#     elif model_type == "DDIM":
#         noise_scheduler = diffusers.DDIMScheduler(num_train_timesteps=1000)

In [None]:
from tqdm.auto import tqdm
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import diffusers

from eval import evaluate, add_segmentations_to_noise, SegGuidedDDPMPipeline, SegGuidedDDIMPipeline

def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, eval_dataloader, lr_scheduler, device='cuda'):
    # Prepare everything
    # There is no specific order to remember, you just need to unpack the
    # objects in the same order you gave them to the prepare method.

    global_step = 0

    # logging
    run_name = '{}-{}-{}'.format(config.model_type.lower(), config.dataset, config.image_size)
    if config.segmentation_guided:
        run_name += "-segguided"
    writer = SummaryWriter(comment=run_name)

    # for loading segs to condition on:
    eval_dataloader = iter(eval_dataloader)

    # Now you train the model
    start_epoch = 0
    if config.resume_epoch is not None:
        start_epoch = config.resume_epoch

    for epoch in range(start_epoch, config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader))
        progress_bar.set_description(f"Epoch {epoch}")

        model.train()

        for step, batch in enumerate(train_dataloader):
            clean_images = batch['images']
            clean_images = clean_images.to(device)

            # Sample noise to add to the images
            noise = torch.randn(clean_images.shape).to(clean_images.device)
            bs = clean_images.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device).long()

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            if config.segmentation_guided:
                noisy_images = add_segmentations_to_noise(noisy_images, batch, config, device)

            # Predict the noise residual
            if config.class_conditional:
                class_labels = torch.ones(noisy_images.size(0)).long().to(device)
                # classifier-free guidance
                a = np.random.uniform()
                if a <= config.cfg_p_uncond:
                    class_labels = torch.zeros_like(class_labels).long()
                noise_pred = model(noisy_images, timesteps, class_labels=class_labels, return_dict=False)[0]
            else:
                noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
            loss = F.mse_loss(noise_pred, noise)
            loss.backward()

            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            # also train on target domain images if conditional
            # (we don't have masks for this domain, so we can't do segmentation-guided; just use blank masks)
            if config.class_conditional:
                target_domain_images = batch['images_target']
                target_domain_images = target_domain_images.to(device)

                # Sample noise to add to the images
                noise = torch.randn(target_domain_images.shape).to(target_domain_images.device)
                bs = target_domain_images.shape[0]

                # Sample a random timestep for each image
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=target_domain_images.device).long()

                # Add noise to the clean images according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_images = noise_scheduler.add_noise(target_domain_images, noise, timesteps)

                if config.segmentation_guided:
                    # no masks in target domain so just use blank masks
                    noisy_images = torch.cat((noisy_images, torch.zeros_like(noisy_images)), dim=1)

                # Predict the noise residual
                class_labels = torch.full([noisy_images.size(0)], 2).long().to(device)
                # classifier-free guidance
                a = np.random.uniform()
                if a <= config.cfg_p_uncond:
                    class_labels = torch.zeros_like(class_labels).long()
                noise_pred = model(noisy_images, timesteps, class_labels=class_labels, return_dict=False)[0]
                loss_target_domain = F.mse_loss(noise_pred, noise)
                loss_target_domain.backward()

                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            if config.class_conditional:
                logs = {"loss": loss.detach().item(), "loss_target_domain": loss_target_domain.detach().item(), 
                        "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
                writer.add_scalar("loss_target_domain", loss.detach().item(), global_step)
            else: 
                logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            writer.add_scalar("loss", loss.detach().item(), global_step)

            progress_bar.set_postfix(**logs)
            global_step += 1

        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if config.model_type == "DDPM":
            if config.segmentation_guided:
                pipeline = SegGuidedDDPMPipeline(
                    unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config
                    )
            else:
                if config.class_conditional:
                    raise NotImplementedError("TODO: Conditional training not implemented for non-seg-guided DDPM")
                else:
                    pipeline = diffusers.DDPMPipeline(unet=model.module, scheduler=noise_scheduler)
        elif config.model_type == "DDIM":
            if config.segmentation_guided:
                pipeline = SegGuidedDDIMPipeline(
                    unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config
                    )
            else:
                if config.class_conditional:
                    raise NotImplementedError("TODO: Conditional training not implemented for non-seg-guided DDIM")
                else:
                    pipeline = diffusers.DDIMPipeline(unet=model.module, scheduler=noise_scheduler)

        model.eval()

        if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
            if config.segmentation_guided:
                seg_batch = next(eval_dataloader)
                evaluate(config, epoch, pipeline, seg_batch)
            else:
                evaluate(config, epoch, pipeline)

        if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
            pipeline.save_pretrained(config.output_dir)
