Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feat-run-examples' into feat-run…
Browse files Browse the repository at this point in the history
…-examples
  • Loading branch information
alainjungo committed Oct 7, 2020
2 parents 1bcd7d1 + b4e30be commit 25f2625
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 54 deletions.
2 changes: 1 addition & 1 deletion docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The following examples illustrate the intended use of pymia:
examples.augmentation.basic

The examples are available as Jupyter notebooks and Python scripts on `GitHub <https://github.com/rundherum/pymia/tree/master/examples/>`_ or directly rendered in the documentation by following the links above.
Furthermore, there exist complete training scripts in TensorFlow and PyTorch at `GitHub <https://github.com/rundherum/pymia/tree/master/examples/training-examples>`_.
Furthermore, there exist complete training scripts in TensorFlow and PyTorch at `./examples/training-examples on GitHub <https://github.com/rundherum/pymia/tree/master/examples/training-examples>`_.
For all examples, 3 tesla MR images of the head of four healthy subjects from the Human Connectome Project (HCP) [VanEssen2013]_ are used.
Each subject has four 3-D images (in the MetaImage and Nifty format) and demographic information provided as a text file.
The images are a T1-weighted MR image, a T2-weighted MR image, a label image (ground truth), and a brain mask image.
Expand Down
85 changes: 45 additions & 40 deletions examples/augmentation/basic.ipynb

Large diffs are not rendered by default.

30 changes: 17 additions & 13 deletions examples/augmentation/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def __call__(self, sample: dict) -> dict:
sample[entry] = np.expand_dims(np_entry, 0)

# apply batchgenerators transforms
sample = self.transforms(**sample) # todo: make loop over transforms
for t in self.transforms:
sample = t(**sample)

# squeeze samples back to original format
for entry in self.entries:
Expand Down Expand Up @@ -80,12 +81,14 @@ def plot_sample(plot_dir: str, id_: str, sample: dict):


def main(hdf_file, plot_dir):
os.makedirs(plot_dir, exist_ok=True)

# setup the datasource
extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES, defs.KEY_LABELS))
indexing_strategy = extr.SliceIndexing()
dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor)

seed = 42
seed = 1
np.random.seed(seed)
sample_idx = 55

Expand All @@ -106,6 +109,17 @@ def main(hdf_file, plot_dir):
sample = dataset[sample_idx]
plot_sample(plot_dir, 'pymia', sample)

# augmentation with batchgenerators
transforms_augmentation = [BatchgeneratorsTransform([
bg_tfm.spatial_transforms.MirrorTransform(axes=(0, 1), data_key=defs.KEY_IMAGES, label_key=defs.KEY_LABELS),
bg_tfm.noise_transforms.GaussianBlurTransform(blur_sigma=(0.2, 1.0), data_key=defs.KEY_IMAGES, label_key=defs.KEY_LABELS),
])]
train_transforms = tfm.ComposeTransform(
transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation)
dataset.set_transform(train_transforms)
sample = dataset[sample_idx]
plot_sample(plot_dir, 'batchgenerators', sample)

# augmentation with TorchIO
transforms_augmentation = [TorchIOTransform(
[tio.RandomFlip(axes=('LR'), flip_probability=1.0, keys=(defs.KEY_IMAGES, defs.KEY_LABELS), seed=seed),
Expand All @@ -118,16 +132,6 @@ def main(hdf_file, plot_dir):
sample = dataset[sample_idx]
plot_sample(plot_dir, 'torchio', sample)

# augmentation with batchgenerators
transforms_augmentation = [BatchgeneratorsTransform(
bg_tfm.spatial_transforms.Rot90Transform(axes=(0, 1), data_key=defs.KEY_IMAGES, label_key=defs.KEY_LABELS,
p_per_sample=1.0))]
train_transforms = tfm.ComposeTransform(
transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation)
dataset.set_transform(train_transforms)
sample = dataset[sample_idx]
plot_sample(plot_dir, 'batchgenerators', sample)


if __name__ == '__main__':
"""The program's entry point.
Expand All @@ -147,7 +151,7 @@ def main(hdf_file, plot_dir):
parser.add_argument(
'--plot_dir',
type=str,
default='../example-data/log',
default='../example-data/augmentation',
help='Path to the plotting directory.'
)

Expand Down

0 comments on commit 25f2625

Please sign in to comment.