Skip to content

Commit

Permalink
Merge pull request #1002 from levje/levje/segment-no-streamlines
Browse files Browse the repository at this point in the history
[FIX] Empty remaining indices when filtering for a bundle
  • Loading branch information
arnaudbore committed Jun 19, 2024
2 parents 99ad51e + f67cf16 commit 74c7765
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 16 deletions.
61 changes: 61 additions & 0 deletions scilpy/segment/tests/test_tractogram_from_roi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import tempfile
import nibabel as nib
import numpy as np

from numpy.testing import (assert_array_equal,
assert_equal)

from scilpy.segment.tractogram_from_roi import (_extract_vb_one_bundle,
_extract_ib_one_bundle)
from dipy.io.stateful_tractogram import Space, StatefulTractogram


def test_extract_vb_one_bundle():
# Testing extraction of VS corresponding to a bundle using
# an empty tractogram which shouldn't raise any error.
fake_reference = nib.Nifti1Image(
np.zeros((10, 10, 10, 1)), affine=np.eye(4))
# The Space type is not important here
empty_sft = StatefulTractogram([], fake_reference, Space.RASMM)

with tempfile.TemporaryDirectory() as tmp_dir:
fake_mask1_name = os.path.join(tmp_dir, 'fake_mask1.nii.gz')
fake_mask2_name = os.path.join(tmp_dir, 'fake_mask2.nii.gz')
nib.save(nib.Nifti1Image(np.zeros((10, 10, 10)),
affine=np.eye(4), dtype=np.int8), fake_mask1_name)
nib.save(nib.Nifti1Image(np.zeros((10, 10, 10)),
affine=np.eye(4), dtype=np.int8), fake_mask2_name)

vs_ids, wpc_ids, bundle_stats = \
_extract_vb_one_bundle(empty_sft,
fake_mask1_name,
fake_mask2_name,
None, None, None,
None, None, None, None)
assert_array_equal(vs_ids, [])
assert_array_equal(wpc_ids, [])
assert_equal(bundle_stats["VS"], 0)


def test_extract_ib_one_bundle():
# Testing extraction of IS corresponding to a bundle using
# an empty tractogram which shouldn't raise any error.
fake_reference = nib.Nifti1Image(
np.zeros((10, 10, 10, 1)), affine=np.eye(4))
# The Space type is not important here
empty_sft = StatefulTractogram([], fake_reference, Space.RASMM)

with tempfile.TemporaryDirectory() as tmp_dir:
fake_mask1_name = os.path.join(tmp_dir, 'fake_mask1.nii.gz')
fake_mask2_name = os.path.join(tmp_dir, 'fake_mask2.nii.gz')
nib.save(nib.Nifti1Image(np.zeros((10, 10, 10)),
affine=np.eye(4), dtype=np.int8), fake_mask1_name)
nib.save(nib.Nifti1Image(np.zeros((10, 10, 10)),
affine=np.eye(4), dtype=np.int8), fake_mask2_name)

fc_sft, fc_ids = _extract_ib_one_bundle(
empty_sft, fake_mask1_name, fake_mask2_name, None)

assert_equal(len(fc_sft), 0)
assert_array_equal(fc_ids, [])
38 changes: 22 additions & 16 deletions scilpy/segment/tractogram_from_roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,16 +351,19 @@ def _extract_vb_one_bundle(
bundle_stats: dict
Dictionary of recognized streamlines statistics
"""
mask_1_img = nib.load(head_filename)
mask_2_img = nib.load(tail_filename)
mask_1 = get_data_as_mask(mask_1_img)
mask_2 = get_data_as_mask(mask_2_img)
if len(sft) > 0:
mask_1_img = nib.load(head_filename)
mask_2_img = nib.load(tail_filename)
mask_1 = get_data_as_mask(mask_1_img)
mask_2 = get_data_as_mask(mask_2_img)

if dilate_endpoints:
mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints)
mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints)
if dilate_endpoints:
mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints)
mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints)

_, vs_ids = filter_grid_roi_both(sft, mask_1, mask_2)
_, vs_ids = filter_grid_roi_both(sft, mask_1, mask_2)
else:
vs_ids = np.array([])

wpc_ids = []
bundle_stats = {"Initial count head to tail": len(vs_ids)}
Expand Down Expand Up @@ -499,16 +502,19 @@ def _extract_ib_one_bundle(sft, mask_1_filename, mask_2_filename,
SFT of remaining streamlines.
"""

mask_1_img = nib.load(mask_1_filename)
mask_2_img = nib.load(mask_2_filename)
mask_1 = get_data_as_mask(mask_1_img)
mask_2 = get_data_as_mask(mask_2_img)
if len(sft) > 0:
mask_1_img = nib.load(mask_1_filename)
mask_2_img = nib.load(mask_2_filename)
mask_1 = get_data_as_mask(mask_1_img)
mask_2 = get_data_as_mask(mask_2_img)

if dilate_endpoints:
mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints)
mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints)
if dilate_endpoints:
mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints)
mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints)

_, fc_ids = filter_grid_roi_both(sft, mask_1, mask_2)
_, fc_ids = filter_grid_roi_both(sft, mask_1, mask_2)
else:
fc_ids = []

fc_sft = sft[fc_ids]
return fc_sft, fc_ids
Expand Down

0 comments on commit 74c7765

Please sign in to comment.