This repository provides the code for our MICCAI2023 paper "Punctate White Matter Lesion Segmentation in Preterm Infants Powered by Counterfactually Generative Learning".
We release our code in TensorFlow. The code is based on https://github.com/RicardoZiTseng/3D-MASNet.
As shown in the picture, our DeepPMWL consists of four parts, i.e., the tissue segmentation module (T-SEG), the classification module (CLS), the counterfactual map generator (CMG), and the PWML segmentation module (P-SEG). Specifically, in the training stage, T-SEG is learned on control data, while other modules are learned on PWML data. Given an image patch as the input, CLS is trained to distinguish positive (containing PWML) or negative (no PWML) cases, based on which CMG is further trained to produce a counterfactual map to linearly manipulate the input to change the CLS result.
The dataset contains control group and PWML group. All data is named as fellow. Each subject consists of T1 and label images for segmentation, for example "subject-1-T1.nii" (for the T1 image) and "subject-1-label.nii" (for its label image). The meanings of control data label are background(0), CSF(1), GM(2) and WM(3). And the meanings of PWML data label are background(0) and PWML(1).
- python 3.8+
- tensorflow 2.4+
- Keras
- nibabel
run train_T-SEG.py
run train_CLS.py
run train_CMG.py
run train_P-SEG.py
run predict_DeepPWML.py