Overall architecture of MDViT, which is trained on multi-domain data by optimizing two types of losses:
This is a PyTorch implementation for MDViT: Multi-domain Vision Transformer for Small Medical Image Segmentation Datasets, MICCAI 2023.
We also include plenty of comparing models in this repository: SwinUnet, UNETR, UTNet, TransFuse, DASE, and USE (Please go to the paper to find the detailed information of these models).
If you use this code in your research, please consider citing:
@inproceedings{du2023mdvit,
title={{MDViT}: Multi-domain Vision Transformer for Small Medical Image Segmentation Datasets},
author={Du, Siyi and Bayasi, Nourhan and Hamarneh, Ghassan and Garbi, Rafeef},
booktitle={26th International Conference on Medical Image Computing and Computer Assisted Intervention (MICCAI 2023)},
This code is implemented using Python 3.8.1, PyTorch v1.8.0, CUDA 11.1 and CuDNN 7.
conda create -n skinlesion python=3.8
conda activate skinlesion # activate the environment and install all dependencies
cd MDViT/
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
# or go to https://pytorch.org/get-started/previous-versions/ to find a right command to install pytorch
pip install -r requirements.txt
- Please run the following command to resize original images into the same dimension (512,512) and then convert and store them as .npy files.
python Datasets/process_resize.py
- Use Datasets/create_meta.ipynb to create the csv files for each dataset.
- MDViT
python -u multi_train_MDViT.py --exp_name test --config_yml Configs/multi_train_local.yml --model MDViT --batch_size 4 --adapt_method Sup --dataset isic2018 PH2 DMF SKD --k_fold 0
- TransFuse or TransFuse+DA
# TransFuse
python -u multi_train_TransFuse.py --exp_name test --config_yml Configs/multi_train_local.yml --model TransFuse --batch_size 4 --adapt_method False --dataset isic2018 PH2 DMF SKD --k_fold 0
# TransFuse+DA
python -u multi_train_TransFuse.py --exp_name test --config_yml Configs/multi_train_local.yml --model TransFuse_adapt --batch_size 4 --adapt_method Sup --dataset isic2018 PH2 DMF SKD --k_fold 0
- BASE or other models
# BASE BASE+DSN SwinUnet UTNet SwinUNETR BASE+DASE BASE+USE
python -u multi_train_TransFuse.py --exp_name test --config_yml Configs/multi_train_local.yml --model TransFuse --batch_size 4 --adapt_method False --dataset isic2018 PH2 DMF SKD --k_fold 0
Code is here