This is Chainer implementation for Bayesian Convolutional Neural Networks.
(Keras and PyTorch re-impremitation are also available: keras_bayesian_unet, pytorch_bayesian_unet)
In this project, we assume the following two scenarios, especially for medical imaging.
- Two-dimensional segmentation / regression with the 2D U-Net. (e.g., 2D x-ray, laparoscopic images, and CT slices)
- Three-dimensional segmentation / regression with the 3D U-Net. (e.g., 3D CT volumes)
This is a part of following works.
@article{hiasa2019automated,
title={Automated muscle segmentation from clinical CT using Bayesian U-net for personalized musculoskeletal Modeling},
author={Hiasa, Yuta and Otake, Yoshito and Takao, Masaki and Ogawa, Takeshi and Sugano, Nobuhiko and Sato, Yoshinobu},
journal={IEEE Transactions on Medical Imaging},
volume={39},
number={4},
pages={1030--1040},
year={2019},
publisher={IEEE},
doi={10.1109/TMI.2019.2940555}
}
@article{sakamoto2020bayesian,
title={Bayesian segmentation of hip and thigh muscles in metal artifact-contaminated CT using convolutional neural network-enhanced normalized metal artifact reduction},
author={Sakamoto, Mitsuki and Hiasa, Yuta and Otake, Yoshito and Takao, Masaki and Suzuki, Yuki and Sugano, Nobuhiko and Sato, Yoshinobu},
journal={Journal of Signal Processing Systems},
volume={92},
number={3},
pages={335--344},
year={2020},
publisher={Springer},
doi={10.1007/s11265-019-01507-z}
}
@inproceedings{hiasa2018surgical,
title={Surgical tools segmentation in laparoscopic images using convolutional neural networks with uncertainty estimation and semi-supervised learning},
author={Hiasa, Y and Otake, Y and Nakatani, S and Harada, H and Kanaji, S and Kakeji, Y and Sato, Yoshinobu},
booktitle={Proc. International Conference of Computer Assisted Radiology and Surgery},
pages={14--15},
year={2018}
}
- Python 3
- CPU or NVIDIA GPU + CUDA CuDNN
- Chainer 6.5
- Install Chainer and dependencies from https://chainer.org/
- For other requirements, see requirements.txt.
- Install from this repository
git clone https://github.com/yuta-hi/bayesian_unet
cd bayesian_unet
python setup.py install
The data set we used are medical images and it is difficult to share due to ethical issues. Thus, we prepared the following examples using synthetic or public data set.
Approximation of the function
python examples/curve_regression/train_and_test_epistemic.py
python examples/curve_regression/train_and_test_epistemic.py --test_on_test
python examples/curve_regression/train_and_test_epistemic_aleatoric.py
python examples/curve_regression/train_and_test_epistemic_aleatoric.py --test_on_test
Ten digits classification. A subset of samples was used for the training data set. In the default setting, 1,000 samples are used for training and 1,000 samples are used for validation. The distribution of predicted variance for correct and wrong predictions on the test data set (10,000 samples) were visualized.
python examples/mnist_classification/train_and_test_epistemic.py
python examples/mnist_classification/train_and_test_epistemic.py --test_on_test
Segmentation of surgical instruments from laparoscopic images. Data set is downloaded from https://endovissub-instrument.grand-challenge.org/ . Training and test data sets consist 160 and 140 images, respectively.
python examples/miccai_endovis_segmentation/preprocess.py # download the dataset and convert label format
python examples/miccai_endovis_segmentation/train_and_test_epistemic.py
python examples/miccai_endovis_segmentation/train_and_test_epistemic.py --test_on_test
Aerial-to-Map translation. This example focuses on how the adversarial training affects uncertainty behavior. This is mainly followed the previous work [P. Isola, et al.]. In this example, the generator is replaced to Bayesian U-Net for uncertainty estimates. And, spectral normalization [T. Miyato et al.] is applied to the patch discriminator for stabilizing the optimization.
cd examples/map_synthesis
python preprocess.py # download and normalize the dataset
python train_and_test_baseline.py --out logs/baseline
python train_and_test_pix2pix.py --out logs/pix2pix
Note that this is under construction.
On going.
Please follow the description to define these objects.
- datasets
- data augmentor
- data normalizer
- model
- visualizer
- validator
- inferencer
- (optional) singularity image
You can define your own dataset like below. PNG, JPG, BMP and meta image format (MHD, MHA) are supported.
- [case #1] 2D images
from chainer_bcnn.datasets import ImageDataset
data_root = './data'
patients = ['ID0', 'ID1', 'ID2'] # NOTE: 3 patients
class_list = ['background', 'liver', 'tumor']
augmentor = None # NOTE: please set if you have..
normalizer = None # NOTE: please set if you have..
dtypes = OrderedDict({
'image': np.float32,
'label': np.int32, # NOTE: if categorical label
# 'mask': np.uint8, # NOTE: please set if you have..
})
filenames = OrderedDict({
'image': '{root}/{patient}/*_image.mhd',
'label': '{root}/{patient}/*_label.mhd',
# 'mask' : '{root}/{patient}/*_mask.mhd', # NOTE: please set if you have..
})
dataset = ImageDataset(data_root, patients, classes=class_list,
dtypes=dtypes, filenames=filenames, augmentor=augmentor, normalizer=normalizer)
- [case #2] 3D volumes
from chainer_bcnn.datasets import VolumeDataset
...
dataset = VolumeDataset(data_root, patients, classes=class_list,
dtypes=dtypes, filenames=filenames, augmentor=augmentor, normalizer=normalizer)
- [case #3] Custom dataset
from chainer_bcnn.datasets import BaseDataset
class CustomDataset(BaseDataset):
...
raise NotImplementedError()
You can use the data augmentor based on geometric transformation, which has stochastic behavior.
from chainer_bcnn.data.augmentor import DataAugmentor
from chainer_bcnn.data.augmentor import Crop2D, Flip2D, Affine2D
from chainer_bcnn.data.augmentor import Crop3D, Flip3D, Affine3D
augmentor = DataAugmentor()
augmentor.add(Crop2D(size=(300,400)))
augmentor.add(Flip2D(axis=1))
augmentor.add(Affine2D(rotation=15.,
translate=(10.,10.),
shear=0.25,
zoom=(0.8, 1.2),
keep_aspect_ratio=True,
fill_mode=('nearest', 'constant'),
cval=(0.,0.),
interp_order=(3,0)))
augmentor.summary('augment.json')
You can use the data normalizer based on intensity transformation.
from chainer_bcnn.data.normalizer import Normalizer
from chainer_bcnn.data.normalizer import Clip2D, Subtract2D, Divide2D, Quantize2D
from chainer_bcnn.data.normalizer import Clip3D, Subtract3D, Divide3D, Quantize3D
normalizer = Normalizer()
normalizer.add(Clip2D((-150, 350)))
normalizer.add(Quantize2D(8))
normalizer.add(Subtract2D(0.5))
normalizer.add(Divide2D(1./255.))
normalizer.summary('norm.json')
- [case #1] Segmentation
from chainer_bcnn.models import BayesianUNet
from chainer_bcnn.links import Classifier
predictor = BayesianUNet(ndim=2,
out_channels=3,
nlayer=5,
nfilter=32)
lossfun = partial(softmax_cross_entropy,
normalize=False, class_weight=class_weight)
model = Classifier(predictor,
lossfun=lossfun)
- [case #2] Regression
from chainer_bcnn.links import Regressor
from chainer_bcnn.functions.loss import sigmoid_soft_cross_entropy
from chainer.functions import mean_squared_error
...
lossfun = mean_squared_error
# lossfun = sigmoid_soft_cross_entropy # NOTE: if you want..
model = Regressor(predictor,
lossfun=lossfun)
- [case #3] Other problems (e.g., multi-task)
from chainer_bcnn.models import UNetBase
class MultiTaskUNet(UNetBase):
def __init__(self,
ndim,
foo, # TODO
bar, # TODO
nfilter=32,
nlayer=5,
conv_param=_default_conv_param,
pool_param=_default_pool_param,
upconv_param=_default_upconv_param,
norm_param=_default_norm_param,
activation_param=_default_activation_param,
dropout_param=_default_dropout_param,
residual=False,
):
super(UNet, self).__init__(
ndim,
nfilter,
nlayer,
conv_param,
pool_param,
upconv_param,
norm_param,
activation_param,
dropout_param,
residual,)
self._foo = foo
self._bar = bar
with self.init_scope():
pass # TODO: foo, bar
def forward(self, x):
h = super().forward(x)
# TODO: foo, bar
raise NotImplementedError('foo is bar..')
- [case #1] 2D segmentation
from chainer_bcnn.visualizer import ImageVisualizer
transforms = {
'x': lambda x: x,
'y': lambda x: np.argmax(x, axis=0),
't': lambda x: x,
}
_cmap = np.array([
[0,0,0], # NOTE: background (black)
[1,0,0], # liver (red)
[0,1,0]]) # tumor (green)
cmaps = {
'x': None,
'y': _cmap,
't': _cmap,
}
clims = {
'x': (0., 255.),
'y': None,
't': None,
}
visualizer = ImageVisualizer(transforms=transforms,
cmaps=cmaps,
clims=clims)
- [case #2] 2D regression
from chainer_bcnn.visualizer import ImageVisualizer
import matplotlib.pyplot as plt
def alpha_blend(heatmaps, cmap='jet'):
assert heatmaps.ndim == 3
ch, w, h = heatmaps.shape
ret = np.zeros((3, w, h))
mapper = plt.get_cmap(cmap, ch)
for i in range(ch):
color = np.ones((3, w, h)) \
* np.asarray(mapper(i)[:3]).reshape(-1,1,1)
ret += (color * heatmaps[i])
return ret
transforms = {
'x': None,
'y': lambda x: alpha_blend(F.sigmoid(x).data),
't': lambda x: alpha_blend(x),
}
clims = {
'x': (0., 255.),
'y': (0., 1.),
't': (0., 1.),
}
cmaps = None
visualizer = ImageVisualizer(transforms=transforms,
cmaps=cmaps,
clims=clims)
To visualize 3D volumes, you can pass the volume renderer to the transforms
as described above.
from chainer_bcnn.extensions import Validator
...
valid_file = 'iter_{.updater.iteration:08}.png'
n_vis = 20 # NOTE: number of samples for visualization
trainer.extend(Validator(valid_iter, model, valid_file,
visualizer=visualizer, n_vis=n_vis,
device=gpu_id))
- [case #1] Segmentation / Classification
from chainer_bcnn.links import MCSampler
from chainer_bcnn.inference import Inferencer
import chainer.functions as F
mc_iteration = 50
model = MCSampler(predictor, # NOTE: e.g., BayesianUNet
mc_iteration=mc_iteration,
activation=partial(F.softmax, axis=1),
reduce_mean=partial(F.argmax, axis=1),
reduce_var=partial(F.mean, axis=1))
infer = Inferencer(test_iter, model, device=gpu_id)
estimated_labels, predicted_variances = infer.run()
cd recipe
make all