Contour-Guided Diffusion Models for Unpaired Image-to-Image Translation
By Yuwen Chen, Nicholas Konz, Hanxue Gu, Haoyu Dong, Yaqian Chen, Lin Li, Jisoo Lee and Maciej Mazurowski
This is the code for our paper ContourDiff: Unpaired Medical Image Translation with Structural Consistency, which is a novel framework that leverages domain-invariant anatomical contour representations of images to enable unpaired translation between different domains.
Our method can:
- Enforce precise anatomical consistency even between modelaities with severe structural biases (See example figure below)
- Potentially translate images from arbitrary unseen input domains (i.e., train once, translate any)
Great thanks to Segmentation-guided Diffusion for inspiration and code backbone!
Please cite our paper if you use our code or reference our work:
@article{chen2024contourdiff,
title={ContourDiff: Unpaired Image Translation with Contour-Guided Diffusion Models},
author={Chen, Yuwen and Konz, Nicholas and Gu, Hanxue and Dong, Haoyu and Chen, Yaqian and Li, Lin and Lee, Jisoo and Mazurowski, Maciej A},
journal={arXiv preprint arXiv:2403.10786},
year={2024}
}
Please follow the steps below to have your own ContourDiff model!
To extract the contours, run command:
python preprocess.py \
--data_directory {DATA_DIRECTORY} \
--domain_img_folder {DOMAIN_IMG_FOLDER} \
--domain_contour_folder {DOMAIN_CONTOUR_FOLDER} \
--domain_meta_path {DOMAIN_META_PATH} \
where:
DATA_DIRECTORY
is directory of data from multiple domainsDOMAIN_IMG_FOLDER
is path to certain domain images underDATA_DIRECTORY
DOMAIN_CONTOUR_FOLDER
is path to save extracted contours underDATA_DIRECTORY
DOMAIN_META_PATH
is path (*.csv) to save meta information underDATA_DIRECTORY
To enable removal of non-anatomical background artifacts, use --remove_artifact
.
For example, given data structure below:
DATA_DIRECTORY
├── domain_1
│ ├── images
│ │ ├── img_1.png
│ │ ├── img_2.png
│ │ └── ...
├── domain_2
│ ├── images
│ │ ├── img_1.png
│ │ ├── img_2.png
│ │ └── ...
└── ...
If extracting contours for images from domain 1, then set DOMAIN_IMG_FOLDER="domain_1/images"
Then, if setting DOMAIN_CONTOUR_FOLDER="domain_1/contours"
and DATA_DIRECTORY="domain_1/df_meta.csv"
, the outcome data structure is:
DATA_DIRECTORY
├── domain_1
│ ├── images
│ │ ├── img_1.png
│ │ ├── img_2.png
│ │ └── ...
│ ├── contours
│ │ ├── img_1.png
│ │ ├── img_2.png
│ │ └── ...
│ ├── df_meta.csv
├── domain_2
│ ├── images
│ │ ├── img_1.png
│ │ ├── img_2.png
│ │ └── ...
└── ...
To visualize the extracted contours, run contour_checker.ipynb
.
To train your own ContourDiff model, run command:
CUDA_VISIBLE_DEVICES=0,1,2 python3 train.py \
--input_domain {INPUT_DOMAIN} \
--output_domain {OUTPUT_DOMAIN} \
--data_directory {DATA_DIRECTORY} \
--input_domain_img_folder {INPUT_DOMAIN_IMG_FOLDER} \
--input_domain_contour_folder {INPUT_DOMAIN_CONTOUR_FOLDER} \
--output_domain_img_folder {OUTPUT_DOMAIN_IMG_FOLDER} \
--output_domain_contour_folder {OUTPUT_DOMAIN_CONTOUR_FOLDER} \
--input_domain_meta_path {INPUT_DOMAIN_META_PATH} \
--output_domain_meta_path {OUTPUT_DOMAIN_META_PATH} \
--output_dir {OUTPUT_DIR}
--contour_guided \
--near_guided \
--near_guided_ratio {NEAR_GUIDED_RATIO}
where:
INPUT_DOMAIN
is the string name of the input domain (e.g. any, CT or MRI)OUTPUT_DOMAIN
is the string name of the output domain (e.g. CT or MRI)DATA_DIRECTORY
is directory of data from multiple domainsINPUT_DOMAIN_IMG_FOLDER
is path to input domain images underDATA_DIRECTORY
INPUT_DOMAIN_CONTOUR_FOLDER
is path to input domain contours underDATA_DIRECTORY
OUTPUT_DOMAIN_IMG_FOLDER
is path to output domain images underDATA_DIRECTORY
OUTPUT_DOMAIN_CONTOUR_FOLDER
is path to output domain contours underDATA_DIRECTORY
INPUT_DOMAIN_META_PATH
is path (*.csv) to input domain meta file underDATA_DIRECTORY
OUTPUT_DOMAIN_META_PATH
is path (*.csv) to output domain meta file underDATA_DIRECTORY
OUTPUT_DIR
is absolute path to save output results, including model checkpoints and visualization samplescontour_guided
is flag to enable contour-guided mode for diffusion modelsnear_guided
is flag to enable adjacent-slice guided modenear_guided_ratio
is the ratio to provide adjacent slice
Notice: Input domain images and contours are used for validation in the training phase.
To translate input domain images using your own ContourDiff model in 2D setting, run command:
python translation_2d.py \
--input_domain {INPUT_DOMAIN} \
--output_domain {OUTPUT_DOMAIN} \
--data_directory {DATA_DIRECTORY} \
--input_domain_contour_folder {INPUT_DOMAIN_CONTOUR_FOLDER} \
--input_domain_meta_path {INPUT_DOMAIN_META_PATH} \
--num_copy {NUM_COPY} \
--by_volume \
--volume_specifier {VOLUME_SPECIFIER} \
--slice_specifier {SLICE_SPECIFIER} \
--selected_epoch {SELECTED_EPOCH} \
--translating_folder_name {TRANSLATING_FOLDER_NAME} \
--device {DEVICE} \
--num_partition {NUM_PARTITION} \
--partition {PARTITION}
where:
INPUT_DOMAIN
is the string name of the input domain (e.g. any, CT or MRI)OUTPUT_DOMAIN
is the string name of the output domain (e.g. CT or MRI)DATA_DIRECTORY
is directory of data from multiple domainsINPUT_DOMAIN_CONTOUR_FOLDER
is path to input domain contours underDATA_DIRECTORY
INPUT_DOMAIN_META_PATH
is path (*.csv) to input domain meta file underDATA_DIRECTORY
OUTPUT_DIR
is absolute path to save output results, including model checkpoints and visualization samplesNUM_COPY
is the number of samples generated in each iterationby_volume
is flag to enable slice-by-slice generation within each volumeVOLUME_SPECIFIER
is string of column to indicate each volume (e.g., "volume")SLICE_SPECIFIER
is string of column to indicate slice number (e.g., "slice")SELECTED_EPOCH
is epoch of the selected checkpoint to loadTRANSLATING_FOLDER_NAME
is absolute path to store the tranlsated imagesDEVICE
is GPU deviceNUM_PARTITION
is total number of partition to split input domain units (either slices or volumes)PARTITION
is specified partition to translate
To translate input domain images using your own ContourDiff model in 3D setting, run command:
python translation_3d.py \
--input_domain {INPUT_DOMAIN} \
--output_domain {OUTPUT_DOMAIN} \
--data_directory {DATA_DIRECTORY} \
--input_domain_contour_folder {INPUT_DOMAIN_CONTOUR_FOLDER} \
--input_domain_meta_path {INPUT_DOMAIN_META_PATH} \
--num_copy {NUM_COPY} \
--by_volume \
--volume_specifier {VOLUME_SPECIFIER} \
--slice_specifier {SLICE_SPECIFIER} \
--selected_epoch {SELECTED_EPOCH} \
--translating_folder_name {TRANSLATING_FOLDER_NAME} \
--device {DEVICE} \
--num_partition {NUM_PARTITION} \
--partition {PARTITION} \
--near_guided \
--num_copy {NUM_COPY}
where:
INPUT_DOMAIN
is the string name of the input domain (e.g. any, CT or MRI)OUTPUT_DOMAIN
is the string name of the output domain (e.g. CT or MRI)DATA_DIRECTORY
is directory of data from multiple domainsINPUT_DOMAIN_CONTOUR_FOLDER
is path to input domain contours underDATA_DIRECTORY
INPUT_DOMAIN_META_PATH
is path (*.csv) to input domain meta file underDATA_DIRECTORY
OUTPUT_DIR
is absolute path to save output results, including model checkpoints and visualization samplesNUM_COPY
is the number of samples generated in each iterationby_volume
is flag to enable slice-by-slice generation within each volumeVOLUME_SPECIFIER
is string of column to indicate each volume (e.g., "volume")SLICE_SPECIFIER
is string of column to indicate slice number (e.g., "slice")SELECTED_EPOCH
is epoch of the selected checkpoint to loadTRANSLATING_FOLDER_NAME
is absolute path to store the tranlsated imagesDEVICE
is GPU deviceNUM_PARTITION
is total number of partition to split input domain units (either slices or volumes)PARTITION
is specified partition to translatenear_guided
is flag to enable adjacent-slice guided modenum_copy
is the number of candidates for initial slice generation
Notice:
VOLUME_SPECIFIER
andSLICE_SPECIFIER
are required to enableby_volume
translation, which means the meta file should include corresponding columns.NUM_PARTITION
andPARTITION
are aimed for translation in parallel.PARTITION
is within range [0,NUM_PARTITION
- 1].
All codes in this repository are under Apache 2.0.