<a href="https://colab.research.google.com/github/wps0/deep4life/blob/main/unet_seg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Project : Cell image segmentation projects

Contact: Elena Casiraghi (University Milano elena.casiraghi@unimi.it)

Cell segmentation is usually the first step for downstream single-cell analysis in microscopy image-based biology and biomedical research. Deep learning has been widely used for cell-image segmentation.
The CellSeg competition aims to benchmark cell segmentation methods that could be applied to various microscopy images across multiple imaging platforms and tissue types for cell Segmentation. The  Dataset challenge organizers provide both labeled images and unlabeled ones.
The “2018 Data Science Bowl” Kaggle competition provides cell images and their masks for training cell/nuclei segmentation models.

In 2022 another [Cell Segmentation challenge was proposed at Neurips](https://neurips22-cellseg.grand-challenge.org/).
For interested readers, the competition proceeding has been published on [PMLR](https://proceedings.mlr.press/v212/)

### Project Description

In the field of (bio-medical) image processing, segmentation of images is typically performed via U-Nets [1,2].

A U-Net consists of an encoder - a series of convolution and pooling layers which reduce the spatial resolution of the input, followed by a decoder - a series of transposed convolution and upsampling layers which increase the spatial resolution of the input. The encoder and decoder are connected by a bottleneck layer which is responsible for reducing the number of channels in the input.
The key innovation of U-Net is the addition of skip connections that connect the contracting path to the corresponding layers in the expanding path, allowing the network to recover fine-grained details lost during downsampling.

<img src='https://production-media.paperswithcode.com/methods/Screen_Shot_2020-07-07_at_9.08.00_PM_rpNArED.png' width="400"/>


At this [link](https://rpubs.com/eR_ic/unet), you find an R implementation of basic U-Nets. At this [link](https://github.com/zhixuhao/unet), you find a Keras implementation of UNets.  
Other implementations of more advanced UNets are also made available in [2] at these links: [UNet++](https://github.com/MrGiovanni/UNetPlusPlus)
and by the CellSeg organizers as baseline models: [https://neurips22-cellseg.grand-challenge.org/baseline-and-tutorial/](https://neurips22-cellseg.grand-challenge.org/baseline-and-tutorial/)


### Project aim

The aim of the project is to download the *gray-level* (.tiff or .tif files) cell images from the [CellSeg](https://neurips22-cellseg.grand-challenge.org/dataset/) competition and assess the performance of an UNet or any other Deep model for cell segmentation.
We suggest using gray-level images to obtain a model that is better specified on a sub class of images.

Students are not restricted to use UNets but may other model is wellcome; e.g., even transformer based model in the [leaderboard](https://neurips22-cellseg.grand-challenge.org/evaluation/testing/leaderboard/) may be tested.
Students are free to choose any model, as long as they are able to explain their rationale, architecture, strengths and weaknesses.



### References

[1] Ronneberger, O., Fischer, P., Brox, T. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation. In: Navab, N., Hornegger, J., Wells, W., Frangi, A. (eds) Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015. MICCAI 2015. Lecture Notes in Computer Science(), vol 9351. Springer, Cham. https://doi.org/10.1007/978-3-319-24574-4_28

[2] Long, F. Microscopy cell nuclei segmentation with enhanced U-Net. BMC Bioinformatics 21, 8 (2020). https://doi.org/10.1186/s12859-019-3332-1


## Initialization

In [None]:
!pip install --upgrade gdown
!pip install torch
!pip install torchvision
!pip install opencv-contrib-python
!pip install torchsummary

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import tempfile
from typing import Callable, List, Tuple
from torchvision import transforms
from PIL import Image

In [3]:
TRAIN_PATH = 'data_train/Training-labeled/images/'
TRAIN_LABELS_PATH = 'data_train/Training-labeled/labels/'

TEST_PATH = 'data_test/Testing/Public/images/'
TEST_LABELS_PATH = 'data_test/Testing/Public/labels/'

VAL_PATH = 'data_val/Tuning/images/'
VAL_LABELS_PATH = 'data_val/Tuning/labels/'

use_cuda = torch.cuda.is_available()
if use_cuda:
  device = torch.device("cuda")
  dataloader_kwargs = {"batch_size": 2, "shuffle": True, "pin_memory": True}
else:
  device = torch.device("cpu")
  dataloader_kwargs = {"batch_size": 2}


### Data preparation
[Browse the data](https://drive.google.com/drive/folders/1MaJibsHYitCPOltxVzYjr3rm5s9Vpjpv)

In [None]:
!curl -o data_test.zip "https://zenodo.org/records/10719375/files/Testing.zip?download=1"
!unzip -d data_test data_test.zip

In [None]:
!curl -o data_train.zip "https://zenodo.org/records/10719375/files/Training-labeled.zip?download=1"
!unzip -d data_train data_train.zip

In [None]:
!curl -o data_val.zip "https://zenodo.org/records/10719375/files/Tuning.zip?download=1"
!unzip -d data_val data_val.zip

In [7]:
class PadToSquare:
    def __init__(self, size=512, fill=0):
        self.size = size
        self.fill = fill

    def __call__(self, img):
        # Get current dimensions
        w, h = img.size

        # Calculate padding
        max_dim = max(w, h)
        pad_w = (max_dim - w) // 2
        pad_h = (max_dim - h) // 2

        # Apply padding to make square
        padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h)
        img = transforms.functional.pad(img, padding, fill=self.fill)

        # Resize to target size
        img = transforms.functional.resize(img, (self.size, self.size))
        return img


# Partially adapted from https://colab.research.google.com/github/mim-ml-teaching/public-dnn-2024-25/blob/master/docs/DNN-Lab-7-UNet-in-Pytorch-student-version.ipynb
class ImageTiffDataset(torch.utils.data.Dataset):
  def __init__(self,
               image_dir: str,
               target_dir: str,
               cache_dir: str,
               file_pairs: List[Tuple[str, str]],
               transform: torch.nn.Module = transforms.ToTensor(),
               target_transform: torch.nn.Module = transforms.ToTensor()):
    self.image_dir = image_dir
    self.target_dir = target_dir
    self.cache_dir = cache_dir
    self.file_pairs = file_pairs
    self.transform = transforms.Compose([
        PadToSquare(size=512),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
    ])
    self.target_transform = transforms.Compose([
        PadToSquare(size=512, fill=0),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
    ])

    if not os.path.exists(self.cache_dir):
      os.mkdir(self.cache_dir)
      os.mkdir(os.path.join(self.cache_dir, "images"))
      os.mkdir(os.path.join(self.cache_dir, "target"))


  def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
    img_filename, target_filename = self.file_pairs[idx]

    img_path = os.path.join(self.image_dir, img_filename)
    target_path = os.path.join(self.target_dir, target_filename)
    img_cache = os.path.join(self.cache_dir, img_filename)
    target_cache = os.path.join(self.cache_dir, target_filename)

    if not os.path.exists(img_cache):
      # Change. I had this error: AttributeError: module 'PIL.Image' has no attribute 'load'
      #img = Image.load(img_path)
      # End of change
      img = Image.open(img_path)
      img = self.transform(img)
      torch.save(img, img_cache)
    else:
      img = torch.load(img_cache)

    if not os.path.exists(target_cache):
      # Change. I had this error: AttributeError: module 'PIL.Image' has no attribute 'load'
      #target = Image.load(target_path)
      target = Image.open(target_path)
      # End of change
      target = self.target_transform(target)
      torch.save(target, target_cache)
    else:
      target = torch.load(target_cache)

    return img, target

  def __len__(self) -> int:
    return len(self.file_pairs)

def is_valid_tiff(path):
    try:
        with Image.open(path) as img:
            img.verify()
        return True
    except:
        return False

def find_image_mask_pairs(images_dir, masks_dir):
    print(f"Searching in {images_dir} and {masks_dir}")
    image_files = [f for f in os.listdir(images_dir) if f.endswith('.tiff') or f.endswith('.tif')]
    mask_files = set(f for f in os.listdir(masks_dir) if f.endswith('.tiff') or f.endswith('.tif'))

    pairs = []
    for img_file in image_files:
        if (len(pairs) > 100): break
        base_name = img_file[:-5] if img_file.endswith('.tiff') else img_file[:-4]

        found = False
        for suffix in ['_label.tiff', '_label.tif']:
            candidate_mask = base_name + suffix
            if candidate_mask in mask_files:
                img_path = os.path.join(images_dir, img_file)
                mask_path = os.path.join(masks_dir, candidate_mask)
                if is_valid_tiff(img_path) and is_valid_tiff(mask_path):
                    pairs.append((img_file, candidate_mask))
                else:
                    print(f"Invalid file: {img_file} or {candidate_mask}")
                found = True
                break

        if not found:
            print(f"No mask for image {img_file}")

    return pairs

def make_tiff_dataset(image_dir: str, target_dir: str, cache_dir: str):
    file_pairs = find_image_mask_pairs(image_dir, target_dir)
    return ImageTiffDataset(image_dir, target_dir, cache_dir, file_pairs)

In [8]:
train_dataset = make_tiff_dataset(TRAIN_PATH, TRAIN_LABELS_PATH, tempfile.mkdtemp())
test_dataset = make_tiff_dataset(TEST_PATH, TEST_LABELS_PATH, tempfile.mkdtemp())
val_dataset = make_tiff_dataset(VAL_PATH, VAL_LABELS_PATH, tempfile.mkdtemp())
dataloader_train = torch.utils.data.DataLoader(train_dataset, **dataloader_kwargs)
dataloader_test = torch.utils.data.DataLoader(test_dataset, **dataloader_kwargs)
dataloader_val = torch.utils.data.DataLoader(val_dataset, **dataloader_kwargs)

print('Train:', len(train_dataset))
print('Test:', len(test_dataset))
print('Val:', len(val_dataset))

Searching in data_train/Training-labeled/images/ and data_train/Training-labeled/labels/
Invalid file: cell_00502.tif or cell_00502_label.tiff
Invalid file: cell_00315.tiff or cell_00315_label.tiff
Invalid file: cell_00507.tif or cell_00507_label.tiff
Invalid file: cell_00301.tiff or cell_00301_label.tiff
Invalid file: cell_00304.tiff or cell_00304_label.tiff
Searching in data_test/Testing/Public/images/ and data_test/Testing/Public/labels/
Searching in data_val/Tuning/images/ and data_val/Tuning/labels/
Train: 101
Test: 30
Val: 58


## Basic U-Nets

In [9]:
class UNetConvBlock(nn.Module):
 def __init__(self, in_channels, out_channels):
   self.layer = nn.Sequential(
       nn.Conv2d(
           in_channels,
           out_channels,
           kernel_size=3,
           padding=1,
           dilation=0,
           padding_mode='reflect'),
       nn.ReLU() # Leaky ReLU?
   )

 def forward(self, x):
   return self.layer(x)

class UNetEncoderBlock(nn.Module):
 def __init__(self, in_channels: int, out_channels: int, maxpool: bool = True):
   assert out_channels > in_channels
   if maxpool:
     self.layer = nn.Sequential(
         UNetConvBlock(in_channels, out_channels),
         UNetConvBlock(out_channels, out_channels),
         nn.MaxPool2d(2, dilation=0)
     )
   else:
     self.layer = nn.Sequential(
         UNetConvBlock(in_channels, out_channels),
         UNetConvBlock(out_channels, out_channels)
     )

 def forward(self, x):
   return self.layer(x)

class UNetDecoderBlock(nn.Module):
 def __init__(self, in_channels: int, out_channels: int, unmaxpool: bool = True):
   assert out_channels < in_channels

   if unmaxpool:
     assert False
     self.layer = nn.Sequential(
          # TODO
     )
   else:
     self.layer = nn.Sequential(
         UNetConvBlock(in_channels, out_channels),
         UNetConvBlock(out_channels, out_channels)
     )

class UNet(nn.Module):
 def __init__(self, encoder_channels: List[int], decoder_channels: List[int]):
   assert len(encoder_channels) > 0
   self.encoder = nn.ModuleList()
   self.decoder = nn.ModuleList()

    #Expected:
    #UNetEncoderBlock(64, 128),
    #UNetEncoderBlock(128, 256),
    #UNetEncoderBlock(256, 512),
    #UNetEncoderBlock(512, 1024, False)
   in_channels = 1
   for out_channels in encoder_channels[:-1]:
     self.encoder.append(UNetEncoderBlock(in_channels, out_channels))
     in_channels = out_channels
   self.encoder.append(UNetEncoderBlock(in_channels, encoder_channels[-1], maxpool=False))

   for out_channels in decoder_channels[:-1]:
     self.decoder.append(in_channels)



In [10]:
import cv2

img_path = "data_train/Training-labeled/images/cell_00302.tiff"
image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
print(image)

[[43. 42. 41. ... 44. 42. 43.]
 [43. 40. 40. ... 42. 43. 42.]
 [43. 41. 40. ... 42. 43. 41.]
 ...
 [51. 48. 46. ... 42. 43. 42.]
 [50. 48. 47. ... 43. 44. 42.]
 [49. 49. 50. ... 45. 46. 46.]]


## Basic UNet (UNet_1) from https://rpubs.com/eR_ic/unet

In [None]:
# This code implements a U-Net model for semantic segmentation from the paper U-Net: Convolutional Networks for Biomedical Image Segmentation
import torch
import torch.nn as nn
import torchvision.transforms.functional

# Implement the double 3X3 convolution blocks
# The original paper did not use padding, but we will use padding to keep the image size the same

class double_convolution(nn.Module):
    """
    This class implements the double convolution block which consists of two 3X3 convolution layers,
    each followed by a ReLU activation function.

    """
    def __init__(self, in_channels, out_channels): # Initialize the class
        super().__init__() # Initialize the parent class

        # First 3X3 convolution layer
        self.first_cnn = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
        self.act1 = nn.ReLU()

        # Second 3X3 convolution layer
        self.second_cnn = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
        self.act2 = nn.ReLU()

    # Pass the input through the double convolution block
    def forward(self, x):
        x = self.first_cnn(x)
        x = self.act1(x)
        x = self.act2(self.second_cnn(x))
        return x


# Implement the Downsample block that occurs after each double convolution block
class down_sample(nn.Module):
    """
    This class implements the downsample block which consists of a Max Pooling layer with a kernel size of 2.
    The Max Pooling layer halves the image size reducing the spatial resolution of the feature maps
    while retaining the most important features.
    """
    def __init__(self):
        super().__init__()
        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)

    # Pass the input through the downsample block
    def forward(self, x):
        x = self.max_pool(x)
        return x

# Implement the UpSample block that occurs in the decoder part of the network
class up_sample(nn.Module):
    """
    This class implements the upsample block which consists of a convolution transpose layer with a kernel size of 2.
    The convolution transpose layer doubles the image size increasing the spatial resolution of the feature maps
    while retaining the learned features.
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Convolution transpose layer
        self.up_sample = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 2, stride = 2)

    # Pass the input through the upsample block
    def forward(self, x):
        x = self.up_sample(x)
        return x

# Implement the crop and concatenate block that occurs in the decoder part of the network
# This block concatenates the output of the upsample block with the output of the corresponding downsample block
# The output of the crop and concatenate block is then passed through a double convolution block
class crop_and_concatenate_fixed(nn.Module):
    """Memory-efficient crop and concatenate"""
    def forward(self, upsampled, bypass):
        # Use F.interpolate instead of torchvision resize (more memory efficient)
        if upsampled.shape[2:] != bypass.shape[2:]:
            upsampled = F.interpolate(upsampled, size=bypass.shape[2:],
                                    mode='bilinear', align_corners=False)
        return torch.cat([upsampled, bypass], dim=1)

# m = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
# input = torch.randn(1, 1024, 28, 28)
# m(input).shape

# m = nn.MaxPool2d(kernel_size = 2, stride = 2)
# xx = torch.randn(1, 1, 143, 143)
# m(xx).shape

## Implement the UNet architecture
class UNet(nn.Module):
    # in_channels: number of channels in the input image
    # out_channels: number of channels in the output image
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Define the contracting path: convolution blocks followed by downsample blocks
        self.down_conv = nn.ModuleList(double_convolution(in_chans, out_chans) for in_chans, out_chans in
                                       [(in_channels, 64), (64, 128), (128, 256), (256, 512)]) # List of downsample blocks

        self.down_samples = nn.ModuleList(down_sample() for _ in range(4))

        # Define the bottleneck layer
        self.bottleneck = double_convolution(in_channels = 512, out_channels = 1024)

        # Define the expanding path: upsample blocks followed by convolution blocks
        self.up_samples = nn.ModuleList(up_sample(in_chans, out_chans) for in_chans, out_chans in
                                        [(1024, 512), (512, 256), (256, 128), (128, 64)]) # List of upsample blocks

        self.concat = nn.ModuleList(crop_and_concatenate_fixed() for _ in range(4))

        self.up_conv = nn.ModuleList(double_convolution(in_chans, out_chans) for in_chans, out_chans in
                                        [(1024, 512), (512, 256), (256, 128), (128, 64)]) # List of convolution blocks

        # Final 1X1 convolution layer to produce the output segmentation map:
        # The primary purpose of 1x1 convolutions is to transform the channel dimension of the feature map,
        # while leaving the spatial dimensions unchanged.
        self.final_conv = nn.Conv2d(in_channels = 64, out_channels = out_channels, kernel_size = 1)

    # Pass the input through the UNet architecture
    def forward(self, x):
        # Pass the input through the contacting path
        skip_connections = [] # List to store the outputs of the downsample blocks
        for down_conv, down_sample in zip(self.down_conv, self.down_samples):
            x = down_conv(x)
            skip_connections.append(x)
            x = down_sample(x)

        # Pass the output of the contacting path through the bottleneck layer
        x = self.bottleneck(x)

        # Pass the output of the bottleneck layer through the expanding path
        skip_connections = skip_connections[::-1] # Reverse the list of skip connections
        for up_sample, concat, up_conv in zip(self.up_samples, self.concat, self.up_conv):
            x = up_sample(x)
            x = concat(x, skip_connections.pop(0)) # Remove the first element from the list of skip connections
            x = up_conv(x)

        # Pass the output of the expanding path through the final convolution layer
        x = self.final_conv(x)
        return x

# Sanity check for the model
import torchsummary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels = 3, out_channels = 1).to(device)
dummy_input = torch.randn((1, 3, 32, 32)).to(device)
mask = model(dummy_input)
mask.shape

# See how data flows through the network
torchsummary.summary(model, input_size = (3, 32, 32))

## Training UNet_1

In [None]:

def calculate_dice(pred, target):
    """Calculate Dice Similarity Coefficient with proper edge case handling"""
    # Ensure both are binary
    pred = pred.long()
    target = target.long()

    # Calculate intersection and sums
    intersection = (pred & target).float().sum()
    pred_sum = pred.float().sum()
    target_sum = target.float().sum()

    # Handle edge cases:
    # 1. Both masks empty (all background) - this is a perfect prediction
    if pred_sum == 0 and target_sum == 0:
        return 1.0

    # 2. One mask empty, other has positives - worst case
    if pred_sum == 0 or target_sum == 0:
        return 0.0

    # 3. Normal case - both have positive pixels
    dice = (2.0 * intersection) / (pred_sum + target_sum)

    return dice.item()


def train(model, device, train_loader, optimizer, epoch, log_interval=10):
    model.train()
    train_loss = 0
    total_dice = 0
    num_samples = 0

    # Track different types of samples
    samples_both_empty = 0
    samples_target_empty = 0
    samples_pred_empty = 0
    samples_both_positive = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)

        # Calculate loss
        loss = F.binary_cross_entropy_with_logits(output.squeeze(1), target.squeeze(1).float())

        # Get predictions
        pred = (torch.sigmoid(output) > 0.5).squeeze(1)
        target_binary = target.squeeze(1)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # Calculate Dice for each image in the batch
        batch_dice = 0
        for i in range(pred.shape[0]):
            dice = calculate_dice(pred[i], target_binary[i])
            batch_dice += dice

            # Track sample types
            pred_positive = pred[i].sum().item() > 0
            target_positive = target_binary[i].sum().item() > 0

            if not pred_positive and not target_positive:
                samples_both_empty += 1
            elif not pred_positive and target_positive:
                samples_pred_empty += 1
            elif pred_positive and not target_positive:
                samples_target_empty += 1
            else:
                samples_both_positive += 1

        total_dice += batch_dice
        num_samples += pred.shape[0]

        if batch_idx % log_interval == 0:
            current_avg_dice = batch_dice / pred.shape[0]
            pred_positives = pred.sum().item()
            target_positives = target_binary.sum().item()
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}, '
                  f'Batch Dice: {current_avg_dice:.4f}, '
                  f'Pred+: {pred_positives}, Target+: {target_positives}')

    avg_loss = train_loss / len(train_loader)
    avg_dice = total_dice / num_samples

    print(f'\nTrain Epoch: {epoch} Summary:')
    print(f'  Average loss: {avg_loss:.4f}, Average Dice: {avg_dice:.4f}')
    print(f'  Sample distribution:')
    print(f'    Both empty (background only): {samples_both_empty} ({100*samples_both_empty/num_samples:.1f}%)')
    print(f'    Target has object, pred empty: {samples_pred_empty} ({100*samples_pred_empty/num_samples:.1f}%)')
    print(f'    Target empty, pred has false positives: {samples_target_empty} ({100*samples_target_empty/num_samples:.1f}%)')
    print(f'    Both have positives: {samples_both_positive} ({100*samples_both_positive/num_samples:.1f}%)')

    return avg_loss, avg_dice


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    total_dice = 0
    num_samples = 0

    # Track different types of samples
    samples_both_empty = 0
    samples_target_empty = 0
    samples_pred_empty = 0
    samples_both_positive = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)

            # Calculate loss
            loss = F.binary_cross_entropy_with_logits(
                output.squeeze(1),
                target.squeeze(1).float(),
                reduction='mean'
            )

            # Get predictions
            pred = (torch.sigmoid(output) > 0.5).squeeze(1)
            target_binary = target.squeeze(1)

            test_loss += loss.item()

            # Calculate Dice for each image in the batch
            for i in range(pred.shape[0]):
                dice = calculate_dice(pred[i], target_binary[i])
                total_dice += dice

                # Track sample types
                pred_positive = pred[i].sum().item() > 0
                target_positive = target_binary[i].sum().item() > 0

                if not pred_positive and not target_positive:
                    samples_both_empty += 1
                elif not pred_positive and target_positive:
                    samples_pred_empty += 1
                elif pred_positive and not target_positive:
                    samples_target_empty += 1
                else:
                    samples_both_positive += 1

            num_samples += pred.shape[0]

    avg_loss = test_loss / len(test_loader)
    avg_dice = total_dice / num_samples

    print(f'\nTest set Results:')
    print(f'  Average loss: {avg_loss:.4f}, Average Dice: {avg_dice:.4f}')
    print(f'  Sample distribution:')
    print(f'    Both empty (background only): {samples_both_empty} ({100*samples_both_empty/num_samples:.1f}%)')
    print(f'    Target has object, pred empty: {samples_pred_empty} ({100*samples_pred_empty/num_samples:.1f}%)')
    print(f'    Target empty, pred has false positives: {samples_target_empty} ({100*samples_target_empty/num_samples:.1f}%)')
    print(f'    Both have positives: {samples_both_positive} ({100*samples_both_positive/num_samples:.1f}%)\n')

    return avg_loss, avg_dice

In [None]:
batch_size = 2
test_batch_size = 10
epochs = 5
lr = 2e-3
use_cuda = True
seed = 1
log_interval = 10
test_size = 5

In [None]:
torch.cuda.empty_cache()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# in_channels=3 for RGB images, out_channels=2 for binary segmentation with 2 classes (background vs object)
model = UNet(in_channels=1, out_channels=1).to(device)

# Use the Adam optimizer with a small learning rate (good starting point for U-Net training)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Set the number of training epochs
num_epochs = 5

# Training loop over all epochs
for epoch in range(1, num_epochs + 1):
    torch.cuda.empty_cache()
    train(model, device, dataloader_train, optimizer, epoch, log_interval)
    test(model, device, dataloader_test)

In [None]:
import os

path = "data_train/Training-labeled/images/cell_00311.tiff"
print(f"File exists: {os.path.exists(path)}")
print(f"File size: {os.path.getsize(path)} bytes")

import cv2

img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
print(type(img), img.shape if img is not None else "Can't load image")

In [None]:
import os
from PIL import Image, UnidentifiedImageError

def validate_tiff_dataset(image_dir, label_dir):
    broken_images = []
    broken_labels = []

    image_files = sorted(os.listdir(image_dir))
    label_files = sorted(os.listdir(label_dir))

    print(f"Sprawdzam {len(image_files)} obrazów i {len(label_files)} masek...")

    for filename in image_files:
        img_path = os.path.join(image_dir, filename)
        try:
            with Image.open(img_path) as img:
                img.verify()  # tylko weryfikuje (nie ładuje całkowicie)
        except (UnidentifiedImageError, OSError, FileNotFoundError) as e:
            print(f"[BŁĄD OBRAZU] {filename}: {e}")
            broken_images.append(filename)

    for filename in label_files:
        label_path = os.path.join(label_dir, filename)
        try:
            with Image.open(label_path) as img:
                img.verify()
        except (UnidentifiedImageError, OSError, FileNotFoundError) as e:
            print(f"[BŁĄD MASKI] {filename}: {e}")
            broken_labels.append(filename)

    print("\nPodsumowanie:")
    print(f"- Uszkodzone obrazy: {len(broken_images)}")
    print(f"- Uszkodzone maski: {len(broken_labels)}")

    return broken_images, broken_labels

bad_imgs, bad_labels = validate_tiff_dataset(TRAIN_PATH, TRAIN_LABELS_PATH)

In [None]:
Image.open('data_train/Training-labeled/images/cell_00501.tif').verify()

In [None]:
target_np = cv2.imread('data_train/Training-labeled/labels/cell_00490_label.tiff', cv2.IMREAD_UNCHANGED)
print(target_np)

In [None]:
import cv2

for img in bad_imgs:
    path = os.path.join(TRAIN_PATH, img)
    image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    if image is None:
        print(f"{img} nie da się otworzyć przez OpenCV")
    else:
        print(f"{img} OK — można użyć OpenCV zamiast PIL")

## Testing UNet_1

In [None]:
@torch.no_grad()
def test_unet(model, dataloader, device):
    model.eval()
    total_correct = 0
    total_pixels = 0

    for images, masks in dataloader:
        images = images.to(device)             # (B, C, H, W)
        masks = masks.to(device)               # (B, H, W) — ground truth z etykietami klas

        outputs = model(images)                # (B, num_classes, H, W)
        preds = outputs.argmax(dim=1)          # (B, H, W)

        total_correct += (preds == masks).sum().item()
        total_pixels += masks.numel()

    accuracy = 100.0 * total_correct / total_pixels
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

import matplotlib.pyplot as plt

@torch.no_grad()
def visualize_predictions(model, dataloader, device, num_examples=3):
    model.eval()
    for images, masks in dataloader:
        images = images.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)

        for i in range(min(num_examples, images.size(0))):
            img = images[i].cpu().permute(1, 2, 0).numpy()
            mask = masks[i].cpu().numpy()
            pred = preds[i].cpu().numpy()

            fig, axs = plt.subplots(1, 3, figsize=(12, 4))
            axs[0].imshow(img)
            axs[0].set_title("Input Image")
            axs[1].imshow(mask)
            axs[1].set_title("Ground Truth")
            axs[2].imshow(pred)
            axs[2].set_title("Prediction")
            plt.show()
        break  # tylko jedna batch

In [None]:
test_unet(model, dataloader_test, device)
visualize_predictions(model, dataloader_test, device)

## ResidualAttentionUnet (UNet_2) from https://rpubs.com/eR_ic/unet

In [None]:
import torchvision

# Define a Residual block
class residual_block(nn.Module):
    """
    This class implements a residual block which consists of two convolution layers with group normalization
    """
    def __init__(self, in_channels, out_channels, n_groups = 8):
        super().__init__()
        # First convolution layer
        self.first_conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
        self.first_norm = nn.GroupNorm(num_groups = n_groups, num_channels = out_channels)
        self.act1 = nn.SiLU() # Swish activation function

        # Second convolution layer
        self.second_conv = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
        self.second_norm = nn.GroupNorm(num_groups = n_groups, num_channels = out_channels)
        self.act2 = nn.SiLU() # Swish activation function

        # If the number of input channels is not equal to the number of output channels,
        # then use a 1X1 convolution layer to compensate for the difference in dimensions
        # This allows the input to have the same dimensions as the output of the residual block
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1)
        else:
            # Pass the input as is
            self.shortcut = nn.Identity()

    # Pass the input through the residual block
    def forward(self, x):
        # Store the input
        input = x

        # Pass input through the first convolution layer
        x = self.act1(self.second_norm(self.first_conv(x)))

        # Pass the output of the first convolution layer through the second convolution layer
        x = self.act2(self.second_norm(self.second_conv(x)))

        # Add the input to the output of the second convolution layer
        # This is the skip connection
        x = x + self.shortcut(input)
        return x

# Implement the DownSample block that occurs after each residual block
class down_sample(nn.Module):
    def __init__(self):
        super().__init__()
        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)

    # Pass the input through the downsample block
    def forward(self, x):
        x = self.max_pool(x)
        return x

# Implement the UpSample block that occurs in the decoder path/expanding path
class up_sample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Convolution transpose layer to upsample the input
        self.up_sample = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 2, stride = 2)

    # Pass the input through the upsample block
    def forward(self, x):
        x = self.up_sample(x)
        return x

# Implement the crop and concatenate layer
class crop_and_concatenate(nn.Module):
    def forward(self, upsampled, bypass):
        # Crop the upsampled feature map to match the dimensions of the bypass feature map
        upsampled = torchvision.transforms.functional.resize(upsampled, size = bypass.shape[2:], antialias=True)
        x = torch.cat([upsampled, bypass], dim = 1) # Concatenate along the channel dimension
        return x

# Implement an attention block
class attention_block(nn.Module):
    def __init__(self, skip_channels, gate_channels, inter_channels = None, n_groups = 8):
        super().__init__()

        if inter_channels is None:
            inter_channels = skip_channels // 2

        # Implement W_g i.e the convolution layer that operates on the gate signal
        # Upsample gate signal to be the same size as the skip connection
        self.W_g = up_sample(in_channels = gate_channels, out_channels = skip_channels)
        #self.W_g_norm = nn.GroupNorm(num_groups = n_groups, num_channels = skip_channels)
        #self.W_g_act = nn.SiLU() # Swish activation function

        # Implement W_x i.e the convolution layer that operates on the skip connection
        self.W_x = nn.Conv2d(in_channels = skip_channels, out_channels = inter_channels, kernel_size = 1)
        #self.W_x_norm = nn.GroupNorm(num_groups = n_groups, num_channels = inter_channels)
        #self.W_x_act = nn.SiLU() # Swish activation function

        # Implement phi i.e the convolution layer that operates on the output of W_x + W_g
        self.phi = nn.Conv2d(in_channels = inter_channels, out_channels = 1, kernel_size = 1)
        #self.phi_norm = nn.GroupNorm(num_groups = n_groups, num_channels = 1)
        #self.phi_act = nn.SiLU() # Swish activation function

        # Implement the sigmoid activation function
        self.sigmoid = nn.Sigmoid()
        # Implement the Swish activation function
        self.act = nn.SiLU()

        # Implement final group normalization layer
        self.final_norm = nn.GroupNorm(num_groups = n_groups, num_channels = skip_channels)

    # Pass the input through the attention block
    def forward(self, skip_connection, gate_signal):
        # Upsample the gate signal to match the channels of the skip connection
        gate_signal = self.W_g(gate_signal)
        # Ensure that the sizes of the skip connection and the gate signal match before addition
        if gate_signal.shape[2:] != skip_connection.shape[2:]:
            gate_signal = torchvision.transforms.functional.resize(gate_signal, size = skip_connection.shape[2:], antialias=True)
        # Project to the intermediate channels
        gate_signal = self.W_x(gate_signal)

        # Project the skip connection to the intermediate channels
        skip_signal = self.W_x(skip_connection)

        # Add the skip connection and the gate signal
        add_xg = gate_signal + skip_signal

        # Pass the output of the addition through the activation function
        add_xg = self.act(add_xg)

        # Pass the output of attention through a 1x1 convolution layer to obtain the attention map
        attention_map = self.sigmoid(self.phi(add_xg))

        # Multiply the skip connection with the attention map
        # Perform element-wise multiplication
        skip_connection = torch.mul(skip_connection, attention_map)

        skip_connection = nn.Conv2d(in_channels = skip_connection.shape[1], out_channels = skip_connection.shape[1], kernel_size = 1)(skip_connection)
        skip_connection = self.act(self.final_norm(skip_connection))

        return skip_connection


## Implement a residual attention U-Net
class ResidualAttentionUnet(nn.Module):
    def __init__(self, in_channels, out_channels, n_groups = 4, n_channels = [64, 128, 256, 512, 1024]):
        super().__init__()

        # Define the contracting path: residual blocks followed by downsampling
        self.down_conv = nn.ModuleList(residual_block(in_chans, out_chans) for in_chans, out_chans in
                                       [(in_channels, n_channels[0]), (n_channels[0], n_channels[1]), (n_channels[1], n_channels[2]), (n_channels[2], n_channels[3])])
        self.down_samples = nn.ModuleList(down_sample() for _ in range(4))

        # Define the bottleneck residual block
        self.bottleneck = residual_block(n_channels[3], n_channels[4])

        # Define the attention blocks
        self.attention_blocks = nn.ModuleList(attention_block(skip_channels = residuals_chans, gate_channels = gate_chans) for gate_chans, residuals_chans in
                                              [(n_channels[4], n_channels[3]), (n_channels[3], n_channels[2]), (n_channels[2], n_channels[1]), (n_channels[1], n_channels[0])])

        # Define the expanding path: upsample blocks, followed by crop and concatenate, followed by residual blocks
        self.upsamples = nn.ModuleList(up_sample(in_chans, out_chans) for in_chans, out_chans in
                                       [(n_channels[4], n_channels[3]), (n_channels[3], n_channels[2]), (n_channels[2], n_channels[1]), (n_channels[1], n_channels[0])])

        self.concat = nn.ModuleList(crop_and_concatenate() for _ in range(4))

        self.up_conv = nn.ModuleList(residual_block(in_chans, out_chans) for in_chans, out_chans in
                                     [(n_channels[4], n_channels[3]), (n_channels[3], n_channels[2]), (n_channels[2], n_channels[1]), (n_channels[1], n_channels[0])])

        # Final 1X1 convolution layer to produce the output segmentation map:
        # The primary purpose of 1x1 convolutions is to transform the channel dimension of the feature map,
        # while leaving the spatial dimensions unchanged.
        self.final_conv = nn.Conv2d(in_channels = n_channels[0] , out_channels = out_channels, kernel_size = 1)

    # Pass the input through the residual attention U-Net
    def forward(self, x):
        # Store the skip connections
        skip_connections = []
        # # Store the gate signals
        # gate_signals = []

        # Pass the input through the contracting path
        for down_conv, down_sample in zip(self.down_conv, self.down_samples):
            x = down_conv(x)
            skip_connections.append(x)
            #gate_signals.append(x)
            x = down_sample(x)

        # Pass the output of the contracting path through the bottleneck
        x = self.bottleneck(x)
        skip_connections.append(x)

        # Attention on the residual connections
        #skip_connections = skip_connections[::-1]
        n = len(skip_connections)
        indices = [(n - 1 - i, n - 2 - i) for i in range(n - 1)]
        attentions = []
        for i, g_x in enumerate(indices):
            g_gate = g_x[0]
            x_residual = g_x[1]
            attn = self.attention_blocks[i](skip_connections[x_residual], skip_connections[g_gate])
            attentions.append(attn)

        #attentions = attentions[::-1]

        # Pass the output of the attention blocks through the expanding path
        for up_sample, concat, up_conv in zip(self.upsamples, self.concat, self.up_conv):
            x = up_sample(x)
            x = concat(x, attentions.pop(0))
            x = up_conv(x)

        # Pass the output of the expanding path through the final convolution layer
        x = self.final_conv(x)
        return x

## Sanity check
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResidualAttentionUnet(in_channels = 3, out_channels = 1).to(device)
x = torch.randn((1, 3, 32, 32)).to(device)
mask = model(x)
mask.shape

# See how data flows through the network
torchsummary.summary(model, input_size = (3, 32, 32))

## Training UNet_2

## Testing UNet_2