Official PyTorch/GPU implementation of the paper Masked Diffusion as Self-Supervised Representation Learner. This code is based on ddpm-segmentation.
- [April 12, 2024] The code and checkpoints are released.
- [March 7, 2024] Trained a better MDM on FFHQ.
The evaluation is conducted on two medical image segmentation datasets: GlaS and MoNuSeg, and two natural image segmentation datasets collected by ddpm-segmentation: FFHQ-34 and CelebA-19. We use FFHQ as the pre-training dataset for FFHQ-34 and CelebA-19 segmentation.
Before starting, we recommend to create a new conda environment:
conda env create -f environment.yml
Then, activate the environment:
conda activate masked_diffusion
We provide the pre-training settings in experiments folder in guided_diffusion and mask_diffusion. For example, to pre-train MDM on MoNuSeg, run:
python masked_diffsuion/experiments/MoNuSeg/Train.py
The model trained on FFHQ for DDPM is adopted from ddpm-segmentation. We provide the pre-trained models for DDPM and MDM on GlaS and MoNuSeg datasets, and the pre-trained model for MDM on FFHQ. The pre-trained models are available at Google Drive.
Before fine-tuning, please download the pre-trained models and put them in the corresponding folders. Then, revise the json file in ./experiments folder and change the dataset name in script file. Finally, run:
bash scripts/mdm_glas_monuseg.sh
Performance in terms of Dice, IoU and AJI evaluated on GlaS, MoNuSeg:
Performance in terms of mean IoU evaluated on FFHQ-34, CelebA-19:If you find this repository useful, please use the following BibTeX entry for citation.
@misc{pan2023masked,
title={Masked Diffusion as Self-supervised Representation Learner},
author={Zixuan Pan and Jianxu Chen and Yiyu Shi},
year={2023},
eprint={2308.05695},
archivePrefix={arXiv},
primaryClass={cs.CV}
}