<a href="https://colab.research.google.com/github/sdgroeve/Machine_Learning_course_UGent_D012554_2025/blob/main/notebooks/U-Net_blood_cell_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Blood Cell Segmentation: U-Net

This script implements a U-Net architecture for blood cell segmentation using PyTorch.
The U-Net is a convolutional neural network that was developed for biomedical image segmentation.


In [None]:
!pip install kagglehub

Download the blood cell dataset with masks:


In [None]:
import kagglehub

jeetblahiri_bccd_dataset_with_mask_path = kagglehub.dataset_download('jeetblahiri/bccd-dataset-with-mask')

jeetblahiri_bccd_dataset_with_mask_path

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image

from pathlib import Path

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

from torchvision import io
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms

torch.manual_seed = 42
np.random.seed = 42

Collect paths to all training images and their corresponding masks:


In [None]:
root_dir = jeetblahiri_bccd_dataset_with_mask_path + '/BCCD Dataset with mask'

train_images = list(Path(f'{root_dir}/train/original').glob('*'))
train_masks = list(Path(f'{root_dir}/train/mask').glob('*'))

# Collect paths to all test images and their corresponding masks
test_images = list(Path(f'{root_dir}/test/original').glob('*'))
test_masks = list(Path(f'{root_dir}/test/mask').glob('*'))

# Sort the paths to ensure images and masks match correctly
train_images.sort()
train_masks.sort()
test_images.sort()
test_masks.sort()

# Verify the sorting by displaying paths to the first image and mask
str(train_masks[0]), str(train_images[0])

Display a sample image and its mask:


In [None]:
image_idx = 10

plt.subplot(1, 2, 1)
plt.imshow(io.read_image(str(train_images[0])).permute(1,2,0))

plt.subplot(1, 2, 2)
plt.imshow(io.read_image(str(train_masks[0])).permute(1,2,0))

Create a custom dataset class for blood cell images:


In [None]:
class BloodCellDatase(Dataset):
    """
    Custom PyTorch Dataset for blood cell images and masks.

    Attributes:
        images (list): List of paths to images
        masks (list): List of paths to corresponding masks
        transform: Transformations to apply to images and masks
    """
    def __init__(self,images,masks,transform):
        self.images = images
        self.masks = masks
        self.transform = transform

    """Return the total number of samples in the dataset"""
    def __len__(self):
        return len(self.images)

    def __getitem__(self,idx):
        """
        Get a sample from the dataset.

        Args:
            idx (int): Index of the sample

        Returns:
            tuple: (image, mask) where both are PyTorch tensors
        """
        image = self.transform(io.read_image(str(self.images[idx])))
        mask = self.transform(io.read_image(str(self.masks[idx])))
        return image.to(torch.float32), mask[0:1, :, :].to(torch.float32)

Define transformation pipeline and create datasets:


In [None]:
# Define transformation to resize images and masks to 180x180
transform = transforms.Compose([
    transforms.Resize((180,180)),
])

# Create train and test datasets
train_dataset = BloodCellDatase(train_images,train_masks,transform=transform)
test_dataset = BloodCellDatase(test_images,test_masks,transform=transform)

# Split training data into training and validation sets (80%-20% split)
train_size = int(0.8 * len(train_dataset))
val_size = len(train_images) - train_size

train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Display dataset sizes
len(train_dataset), len(val_dataset), len(test_dataset)

Visualize a sample from the training dataset:

In [None]:
image_idx = 10

plt.subplot(1, 2, 1)
plt.imshow(train_dataset[image_idx][0].permute(1,2,0).numpy().astype(np.uint8))

plt.subplot(1, 2, 2)
plt.imshow(train_dataset[image_idx][1].permute(1,2,0).numpy().astype(np.uint8), cmap='gray')

Create data loaders for batch processing:


In [None]:
# Create data loaders with batch size of 64

batch_size = 64

train_dataloader = DataLoader(train_dataset,batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset,batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset,batch_size=batch_size, shuffle=False)

len(train_dataset), len(val_dataset), len(test_dataset)

Check shapes of images and masks:

In [None]:
train_dataset[100][0].permute(1,2,0).shape, train_dataset[100][1].permute(1,2,0).shape

## The U-Net model architecture

First, define the double convolutional block used in U-Net.

**Overview of the DoubleConv Block:**
  - Implements two consecutive convolutional layers for feature extraction.
  - Each convolution is followed by batch normalization and a ReLU activation.

**Attributes Defined in the `__init__` Method:**
  - **`conv` (nn.Sequential):**
    - A sequential container that stacks layers in the following order:
      - **First Convolution:**
        - `nn.Conv2d` layer transforming input channels to output channels.
        - Uses a 3×3 kernel, stride of 1, and padding of 1 to preserve spatial dimensions.
        - `bias=False` to rely on batch normalization for bias correction.
      - **First Batch Normalization:**
        - `nn.BatchNorm2d` applied to the output channels.
      - **First Activation:**
        - `nn.ReLU` with `inplace=True` for non-linear activation.
      - **Second Convolution:**
        - Another `nn.Conv2d` layer that processes the already transformed features.
        - Keeps the number of channels constant (from output channels to output channels).
        - Uses the same kernel size, stride, and padding as the first convolution.
      - **Second Batch Normalization:**
        - Normalizes the features again to stabilize training.
      - **Second Activation:**
        - A second `nn.ReLU` activation to introduce additional non-linearity.

**Initialization Process:**
  - The `__init__` method takes two parameters:
    - `in_channels`: Number of channels in the input tensor.
    - `out_channels`: Desired number of channels after the convolutional operations.
  - These parameters are used to construct the two convolutional blocks within the `nn.Sequential` container.

**Forward Pass:**
  - The `forward` method applies the sequential block (`self.conv`) to the input tensor `x`.
  - Returns the output after processing through both convolutional layers, their corresponding batch normalization, and ReLU activations.

In [None]:
class DoubleConv(nn.Module):
    """
    Double Convolutional block for U-Net architecture.

    This block consists of two consecutive convolutional layers
    each followed by batch normalization and ReLU activation.

    Attributes:
        conv (nn.Sequential): The sequence of layers
    """
    def __init__(self, in_channels, out_channels):
        """
        Initialize the DoubleConv module.

        Args:
            in_channels (int): Number of input channels
            out_channels (int): Number of output channels
        """
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        """Forward pass through the double convolution block"""
        return self.conv(x)

Define the complete U-Net architecture.

**Overview of the U-Net Architecture:**
  - Designed for image segmentation tasks.
  - Uses an encoder-decoder (downsampling-upsampling) structure.
  - Incorporates skip connections to fuse high-resolution features from the encoder with upsampled features in the decoder.

**Attributes Defined in the `__init__` Method:**
  - **`ups` (nn.ModuleList):**
    - Stores upsampling layers used in the decoder path.
  - **`downs` (nn.ModuleList):**
    - Stores downsampling layers (each typically a double convolution block) used in the encoder path.
  - **`pool` (nn.MaxPool2d):**
    - A max pooling layer with kernel size 2 and stride 2 to reduce spatial dimensions during encoding.
  - **`bottleneck` (DoubleConv):**
    - A convolutional block at the lowest part of the U, connecting the encoder and decoder.
  - **`final_conv` (nn.Conv2d):**
    - A 1×1 convolution that maps the final feature maps to the desired number of output channels (e.g., a binary mask).

**Initialization Process (`__init__` Method):**
  - **Input Parameters:**
    - `in_channels`: Number of channels in the input image (default is 3 for RGB images).
    - `out_channels`: Number of channels in the output segmentation mask (default is 1 for binary segmentation).
    - `features`: A list of integers defining the number of feature maps at each level of the U-Net.
  - **Encoder Path Setup:**
    - Iterates over the `features` list.
    - For each feature value:
      - Adds a `DoubleConv` layer (which likely includes two convolutional layers with activations) to the `downs` list.
      - Updates `in_channels` to the current feature count for the next layer.
  - **Decoder Path Setup:**
    - Iterates over the reversed `features` list.
    - For each feature value:
      - Adds an upsampling layer using `nn.ConvTranspose2d`:
        - This layer upsamples the feature maps (doubling the spatial dimensions).
        - The number of input channels to this layer is `feature * 2` (because of the concatenation with skip connections later).
      - Adds another `DoubleConv` layer to refine the features after concatenation.
  - **Bottleneck and Final Convolution:**
    - The bottleneck is a `DoubleConv` that processes the most compressed feature representation.
    - The `final_conv` layer reduces the number of channels to `out_channels` using a 1×1 convolution, producing the final segmentation mask.

**Forward Pass (`forward` Method):**
  - **Encoder Phase (Downsampling):**
    - Iterates over each module in the `downs` list:
      - Applies the `DoubleConv` block.
      - Stores the output in a `skip_connections` list for later use.
      - Applies max pooling (`self.pool`) to reduce the spatial dimensions before passing to the next block.
  - **Bottleneck Processing:**
    - Processes the pooled feature map through the `bottleneck` layer to extract deep features.
  - **Decoder Phase (Upsampling):**
    - Reverses the `skip_connections` list so that the last encoded features are used first in the decoder.
    - Processes the feature maps in pairs (upsampling layer followed by a `DoubleConv`):
      - **Upsampling:**
        - Uses the `nn.ConvTranspose2d` layer to upsample the feature maps.
      - **Skip Connection:**
        - Retrieves the corresponding encoder output.
        - If there is a mismatch in shape, the upsampled feature map is resized to match the skip connection.
      - **Concatenation & Refinement:**
        - Concatenates the skip connection with the upsampled features along the channel dimension.
        - Applies the subsequent `DoubleConv` layer to fuse and refine the combined features.
  - **Final Output:**
    - Passes the refined feature maps through the `final_conv` (1×1 convolution) to produce the final segmentation mask.

In [None]:
class UNet(nn.Module):
    """
    U-Net architecture for image segmentation.

    The U-Net consists of an encoder (downsampling) path and
    a decoder (upsampling) path with skip connections between them.

    Attributes:
        ups (nn.ModuleList): List of upsampling layers
        downs (nn.ModuleList): List of downsampling layers
        pool (nn.MaxPool2d): Pooling layer for downsampling
        bottleneck (DoubleConv): The bottleneck layer at the base of the "U"
        final_conv (nn.Conv2d): Final 1x1 convolution to produce output
    """
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        """
        Initialize the U-Net model.

        Args:
            in_channels (int): Number of input channels (default: 3 for RGB images)
            out_channels (int): Number of output channels (default: 1 for binary mask)
            features (list): List of feature dimensions for each level of the U-Net
        """
        super(UNet, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET (encoder path)
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET (decoder path)
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        # Bottleneck at the bottom of the U
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        # Final 1x1 convolution to produce output mask
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        """
        Forward pass through the U-Net model.

        Args:
            x (torch.Tensor): Input image tensor

        Returns:
            torch.Tensor: Output segmentation mask
        """
        # Store skip connections
        skip_connections = []

        # Encoder path (downsampling)
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        # Bottleneck
        x = self.bottleneck(x)
        # Reverse skip connections for decoder path
        skip_connections = skip_connections[::-1]

        # Decoder path (upsampling) with skip connections
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            # Resize if shapes don't match
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            # Concatenate skip connection with upsampled feature maps
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        # Final convolution
        return self.final_conv(x)

Train the U-Net model:

In [None]:
# Determine device (GPU or CPU)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define training parameters
epoches = 3
# Initialize the model and move to device
model = UNet().to(device)
# Set optimizer with learning rate
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# Define loss function - Binary Cross Entropy with Logits
loss_fn = nn.BCEWithLogitsLoss()

# Training loop
for epoch in range(epoches):
    model.train()
    train_loss = 0.0
    num_batches = 0

    # Loop through batches
    for images, masks in train_dataloader:
        optimizer.zero_grad()

        # Move data to device
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        outputs = model(images)

        # Calculate loss
        loss = loss_fn(outputs, masks)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Accumulate loss
        train_loss += loss.item()
        num_batches += 1

    # Calculate and print average loss for the epoch
    avg_loss = (train_loss/num_batches)
    print(f'Epoch {epoch+1}/{epoches} loss: {avg_loss}')

Model inference and visualization:

In [None]:
def predict_mask(model, image):
    """
    Generate a segmentation mask prediction for a single image.

    Args:
        model: Trained U-Net model
        image (torch.Tensor): Input image tensor

    Returns:
        torch.Tensor: Predicted mask
    """
    model.eval()
    with torch.no_grad():
        image = image.to(device)
        return model(image.unsqueeze(0)).squeeze(0).squeeze(0).to('cpu')

# Select an image from the test dataset
image_idx = 17

# Get the image and its ground truth mask
x, y = test_dataset[image_idx]
# Generate prediction
y_pred = predict_mask(model, x)

# Convert tensors to numpy arrays for visualization
image = x.permute(1,2,0).detach().numpy().astype(np.uint8)
mask = y.permute(1,2,0).detach().numpy().astype(np.uint8)
y_mask = y_pred.detach().numpy().astype(np.uint8)

# Visualize the original image, ground truth mask, and predicted mask
plt.figure(figsize=(12, 6))

plt.subplot(1, 3, 1)
plt.title('Input Image')
plt.imshow(image)

plt.subplot(1, 3, 2)
plt.title('Actual Mask')
plt.imshow(mask, cmap='gray')

plt.subplot(1, 3, 3)
plt.title('Predicted Mask')
plt.imshow(1 - y_mask, cmap='gray')