## Project : Cell image segmentation projects


Cell segmentation is usually the first step for downstream single-cell analysis in microscopy image-based biology and biomedical research. Deep learning has been widely used for cell-image segmentation.
The CellSeg competition aims to benchmark cell segmentation methods that could be applied to various microscopy images across multiple imaging platforms and tissue types for cell Segmentation. The  Dataset challenge organizers provide contains both labeled images and unlabeled ones.
The “2018 Data Science Bowl” Kaggle competition provides cell images and their masks for training cell/nuclei segmentation models.

### Project Description

In the field of (bio-medical) image processing, segmentation of images is typically performed via U-Nets [1,2].

A U-Net consists of an encoder - a series of convolution and pooling layers which reduce the spatial resolution of the input, followed by a decoder - a series of transposed convolution and upsampling layers which increase the spatial resolution of the input. The encoder and decoder are connected by a bottleneck layer which is responsible for reducing the number of channels in the input.
The key innovation of U-Net is the addition of skip connections that connect the contracting path to the corresponding layers in the expanding path, allowing the network to recover fine-grained details lost during downsampling.

<img src='https://production-media.paperswithcode.com/methods/Screen_Shot_2020-07-07_at_9.08.00_PM_rpNArED.png' width="400"/>


At this [link](https://rpubs.com/eR_ic/unet), you find an R implementation of basic U-Nets. At this [link](https://github.com/zhixuhao/unet), you find a Keras implementation of UNets.  
Other implementations of more advanced UNets are also made available in [2] at these links: [UNet++](https://github.com/MrGiovanni/UNetPlusPlus)
and by the CellSeg organizers as baseline models: [https://neurips22-cellseg.grand-challenge.org/baseline-and-tutorial/](https://neurips22-cellseg.grand-challenge.org/baseline-and-tutorial/)

### Project aim

The aim of the project is to download the cell images (preferably from the “2018 Data Science Bowl” competition) and assess the performance of an UNet or any other Deep model for cell segmentation.
Students are free to choose any model, as long as they are able to explain their rationale, architecture, strengths and weaknesses. 



### References

[1] Ronneberger, O., Fischer, P., Brox, T. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation. In: Navab, N., Hornegger, J., Wells, W., Frangi, A. (eds) Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015. MICCAI 2015. Lecture Notes in Computer Science(), vol 9351. Springer, Cham. https://doi.org/10.1007/978-3-319-24574-4_28

[2] Long, F. Microscopy cell nuclei segmentation with enhanced U-Net. BMC Bioinformatics 21, 8 (2020). https://doi.org/10.1186/s12859-019-3332-1


# Solution
## Download data
Manually from kaggle :(

## Model
[Tutorial unet](https://github.com/Project-MONAI/tutorials/tree/master/2d_segmentation/torch)

[Tutorial hover](https://github.com/Project-MONAI/tutorials/tree/main/pathology/hovernet)

In [1]:
## Imports
import torch
from monai.networks.nets import hovernet, UNet
from monai.engines import SupervisedTrainer, SupervisedEvaluator
from utils import CellDataset
from monai.utils.enums import HoVerNetBranch
from monai.utils import set_determinism
from monai.apps.pathology.handlers.utils import from_engine_hovernet
from monai.apps.pathology.engines.utils import PrepareBatchHoVerNet
from monai.apps.pathology.losses import HoVerNetLoss
from monai.transforms import Activationsd, AsDiscreted, Compose, Lambdad
from monai.handlers import MeanDice, LrScheduleHandler, ValidationHandler, CheckpointSaver, StatsHandler, from_engine, \
        TensorBoardStatsHandler


In [None]:

train_data = CellDataset('data/train')
test_data = CellDataset('data/test', train=False)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=16, shuffle=True)
hover = hovernet.HoVerNet()


In [None]:
def run_hover(log_dir, cfg, model):
    set_determinism(seed=cfg["seed"])

    if cfg["mode"].lower() == "original":
        cfg["patch_size"] = [270, 270]
        cfg["out_size"] = [80, 80]
    elif cfg["mode"].lower() == "fast":
        cfg["patch_size"] = [256, 256]
        cfg["out_size"] = [164, 164]

    
    device = torch.device("cuda" if cfg["use_gpu"] else "cpu")


    # --------------------------------------------------------------------------
    # Create Model, Loss, Optimizer, lr_scheduler
    # --------------------------------------------------------------------------
    # __________________________________________________________________________
    # initialize model

    loss_function = HoVerNetLoss(lambda_hv_mse=1.0)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg["lr"], weight_decay=1e-5)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25)
    post_process_np = Compose(
        [
            Activationsd(keys=HoVerNetBranch.NP.value, softmax=True),
            AsDiscreted(keys=HoVerNetBranch.NP.value, argmax=True),
        ]
    )
    post_process = Lambdad(keys="pred", func=post_process_np)

    # --------------------------------------------
    # Ignite Trainer/Evaluator
    # --------------------------------------------
    # Evaluator
    val_handlers = [
        CheckpointSaver(
            save_dir=log_dir,
            save_dict={"model": model},
            save_key_metric=True,
        ),
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: None),
    ]
    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=test_loader,
        prepare_batch=PrepareBatchHoVerNet(extra_keys=["label_type", "hover_label_inst"]),
        network=model,
        postprocessing=post_process,
        key_val_metric={
            "val_dice": MeanDice(
                include_background=False,
                output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value),
            )
        },
        val_handlers=val_handlers,
        amp=cfg["amp"],
    )

    # Trainer
    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=cfg["val_freq"], epoch_level=True),
        CheckpointSaver(
            save_dir=log_dir,
            save_dict={"model": model, "opt": optimizer},
            save_interval=cfg["save_interval"],
            save_final=True,
            final_filename="model.pt",
            epoch_level=True,
        ),
        StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)),
        TensorBoardStatsHandler(
            log_dir=log_dir, tag_name="train_loss", output_transform=from_engine(["loss"], first=True)
        ),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=cfg["n_epochs"],
        train_data_loader=train_loader,
        prepare_batch=PrepareBatchHoVerNet(extra_keys=["label_type", "hover_label_inst"]),
        network=model,
        optimizer=optimizer,
        loss_function=loss_function,
        postprocessing=post_process,
        key_train_metric={
            "train_dice": MeanDice(
                include_background=False,
                output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value),
            )
        },
        train_handlers=train_handlers,
        amp=cfg["amp"],
    )
    trainer.run()
