<a href="https://colab.research.google.com/github/wps0/deep4life/blob/main/Project04_CellSegmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Project : Cell image segmentation projects

Contact: Elena Casiraghi (University Milano elena.casiraghi@unimi.it)

Cell segmentation is usually the first step for downstream single-cell analysis in microscopy image-based biology and biomedical research. Deep learning has been widely used for cell-image segmentation.
The CellSeg competition aims to benchmark cell segmentation methods that could be applied to various microscopy images across multiple imaging platforms and tissue types for cell Segmentation. The  Dataset challenge organizers provide both labeled images and unlabeled ones.
The “2018 Data Science Bowl” Kaggle competition provides cell images and their masks for training cell/nuclei segmentation models.

In 2022 another [Cell Segmentation challenge was proposed at Neurips](https://neurips22-cellseg.grand-challenge.org/).
For interested readers, the competition proceeding has been published on [PMLR](https://proceedings.mlr.press/v212/)

### Project Description

In the field of (bio-medical) image processing, segmentation of images is typically performed via U-Nets [1,2].

A U-Net consists of an encoder - a series of convolution and pooling layers which reduce the spatial resolution of the input, followed by a decoder - a series of transposed convolution and upsampling layers which increase the spatial resolution of the input. The encoder and decoder are connected by a bottleneck layer which is responsible for reducing the number of channels in the input.
The key innovation of U-Net is the addition of skip connections that connect the contracting path to the corresponding layers in the expanding path, allowing the network to recover fine-grained details lost during downsampling.

<img src='https://production-media.paperswithcode.com/methods/Screen_Shot_2020-07-07_at_9.08.00_PM_rpNArED.png' width="400"/>


At this [link](https://rpubs.com/eR_ic/unet), you find an R implementation of basic U-Nets. At this [link](https://github.com/zhixuhao/unet), you find a Keras implementation of UNets.  
Other implementations of more advanced UNets are also made available in [2] at these links: [UNet++](https://github.com/MrGiovanni/UNetPlusPlus)
and by the CellSeg organizers as baseline models: [https://neurips22-cellseg.grand-challenge.org/baseline-and-tutorial/](https://neurips22-cellseg.grand-challenge.org/baseline-and-tutorial/)


### Project aim

The aim of the project is to download the *gray-level* (.tiff or .tif files) cell images from the [CellSeg](https://neurips22-cellseg.grand-challenge.org/dataset/) competition and assess the performance of an UNet or any other Deep model for cell segmentation.
We suggest using gray-level images to obtain a model that is better specified on a sub class of images.

Students are not restricted to use UNets but may other model is wellcome; e.g., even transformer based model in the [leaderboard](https://neurips22-cellseg.grand-challenge.org/evaluation/testing/leaderboard/) may be tested.
Students are free to choose any model, as long as they are able to explain their rationale, architecture, strengths and weaknesses.



### References

[1] Ronneberger, O., Fischer, P., Brox, T. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation. In: Navab, N., Hornegger, J., Wells, W., Frangi, A. (eds) Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015. MICCAI 2015. Lecture Notes in Computer Science(), vol 9351. Springer, Cham. https://doi.org/10.1007/978-3-319-24574-4_28

[2] Long, F. Microscopy cell nuclei segmentation with enhanced U-Net. BMC Bioinformatics 21, 8 (2020). https://doi.org/10.1186/s12859-019-3332-1


## Initialization

In [1]:
!pip install --upgrade gdown



In [43]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import tempfile
from typing import Callable, List, Tuple
from torchvision import transforms
from PIL import Image, UnidentifiedImageError

In [3]:
TRAIN_PATH = 'data_train/Training-labeled/images/'
TRAIN_LABELS_PATH = 'data_train/Training-labeled/labels/'

TEST_PATH = 'data_test/Testing/Public/images/'
TEST_LABELS_PATH = 'data_test/Testing/Public/labels/'

VAL_PATH = 'data_val/Tuning/images/'
VAL_LABELS_PATH = ' data_val/Tuning/labels/'

use_cuda = torch.cuda.is_available()
if use_cuda:
  device = torch.device("cuda")
  dataloader_kwargs = {"batch_size": 32, "shuffle": True, "pin_memory": True}
else:
  device = torch.device("cpu")
  dataloader_kwargs = {"batch_size": 64}


### Data preparation
[Browse the data](https://drive.google.com/drive/folders/1MaJibsHYitCPOltxVzYjr3rm5s9Vpjpv)

In [4]:
!curl -o data_test.zip https://zenodo.org/records/10719375/files/Testing.zip?download=1
!unzip -q -d data_test data_test.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 2793M  100 2793M    0     0  20.0M      0  0:02:19  0:02:19 --:--:-- 20.0M


In [5]:
!curl -o data_train.zip https://zenodo.org/records/10719375/files/Training-labeled.zip?download=1
!unzip -q -d data_train data_train.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1926M  100 1926M    0     0  18.3M      0  0:01:45  0:01:45 --:--:-- 18.4M


In [6]:
!curl -o data_val.zip https://zenodo.org/records/10719375/files/Tuning.zip?download=1
!unzip -q -d data_val data_val.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  595M  100  595M    0     0  14.0M      0  0:00:42  0:00:42 --:--:-- 9528k


In [45]:
# Partially adapted from https://colab.research.google.com/github/mim-ml-teaching/public-dnn-2024-25/blob/master/docs/DNN-Lab-7-UNet-in-Pytorch-student-version.ipynb
class ImageTiffDataset(torch.utils.data.Dataset):
  def __init__(self,
               image_dir: str,
               target_dir: str,
               cache_dir: str,
               filenames: List[str],
               transform: torch.nn.Module = transforms.ToTensor(),
               target_transform: torch.nn.Module = transforms.ToTensor()):
    self.image_dir = image_dir
    self.target_dir = target_dir
    self.cache_dir = cache_dir
    self.filenames = filenames
    self.transform = transform
    self.target_transform = target_transform

    if not os.path.exists(self.cache_dir):
      os.mkdir(self.cache_dir)
      os.mkdir(os.path.join(self.cache_dir, "images"))
      os.mkdir(os.path.join(self.cache_dir, "target"))

  def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
    img_filename = self.filenames[idx]
    target_filename = os.path.basename(img_filename)
    img_path = os.path.join(self.image_dir, img_filename)
    target_path = os.path.join(self.target_dir, target_filename)
    img_cache = os.path.join(self.cache_dir, img_filename)
    target_cache = os.path.join(self.cache_dir, target_filename)

    if not os.path.exists(img_cache):
      with Image.open(img_path) as im:
        img = self.transform(im)
        torch.save(img, img_cache)
    else:
      img = torch.load(img_cache)

    if not os.path.exists(target_cache):
      with Image.open(target_path) as im:
        target = self.target_transform(im)
        torch.save(target, target_cache)
    else:
      target = torch.load(target_cache)

    return img, target

  def __len__(self) -> int:
    return len(self.filenames)

def make_tiff_dataset(image_dir: str, target_dir: str, cache_dir: str):
  filenames = []
  with os.scandir(image_dir) as it:
    for entry in it:
      if not entry.is_file():
        continue
      lower_name = entry.name.lower()
      if lower_name.endswith('.tiff') or lower_name.endswith('.tif'):
        img_path = entry.path
        # Attempt to open the image to check validity
        try:
            with Image.open(img_path) as im:
                # Using .verify() attempts to read the image data and raises
                # an error if the file is corrupt or not a valid image format
                im.verify()
            filenames.append(entry.name)
        except (UnidentifiedImageError, IOError) as e:
            # Catch specific errors related to image opening/identification
            print(f"Warning: Skipping invalid image file {img_path}: {e}")
        except Exception as e:
            # Catch any other unexpected errors during opening/verification
            print(f"Warning: An unexpected error occurred processing {img_path}: {e}")
  return ImageTiffDataset(image_dir, target_dir, cache_dir, filenames)

In [46]:
train_dataset = make_tiff_dataset(TRAIN_PATH, TRAIN_LABELS_PATH, tempfile.mkdtemp())
test_dataset = make_tiff_dataset(TEST_PATH, TEST_LABELS_PATH, tempfile.mkdtemp())
val_dataset = make_tiff_dataset(VAL_PATH, VAL_LABELS_PATH, tempfile.mkdtemp())
train_dataloader = torch.utils.data.DataLoader(train_dataset, **dataloader_kwargs)
test_dataloader = torch.utils.data.DataLoader(test_dataset, **dataloader_kwargs)
val_dataloader = torch.utils.data.DataLoader(val_dataset, **dataloader_kwargs)

print('Train:', len(train_dataset))
print('Test:', len(test_dataset))
print('Val:', len(val_dataset))

Train: 468
Test: 30
Val: 58


## Basic U-Nets

In [47]:
def train_unet(
    model: torch.nn.Module,
    device: torch.device,
    train_loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    log_interval: int):
  model.train()
  correct = 0
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    log_probs = F.log_softmax(output, dim=1)
    loss = F.nll_loss(log_probs, target)
    pred = log_probs.argmax(
        dim=1, keepdim=True
    )  # get the index of the max log-probability
    correct += pred.eq(target.view_as(pred)).sum().item()
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      _, _, image_width, image_height = data.size()
      print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
          epoch,
          batch_idx * len(data),
          len(train_loader.dataset),
          100.0 * batch_idx / len(train_loader),
          loss.item(),
        ))
  print(
    "Train accuracy: {}/{} ({:.0f}%)".format(
        correct,
          (len(train_loader.dataset) * image_width * image_height),
          100.0 * correct / (len(train_loader.dataset) * image_width * image_height),
      )
  )

In [52]:
class UNetConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.layer = nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            dilation=0,
            padding_mode='reflect'),
        nn.ReLU() # Leaky ReLU?
    )

  def forward(self, x):
    return self.layer(x)

class UNetEncoderBlock(nn.Module):
  def __init__(self, in_channels: int, out_channels: int, maxpool: bool = True):
    super().__init__()
    assert out_channels > in_channels
    if maxpool:
      self.layer = nn.Sequential(
          UNetConvBlock(in_channels, out_channels),
          UNetConvBlock(out_channels, out_channels),
          nn.MaxPool2d(2, dilation=0)
      )
    else:
      self.layer = nn.Sequential(
          UNetConvBlock(in_channels, out_channels),
          UNetConvBlock(out_channels, out_channels)
      )

  def forward(self, x):
    return self.layer(x)

class UNetDecoderBlock(nn.Module):
  def __init__(self, in_channels: int, out_channels: int, unmaxpool: bool = True):
    super().__init__()
    if unmaxpool:
      assert out_channels < in_channels
      self.layer = nn.Sequential(
          nn.ConvTranspose2d(
              in_channels + out_channels,
              out_channels,
              kernel_size=3,
              output_padding=1,
              dilation=0),
          UNetConvBlock(out_channels, out_channels),
          UNetConvBlock(out_channels, out_channels)
      )
    else:
      assert False
      # assert out_channels == in_channels
      # self.layer = nn.Sequential(
      #     UNetConvBlock(in_channels, out_channels),
      #     UNetConvBlock(out_channels, out_channels)
      # )

class UNet(nn.Module):
  def __init__(self, encoder_channels: List[int], decoder_channels: List[int]):
    super().__init__()
    assert len(encoder_channels) > 0
    self.encoder = nn.ModuleList()
    self.decoder = nn.ModuleList()

    in_channels = 1
    for out_channels in encoder_channels[:-1]:
      # print(f"Encoder {in_channels} {out_channels}")
      self.encoder.append(UNetEncoderBlock(in_channels, out_channels))
      in_channels = out_channels
    self.encoder.append(UNetEncoderBlock(in_channels, encoder_channels[-1], maxpool=False))

    in_channels = encoder_channels[-1]
    for out_channels in decoder_channels:
      # print(f"Decoder {in_channels} {out_channels}")
      self.decoder.append(UNetDecoderBlock(in_channels, out_channels))
      in_channels = out_channels

  def forward(self, x):
    encoded = []
    for layer in self.encoder:
      x = layer(x)
      encoded.append(x)
    for residual, layer in zip(encoded, self.decoder):
      cc = torch.cat((residual, x), dim=1)
      x = layer(x)

In [53]:
basic_unet = UNet(
    encoder_channels=[64, 128, 256, 512, 1024],
    decoder_channels=[512, 256, 128, 64, 2])

optimizer_unet = optim.AdamW(basic_unet.parameters())
train_unet(basic_unet, device, train_dataloader, optimizer_unet, 10, 10)

RuntimeError: stack expects each tensor to be equal size, but got [1, 1608, 1608] at entry 0 and [1, 502, 500] at entry 1