<a href="https://colab.research.google.com/github/ritwiks9635/Segmentation-Model/blob/main/Diffusion_Models_for_Image_Segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


#**Diffusion Models for Image Segmentation**

This pipeline is based on MONAI tutorial [1] and use MONAI for 2D segmentation of images using DDPMs, as proposed in [2].

The same structure can also be used for conditional image generation, or image-to-image translation, as proposed in [3,4].

1. https://github.com/Project-MONAI/GenerativeModels/blob/main/tutorials/generative/image_to_image_translation/tutorial_segmentation_with_ddpm.ipynb

2. Wolleb et al. "Diffusion Models for Implicit Image Segmentation Ensembles", https://arxiv.org/abs/2112.03145

3. Waibel et al. "A Diffusion Model Predicts 3D Shapes from 2D Microscopy Images", https://arxiv.org/abs/2208.14125

4. Durrer et al. "Diffusion Models for Contrast Harmonization of Magnetic Resonance Images", https://aps.arxiv.org/abs/2303.08189

In [None]:
# https://www.kaggle.com/datasets/user164919/hutu-80

In [None]:
!unzip /content/https:/www.kaggle.com/datasets/user164919/hutu-80/hutu-80.zip

In [None]:
!python -c "import monai" || pip install -q "monai[pillow, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
!python -c "import seaborn" || pip install -q seaborn
!python -c "import generative" ||pip install monai-generative

In [None]:
import os
import tempfile
import time
from glob import glob

from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from monai import transforms
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader, load_decathlon_datalist, CacheDataset
from monai.utils import set_determinism
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

from generative.inferers import DiffusionInferer
from generative.networks.nets.diffusion_model_unet import DiffusionModelUNet
from generative.networks.schedulers.ddpm import DDPMScheduler

torch.multiprocessing.set_sharing_strategy("file_system")
print_config()

MONAI version: 1.3.0
Numpy version: 1.25.2
Pytorch version: 2.2.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 865972f7a791bf7b42efbcd87c8402bd865b329e
MONAI __file__: /usr/local/lib/python3.10/dist-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 4.0.2
scikit-image version: 0.19.3
scipy version: 1.11.4
Pillow version: 9.4.0
Tensorboard version: 2.15.2
gdown version: 4.7.3
TorchVision version: 0.17.1+cu121
tqdm version: 4.66.2
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: 1.5.3
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: 4.38.2
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
   

#**Microscopic Images Dataset of Human Duodenum Adenocarcinoma**

In [None]:
data_dir = "/content/MCF-7 cell populations Dataset"

os.listdir(data_dir)

['images', 'masks']

###**Set deterministic training for reproducibility**

In [None]:
set_determinism(1217)

In [None]:
dataset_descr_json = {
    "description": "HuTu-80 Dataset",
    "name": "HuTu-80 Dataset",
    "test": [],
    "training": [],
    "validation": []
}

all_images = glob(os.path.join(data_dir, "images", "*.png"))
print("Total Images :::", len(all_images))

train_images, test_images = train_test_split(all_images, test_size = 0.2, random_state = 1217)

print("Train subset ::", len(train_images), "\nTest subset ::", len(test_images))

#dataset_descr_json["training"]

for image in tqdm(train_images):
    image_paths = {}
    image_paths["image"] = image
    image_paths["label"] = image.replace("images", "masks")
    dataset_descr_json["training"].append(image_paths)


for image in tqdm(test_images):
    image_paths = {}
    image_paths["image"] = image
    image_paths["label"] = image.replace("images", "masks")
    dataset_descr_json["validation"].append(image_paths)


print()
print(f"dataset_descr_json[training]: {len(dataset_descr_json['training'])}")
print(f"dataset_descr_json[validation]: {len(dataset_descr_json['validation'])}")

Total Images ::: 180
Train subset :: 144 
Test subset :: 36


100%|██████████| 144/144 [00:00<00:00, 451404.91it/s]
100%|██████████| 36/36 [00:00<00:00, 244724.38it/s]


dataset_descr_json[training]: 144
dataset_descr_json[validation]: 36





In [None]:
import json


split_json = "./dataset_0.json"

with open(split_json, 'w') as f:
    json.dump(dataset_descr_json, f)

#**Preprocessing of the HuTu Dataset for training**

In [None]:
train_transforms = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        transforms.EnsureChannelFirstd(keys=["image", "label"]),
        transforms.Lambdad(keys=["label"], func=lambda x: np.where(x[0, :, :].unsqueeze(0) > 127, 1, 0), overwrite=True),
        transforms.EnsureTyped(keys=["image", "label"]),
        transforms.Resized(keys=["image", "label"], spatial_size =(512, 512), mode=["bilinear","nearest"]),
        transforms.ScaleIntensityRangePercentilesd(keys="image", lower=0, upper=99.5, b_min=0, b_max=1),
    ]
)

In [None]:
train_files = load_decathlon_datalist(split_json, True, "training")
train_ds = CacheDataset(
    data=train_files,
    transform=train_transforms,
#              cache_num=1, # Uncomment it for debug purposes
    cache_rate=1.0,
    num_workers=4,
)

Loading dataset: 100%|██████████| 144/144 [00:52<00:00,  2.74it/s]


In [None]:
print(f"Length of training data: {len(train_ds)}")  # this gives the number of samples in the training set
print(f'Train image shape {train_ds[0]["image"].shape}')
print(f'Train label shape {train_ds[0]["label"].shape}')

Length of training data: 144
Train image shape torch.Size([3, 512, 512])
Train label shape torch.Size([1, 512, 512])


In [None]:
batch_size = 1
train_loader = DataLoader(
    train_ds, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True, persistent_workers=True
)



#**Preprocessing of the Dataset for validation**

In [None]:
val_files = load_decathlon_datalist(split_json, True, "validation")

In [None]:
val_ds = CacheDataset(
    data=val_files,
    transform=train_transforms,
#             cache_num=4,  # Uncomment it for debug purposes
    cache_rate=1.0,
    num_workers=4,
)

Loading dataset: 100%|██████████| 36/36 [00:11<00:00,  3.02it/s]


In [None]:
print(f"Length of training data: {len(val_ds)}")
print(f'Validation Image shape {val_ds[0]["image"].shape}')
print(f'Validation Label shape {val_ds[0]["label"].shape}')

Length of training data: 36
Validation Image shape torch.Size([3, 512, 512])
Validation Label shape torch.Size([1, 512, 512])


In [None]:
val_loader = DataLoader(
    val_ds, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=True, persistent_workers=True
)

#**Define network, scheduler, optimizer, and inferer**

At this step, we instantiate the MONAI components to create a DDPM, the UNET, the noise scheduler, and the inferer used for training and sampling.

We are using the DDPM scheduler containing 100 timesteps, and a 2D UNET with attention mechanisms in the 3rd level (num_head_channels=64).

In [None]:
device = torch.device("cuda")

In [None]:
model = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=4,
    out_channels=1,
    num_channels=(64, 64, 64),
    attention_levels=(False, False, True),
    num_res_blocks=1,
    num_head_channels=64,
    with_conditioning=False,
)
model.to(device)

DiffusionModelUNet(
  (conv_in): Convolution(
    (conv): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (time_embed): Sequential(
    (0): Linear(in_features=64, out_features=256, bias=True)
    (1): SiLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
  )
  (down_blocks): ModuleList(
    (0-1): 2 x DownBlock(
      (resnets): ModuleList(
        (0): ResnetBlock(
          (norm1): GroupNorm(32, 64, eps=1e-06, affine=True)
          (nonlinearity): SiLU()
          (conv1): Convolution(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (time_emb_proj): Linear(in_features=256, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-06, affine=True)
          (conv2): Convolution(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Identity()
        )
      )
      (downsampler): Downsample(
       

#**Model training of the Diffusion Model**
We train our diffusion model for 100 epochs.

In every step, we concatenate the original microscopic image to the noisy segmentation mask, to predict a slightly denoised segmentation mask.

This is described in Equation 7 of the paper https://arxiv.org/pdf/2112.03145.pdf.

In [None]:
n_epochs = 100
n_timesteps=10
val_interval = 5
epoch_loss_list = []
val_epoch_loss_list = []

In [None]:
scheduler = DDPMScheduler(num_train_timesteps=n_timesteps)
optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)
inferer = DiffusionInferer(scheduler)

In [None]:
scaler = GradScaler()
total_start = time.time()

for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0

    for step, data in enumerate(train_loader):
        images = data["image"].to(device)
        seg = data["label"].to(device)  # this is the ground truth segmentation
        optimizer.zero_grad(set_to_none=True)
        timesteps = torch.randint(0, n_timesteps, (len(images),)).to(device)  # pick a random time step t

        with autocast(enabled=True):
            # Generate random noise
            noise = torch.randn_like(seg).to(device)
            noisy_seg = scheduler.add_noise(
                original_samples=seg, noise=noise, timesteps=timesteps
            )  # we only add noise to the segmentation mask
            combined = torch.cat(
                (images, noisy_seg), dim=1
            )  # we concatenate the microscopic image with the noisy segmenatation mask, to condition the generation process
            prediction = model(x=combined, timesteps=timesteps)
            # Get model prediction
            loss = F.mse_loss(prediction.float(), noise.float())
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()

    epoch_loss_list.append(epoch_loss / (step + 1))
    if (epoch) % val_interval == 0:
        model.eval()
        val_epoch_loss = 0
        for step, data_val in enumerate(val_loader):
            images = data_val["image"].to(device)
            seg = data_val["label"].to(device)  # this is the ground truth segmentation
            timesteps = torch.randint(0, n_timesteps, (len(images),)).to(device)
            with torch.no_grad():
                with autocast(enabled=True):
                    noise = torch.randn_like(seg).to(device)
                    noisy_seg = scheduler.add_noise(original_samples=seg, noise=noise, timesteps=timesteps)
                    combined = torch.cat((images, noisy_seg), dim=1)
                    prediction = model(x=combined, timesteps=timesteps)
                    val_loss = F.mse_loss(prediction.float(), noise.float())
            val_epoch_loss += val_loss.item()
        print("Epoch", epoch, "Validation loss", val_epoch_loss / (step + 1))
        val_epoch_loss_list.append(val_epoch_loss / (step + 1))

torch.save(model.state_dict(), "./segmodel.pt")
total_time = time.time() - total_start

Epoch 0 Validation loss 0.46243801216284436
Epoch 5 Validation loss 0.03852957238753637
Epoch 10 Validation loss 0.026561259395546384
Epoch 15 Validation loss 0.03044199488229222
Epoch 20 Validation loss 0.017325865984376933
Epoch 25 Validation loss 0.014769794484083023
Epoch 30 Validation loss 0.014738088137366705
Epoch 35 Validation loss 0.0114727019228869


In [None]:
print(f"train diffusion completed, total time: {total_time}.")
plt.title("Learning Curves Diffusion Model", fontsize=20)
plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color="red", linewidth=2.0, label="Train")
plt.plot(
    np.linspace(val_interval, n_epochs, int(n_epochs / val_interval)),
    val_epoch_loss_list,
    color="green",
    linewidth=2.0,
    label="Validation",
)
plt.yticks(fontsize=12)
plt.xticks(fontsize=12)
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.legend(prop={"size": 14})
plt.show()

#**Sampling of a new segmentation mask for an input image of the validation set**

Starting from random noise, we want to generate a segmentation mask for a microscopic image of our validation set.

Due to the stochastic generation process, we can sample an ensemble of n different segmentation masks per image.

First, we pick an image of our validation set, and check the ground truth segmentation mask.

In [None]:
idx = 0
data = val_ds[idx]
inputimg = data["image"]  # Pick an input slice of the validation set to be segmented
inputlabel = data["label"]  # Check out the ground truth label mask.


plt.figure("input" + str(inputlabel))
plt.imshow(inputimg.T, vmin=0, vmax=1, cmap="gray")
plt.axis("off")
plt.tight_layout()
plt.show()

plt.figure("input" + str(inputlabel))
plt.imshow(inputlabel.T, vmin=0, vmax=1, cmap="gray")
plt.axis("off")
plt.tight_layout()
plt.show()


model.eval()

Then we set the number of samples in the ensemble n. Starting from the input image (which ist the microscopic image), we follow Algorithm 1 of the paper "Diffusion Models for Implicit Image Segmentation Ensembles" (https://arxiv.org/pdf/2112.03145.pdf) n times. This gives us an ensemble of n different predicted segmentation masks.

In [None]:
n = 5
input_img = inputimg[None, ...].to(device)

ensemble = []
for k in range(5):
    noise_shape=list(input_img.shape)
    noise_shape[1] = 1
    noise = torch.randn(noise_shape).to(device)
    current_img = noise  # for the segmentation mask, we start from random noise.
    combined = torch.cat(
        (input_img, noise), dim=1
    )  # We concatenate the input microscopic image to add spartial information.
    scheduler.set_timesteps(num_inference_steps=n_timesteps)
    progress_bar = tqdm(scheduler.timesteps)
    chain = torch.zeros(current_img.shape)
    for t in progress_bar:  # go through the noising process
        with autocast(enabled=False):
            with torch.no_grad():
                model_output = model(combined, timesteps=torch.Tensor((t,)).to(current_img.device))
                current_img, _ = scheduler.step(
                    model_output, t, current_img
                )  # this is the prediction x_t at the time step t
                if t % (n_timesteps//10) == 0:
                    chain = torch.cat((chain, current_img.cpu()), dim=-1)
                combined = torch.cat(
                    (input_img, current_img), dim=1
                )  # in every step during the denoising process, the microscopic image is concatenated to add spartial information

    plt.style.use("default")
    plt.imshow(chain[0, 0, ..., 512:].cpu(), vmin=0, vmax=1, cmap="gray")
    plt.tight_layout()
    plt.axis("off")
    plt.show()
    ensemble.append(current_img)  # this is the output of the diffusion model after T=n_timesteps denoising steps

#**Segmentation prediction**

The predicted segmentation mask is obtained from the output of the diffusion model by thresholding.

We compute the Dice score for all predicted segmentations of the ensemble, as well as the pixel-wise mean and the variance map over the ensemble. As shown in the paper "Diffusion Models for Implicit Image Segmentation Ensembles" (https://arxiv.org/abs/2112.03145), we see that taking the mean over n=5 samples improves the segmentation performance.

The variance maps highlights pixels where the model is unsure about it's own prediction.

In [None]:
def dice_coeff(im1, im2, empty_score=1.0):
    im1 = np.asarray(im1).astype(bool)
    im2 = np.asarray(im2).astype(bool)

    im_sum = im1.sum() + im2.sum()
    if im_sum == 0:
        return empty_score

    # Compute Dice coefficient
    intersection = np.logical_and(im1, im2)

    return 2.0 * intersection.sum() / im_sum

In [None]:
for i in range(len(ensemble)):
    prediction = torch.where(ensemble[i] > 0.5, 1, 0).float()  # a binary mask is obtained via thresholding
    score = dice_coeff(
        prediction[0, 0].cpu(), inputlabel.cpu()
    )  # we compute the dice scores for all samples separately
    print("Dice score of sample" + str(i), score)


E = torch.where(torch.cat(ensemble) > 0.5, 1, 0).float()
var = torch.var(E, dim=0)  # pixel-wise variance map over the ensemble
mean = torch.mean(E, dim=0)  # pixel-wise mean map over the ensemble
mean_prediction = torch.where(mean > 0.5, 1, 0).float()

score = dice_coeff(mean_prediction[0, ...].cpu(), inputlabel.cpu())  # Here we predict the Dice score for the mean map
print("Dice score on the mean map", score)

plt.style.use("default")
plt.imshow(mean[0, ...].cpu(), vmin=0, vmax=1, cmap="gray")  # We plot the mean map
plt.tight_layout()
plt.axis("off")
plt.show()
plt.style.use("default")
plt.imshow(var[0, ...].cpu(), vmin=0, vmax=1, cmap="jet")  # We plot the variance map
plt.tight_layout()
plt.axis("off")
plt.show()