Skip to content

PyTorch implementation of medical semantic segmentations models, e.g. UNet, UNet++, DUCKNet, ResUNet, ResUNet++, and support knowledge distillation, distributed training, Optuna etc.

License

Notifications You must be signed in to change notification settings

zh320/medical-segmentation-pytorch

Repository files navigation

Introduction

PyTorch implementation of medical segmentation models, support multi-gpu training and validating, automatic mixed precision training, knowledge distillation, hyperparameter optimization using Optuna etc.

Requirements

torch == 1.8.1
segmentation-models-pytorch (optional)
torchmetrics
albumentations
loguru
tqdm
optuna == 4.0.0 (optional)
optuna-integration == 4.0.0 (optional)

If you find any version conflicts, see requirements. This repo may also work with torch > 1.8.1, but it has not been verified yet.

If you want a minimally reproducible environment, you may run

pip install -r requirements.txt

Supported datasets

Polyp

In order to obtain reproducible results, you may also download the train-val-test sets provided by DUCKNet's authors here.

Supported models

DUCKNet

Using DUCK-Net for polyp image segmentation [codes]

Abstract: This paper presents a novel supervised convolutional neural network architecture, “DUCK-Net”, capable of effectively learning and generalizing from small amounts of medical images to perform accurate segmentation tasks. Our model utilizes an encoder-decoder structure with a residual downsampling mechanism and a custom convolutional block to capture and process image information at multiple resolutions in the encoder segment. We employ data augmentation techniques to enrich the training set, thus increasing our model's performance. While our architecture is versatile and applicable to various segmentation tasks, in this study, we demonstrate its capabilities specifically for polyp segmentation in colonoscopy images. We evaluate the performance of our method on several popular benchmark datasets for polyp segmentation, Kvasir-SEG, CVC-ClinicDB, CVC-ColonDB, and ETIS-LARIBPOLYPDB showing that it achieves state-of-the-art results in terms of mean Dice coefficient, Jaccard index, Precision, Recall, and Accuracy. Our approach demonstrates strong generalization capabilities, achieving excellent performance even with limited training data.

ResUNet

Road Extraction by Deep Residual U-Net [codes]

Abstract: Road extraction from aerial images has been a hot research topic in the field of remote sensing image analysis. In this letter, a semantic segmentation neural network which combines the strengths of residual learning and U-Net is proposed for road area extraction. The network is built with residual units and has similar architecture to that of U-Net. The benefits of this model is two-fold: first, residual units ease training of deep networks. Second, the rich skip connections within the network could facilitate information propagation, allowing us to design networks with fewer parameters however better performance. We test our network on a public road dataset and compare it with U-Net and other two state of the art deep learning based road extraction methods. The proposed approach outperforms all the comparing methods, which demonstrates its superiority over recently developed state of the arts.

ResUNet++

ResUNet++: An Advanced Architecture for Medical Image Segmentation [codes]

Abstract: Accurate computer-aided polyp detection and segmentation during colonoscopy examinations can help endoscopists resect abnormal tissue and thereby decrease chances of polyps growing into cancer. Towards developing a fully automated model for pixel-wise polyp segmentation, we propose ResUNet++, which is an improved ResUNet architecture for colonoscopic image segmentation. Our experimental evaluations show that the suggested architecture produces good segmentation results on publicly available datasets. Furthermore, ResUNet++ significantly outperforms U-Net and ResUNet, two key state-of-the-art deep learning architectures, by achieving high evaluation scores with a dice coefficient of 81.33%, and a mean Intersection over Union (mIoU) of 79.27% for the Kvasir-SEG dataset and a dice coefficient of 79.55%, and a mIoU of 79.62% with CVC-612 dataset.

UNet

U-Net: Convolutional Networks for Biomedical Image Segmentation [codes]

Abstract: There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at this http URL.

UNet++

UNet++: A Nested U-Net Architecture for Medical Image Segmentation [codes]

Abstract: In this paper, we present UNet++, a new, more powerful architecture for medical image segmentation. Our architecture is essentially a deeply-supervised encoder-decoder network where the encoder and decoder sub-networks are connected through a series of nested, dense skip pathways. The re-designed skip pathways aim at reducing the semantic gap between the feature maps of the encoder and decoder sub-networks. We argue that the optimizer would deal with an easier learning task when the feature maps from the decoder and encoder networks are semantically similar. We have evaluated UNet++ in comparison with U-Net and wide U-Net architectures across multiple medical image segmentation tasks: nodule segmentation in the low-dose CT scans of chest, nuclei segmentation in the microscopy images, liver segmentation in abdominal CT scans, and polyp segmentation in colonoscopy videos. Our experiments demonstrate that UNet++ with deep supervision achieves an average IoU gain of 3.9 and 3.4 points over U-Net and wide U-Net, respectively.


If you want to use encoder-decoder structure with pretrained encoders, you may refer to: segmentation-models-pytorch1. This repo also provides easy access to SMP. Just modify the config file to (e.g. if you want to train UNet with ResNet-101 backbone as teacher model to perform knowledge distillation)

self.model = 'smp'
self.encoder = 'resnet101'
self.decoder = 'unet'

or use command-line arguments

python main.py --model smp --encoder resnet101 --decoder unet

Details of the configurations can also be found in this file.

Knowledge Distillation

Currently only support the original knowledge distillation method proposed by Geoffrey Hinton.2

Hyperparameter Optimization

This repo also supports hyperparameter optimization using Optuna.3 For example, if you have enough computing power and want to search hyperparameters using UNet, you may simply run

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 optuna_search.py

How to use

DDP training (recommend)

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py

DP training

CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py

Performances and checkpoints

Kvasir

Model Pretrained Params(M) Dice(paper)4 mIoU(paper)4 Dice(my)5
val/test
mIoU(my)5
val/test
DUCKNet-17 None 40.1 0.9343 0.8769 0.9227/0.9515 0.8612/0.9094
DUCKNet-34 None 160.28 0.9502 0.9051 0.9230/0.9573 0.8618/0.9198
ResUNet-321 None 2.06 n.a. n.a. 0.8902/0.9113 0.8107/0.8426
ResUNet++ None 3.93 n.a. n.a. 0.9238/0.9369 0.8629/0.8843
UNet-322 None 8.63 0.8655 0.7629 0.9107/0.9444 0.8421/0.8972
UNet++ None 8.53 n.a. n.a. 0.9235/0.95516 0.8625/0.91576
smp-UNet3 None 14.33 n.a. n.a. 0.9267/0.9556 0.8677/0.9166
smp-UNet ImageNet 14.33 n.a. n.a. 0.9341/0.9543 0.8799/0.9144

[1ResUNet-32 stands for ResUNet with base_channel=32. Please note that the original base_channel for ResUNet is 64. To reduce computation complexity and VRAM usuage, I used base_channel=32 in my experiments.]
[2UNet-32 stands for vanilla UNet with base_channel=32. Please note that the original base_channel for UNet is 64. To reduce computation complexity and VRAM usuage, I used base_channel=32 in my experiments.]
[3smp-UNet stands for smp model with decoder UNet. For simplicity, ResNet-18 is chosen as the decoder in my experiments.]
[4These results were obtained from DUCKNet's paper]
[5These results were obtained by training 400 epochs with crop-size 320x320]
[6These results were obtained by using auxiliary heads]

CVC-ClinicDB

Model Pretrained Params(M) Dice(paper) mIoU(paper) Dice(my)
val/test
mIoU(my)
val/test
DUCKNet-17 None 40.1 0.9450 0.8952 0.9607/0.9434 0.9262/0.8966
DUCKNet-34 None 160.28 0.9478 0.9009 0.9674/0.9451 0.9382/0.8995
ResUnet-32 None 2.06 n.a. n.a. 0.9178/0.9021 0.8554/0.8319
ResUnet++ None 3.93 n.a. n.a. 0.9564/0.9424 0.9187/0.8949
UNet-32 None 8.63 0.7631 0.6169 0.9537/0.9424 0.9140/0.8949
UNet++ None 8.53 n.a. n.a. 0.9581/0.9520 0.9217/0.9112
smp-UNet None 14.33 n.a. n.a. 0.9662/0.9503 0.9361/0.9082
smp-UNet ImageNet 14.33 n.a. n.a. 0.9737/0.9566 0.9497/0.9190

CVC-ColonDB

Model Pretrained Params(M) Dice(paper) mIoU(paper) Dice(my)
val/test
mIoU(my)
val/test
DUCKNet-17 None 40.1 0.9353 0.8785 0.9432/0.9357 0.8968/0.8847
DUCKNet-34 None 160.28 0.9230 0.8571 0.9390/0.9322 0.8899/0.8790
ResUNet-32 None 2.06 n.a. n.a. 0.7017/0.7551 0.6112/0.6600
ResUNet++ None 3.93 n.a. n.a. 0.9060/0.8979 0.8390/0.8273
UNet-32 None 8.63 0.8032 0.7037 0.9125/0.8966 0.8486/0.8255
UNet++ None 8.53 n.a. n.a. 0.9296/0.9141 0.8748/0.8510
smp-UNet None 14.33 n.a. n.a. 0.9545/0.9498 0.9157/0.9078
smp-UNet ImageNet 14.33 n.a. n.a. 0.9676/0.9658 0.9388/0.9356

ETIS-LaribpolypDB

Model Pretrained Params(M) Dice(paper) mIoU(paper) Dice(my)
val/test
mIoU(my)
val/test
DUCKNet-17 None 40.1 0.9324 0.8734 0.8939/0.9013 0.8223/0.8323
DUCKNet-34 None 160.28 0.9354 0.8788 0.8805/0.8884 0.8040/0.8142
ResUNet-32 None 2.06 n.a. n.a. 0.6427/0.6524 0.5710/0.5748
ResUNet++ None 3.93 n.a. n.a. 0.8739/0.8352 0.7952/0.7460
UNet-32 None 8.63 0.7984 0.6969 0.8294/0.8218 0.7403/0.7296
UNet++ None 8.53 n.a. n.a. 0.9188/0.8730 0.8587/0.7929
smp-UNet None 14.33 n.a. n.a. 0.9386/0.8997 0.8896/0.8300
smp-UNet ImageNet 14.33 n.a. n.a. 0.9740/0.9706 0.9504/0.9442

Knowledge distillation

Dataset Model kd_training Dice
val/test
mIoU
val/test
Kvasir smp-UNet teacher 0.9341/0.9543 0.8799/0.9144
UNet-32 False 0.9107/0.9444 0.8421/0.8972
UNet-32 True 0.9144/0.9458 0.8478/0.8995
CVC-ClinicDB smp-UNet teacher 0.9737/0.9566 0.9497/0.9190
UNet-32 False 0.9537/0.9424 0.9140/0.8949
UNet-32 True 0.9570/0.9468 0.9197/0.9023
CVC-ColonDB smp-UNet teacher 0.9676/0.9658 0.9388/0.9356
UNet-32 False 0.9125/0.8966 0.8486/0.8255
UNet-32 True 0.9289/0.9131 0.8738/0.8496
ETIS-LaribpolypDB smp-UNet teacher 0.9740/0.9706 0.9504/0.9442
UNet-32 False 0.8294/0.8218 0.7403/0.7296
UNet-32 True 0.8547/0.7988 0.7706/0.7023

Hyperparameter Optimization

Model Dataset Dice(paper) Dice(random)
val/test
Dice(Optuna)
val/test
ResUNet-32 Kvasir n.a. 0.8902/0.9113 0.8933/0.9133
CVC-ClinicDB n.a. 0.9178/0.9021 0.9490/0.9383
CVC-ColonDB n.a. 0.7017/0.7551 0.9266/0.9263
ETIS-LaribpolypDB n.a. 0.6427/0.6524 0.8500/0.8725
ResUNet++ Kvasir n.a. 0.9238/0.9369 0.9248/0.9304
CVC-ClinicDB n.a. 0.9564/0.9424 0.9683/0.9685
CVC-ColonDB n.a. 0.9060/0.8979 0.9487/0.9407
ETIS-LaribpolypDB n.a. 0.8739/0.8352 0.9231/0.8905
UNet-32 Kvasir 0.8655 0.9107/0.9444 0.9235/0.9483
CVC-ClinicDB 0.7631 0.9537/0.9424 0.9679/0.9601
CVC-ColonDB 0.8032 0.9125/0.8966 0.9601/0.9529
ETIS-LaribpolypDB 0.7984 0.8294/0.8218 0.9653/0.9411
UNet++ Kvasir n.a. 0.9235/0.9551 0.9302/0.9462
CVC-ClinicDB n.a. 0.9581/0.9520 0.9733/0.9682
CVC-ColonDB n.a. 0.9296/0.9141 0.9621/0.9511
ETIS-LaribpolypDB n.a. 0.9188/0.8730 0.9687/0.9549
smp-UNet Kvasir n.a. 0.9341/0.9543 0.9368/0.9624
CVC-ClinicDB n.a. 0.9737/0.9566 0.9769/0.9727
CVC-ColonDB n.a. 0.9676/0.9658 0.9758/0.9634
ETIS-LaribpolypDB n.a. 0.9740/0.9706 0.9796/0.9700

[When using random search, the hyperparameters were chosen from the default config. For Optuna search, each experiment was performed 100 trials.]

Prepare the dataset

PolypDataset/
├── Kvasir-SEG/
│   ├── train/
│   │   ├── images/
│   │   └── masks/
│   ├── validation/
│   └── test/
├── CVC-ClinicDB/
├── CVC-ColonDB/
└── ETIS-LaribPolypDB/

References

Footnotes

  1. segmentation-models-pytorch

  2. Distilling the Knowledge in a Neural Network

  3. Optuna: A hyperparameter optimization framework

About

PyTorch implementation of medical semantic segmentations models, e.g. UNet, UNet++, DUCKNet, ResUNet, ResUNet++, and support knowledge distillation, distributed training, Optuna etc.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages