Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Traverse BIDS hierarchy to find masks, bvals, and bvecs #587

Merged
merged 11 commits into from Nov 9, 2020
Merged
73 changes: 47 additions & 26 deletions AFQ/api.py
Expand Up @@ -279,10 +279,13 @@ def __init__(self,
The parameters for segmentation.
Default: use the default behavior of the seg.Segmentation object.
tracking_params: dict, optional
The parameters for tracking.
Default: use the default behavior of the aft.track function.
Seed mask and seed threshold, if not specified, are replaced
with scalar masks from scalar[0] thresholded to 0.2.
The parameters for tracking. Default: use the default behavior of
the aft.track function. Seed mask and seed threshold, if not
specified, are replaced with scalar masks from scalar[0]
thresholded to 0.2. The ``seed_mask`` and ``stop_mask`` items of
this dict may be ``AFQ.mask.MaskFile`` instances. If ``tracker``
is set to "pft" then ``stop_mask`` should be an instance of
``AFQ.mask.PFTMask``.
clean_params: dict, optional
The parameters for cleaning.
Default: use the default behavior of the seg.clean_bundle
Expand Down Expand Up @@ -509,11 +512,17 @@ def __init__(self,
if session is not None:
results_dir = op.join(results_dir, 'ses-' + session)

dwi_files = bids_layout.get(subject=subject, session=session,
extension='nii.gz',
return_type='filename',
scope=dmriprep,
**bids_filters)
dwi_bids_filters = {
"subject": subject,
"session": session,
"return_type": "filename",
"scope": dmriprep,
"datatype": "dwi",
"extension": "nii.gz",
"suffix": "dwi",
}
dwi_bids_filters.update(bids_filters)
dwi_files = bids_layout.get(**dwi_bids_filters)

if (not len(dwi_files)):
self.logger.warning(
Expand All @@ -522,21 +531,23 @@ def __init__(self,
continue

results_dir_list.append(results_dir)

os.makedirs(results_dir_list[-1], exist_ok=True)
dwi_file_list.append(dwi_files[0])
bvec_file_list.append(
bids_layout.get(subject=subject, session=session,
extension=['bvec', 'bvecs'],
return_type='filename',
scope=dmriprep,
**bids_filters)[0])
bval_file_list.append(
bids_layout.get(subject=subject, session=session,
extension=['bval', 'bvals'],
return_type='filename',
scope=dmriprep,
**bids_filters)[0])

dwi_data_file = dwi_files[0]
dwi_file_list.append(dwi_data_file)

# For bvals and bvecs, use ``get_bval()`` and ``get_bvec()`` to
# walk up the file tree and inherit the closest bval and bvec
# files. Maintain input ``bids_filters`` in case user wants to
# specify acquisition labels, but pop suffix since it is
# already specified inside ``get_bvec()`` and ``get_bval()``
suffix = bids_filters.pop("suffix", None)
bvec_file_list.append(bids_layout.get_bvec(dwi_data_file,
**bids_filters))
bval_file_list.append(bids_layout.get_bval(dwi_data_file,
**bids_filters))
if suffix is not None:
bids_filters["suffix"] = suffix

if custom_tractography_bids_filters is not None:
custom_tract_list.append(
Expand All @@ -554,28 +565,38 @@ def __init__(self,

if isinstance(self.reg_subject, dict):
reg_subject_list.append(
bids_layout.get(subject=subject, session=session,
return_type='filename',
**self.reg_subject)[0])
bids_layout.get_nearest(
dwi_data_file,
**self.reg_subject,
session=session,
subject=subject,
full_search=True,
strict=True,
ignore_strict_entities=["session"]
)
)
else:
reg_subject_list.append(None)

if check_mask_methods(self.tracking_params["seed_mask"]):
self.tracking_params["seed_mask"].find_path(
bids_layout,
dwi_data_file,
subject,
session
)

if check_mask_methods(self.tracking_params["stop_mask"]):
self.tracking_params["stop_mask"].find_path(
bids_layout,
dwi_data_file,
subject,
session
)

self.brain_mask_definition.find_path(
bids_layout,
dwi_data_file,
subject,
session
)
Expand Down
12 changes: 6 additions & 6 deletions AFQ/data.py
Expand Up @@ -1706,8 +1706,8 @@ def organize_cfin_data(path=None):
nib.save(t1_img, op.join(anat_folder, 'sub-01_ses-01_T1w.nii.gz'))
dwi_img, gtab = dpd.read_cfin_dwi()
nib.save(dwi_img, op.join(dwi_folder, 'sub-01_ses-01_dwi.nii.gz'))
np.savetxt(op.join(dwi_folder, 'sub-01_ses-01_dwi.bvecs'), gtab.bvecs)
np.savetxt(op.join(dwi_folder, 'sub-01_ses-01_dwi.bvals'), gtab.bvals)
np.savetxt(op.join(dwi_folder, 'sub-01_ses-01_dwi.bvec'), gtab.bvecs)
np.savetxt(op.join(dwi_folder, 'sub-01_ses-01_dwi.bval'), gtab.bvals)

to_bids_description(
bids_path,
Expand Down Expand Up @@ -1743,8 +1743,8 @@ def organize_stanford_data(path=None):
└── sub-01
└── ses-01
└── dwi
├── sub-01_ses-01_dwi.bvals
├── sub-01_ses-01_dwi.bvecs
├── sub-01_ses-01_dwi.bval
├── sub-01_ses-01_dwi.bvec
└── sub-01_ses-01_dwi.nii.gz

"""
Expand Down Expand Up @@ -1785,8 +1785,8 @@ def organize_stanford_data(path=None):

dwi_img, gtab = dpd.read_stanford_hardi()
nib.save(dwi_img, op.join(dwi_folder, 'sub-01_ses-01_dwi.nii.gz'))
np.savetxt(op.join(dwi_folder, 'sub-01_ses-01_dwi.bvecs'), gtab.bvecs)
np.savetxt(op.join(dwi_folder, 'sub-01_ses-01_dwi.bvals'), gtab.bvals)
np.savetxt(op.join(dwi_folder, 'sub-01_ses-01_dwi.bvec'), gtab.bvecs)
np.savetxt(op.join(dwi_folder, 'sub-01_ses-01_dwi.bval'), gtab.bvals)
else:
logger.info('Dataset is already in place. If you want to fetch it '
+ 'again please first remove the folder '
Expand Down
50 changes: 35 additions & 15 deletions AFQ/mask.py
Expand Up @@ -159,15 +159,35 @@ def __init__(self, suffix, filters={}):
self.filters = filters
self.fnames = {}

def find_path(self, bids_layout, subject, session):
def find_path(self, bids_layout, from_path, subject, session):
if session not in self.fnames:
self.fnames[session] = {}
self.fnames[session][subject] = bids_layout.get(
subject=subject, session=session,
extension='.nii.gz',
return_type='filename',

nearest_mask = bids_layout.get_nearest(
from_path,
**self.filters,
extension=".nii.gz",
suffix=self.suffix,
**self.filters)[0]
session=session,
subject=subject,
full_search=True,
strict=False,
)

self.fnames[session][subject] = nearest_mask
from_path_subject = bids_layout.parse_file_entities(from_path).get(
"subject", None
)
mask_subject = bids_layout.parse_file_entities(nearest_mask).get(
"subject", None
)
if from_path_subject != mask_subject:
raise ValueError(
f"Expected subject IDs to match for the retrieved mask file "
f"and the supplied `from_path` file. Got sub-{mask_subject} "
f"from mask file {nearest_mask} and sub-{from_path_subject} "
f"from `from_path` file {from_path}."
)

def get_path_data_affine(self, afq_object, row):
mask_file = self.fnames[row['ses']][row['subject']]
Expand Down Expand Up @@ -209,7 +229,7 @@ class FullMask(StrInstantiatesMixin):
def __init__(self):
pass

def find_path(self, bids_layout, subject, session):
def find_path(self, bids_layout, from_path, subject, session):
pass

def get_mask(self, afq_object, row):
Expand All @@ -234,7 +254,7 @@ class RoiMask(StrInstantiatesMixin):
def __init__(self):
pass

def find_path(self, bids_layout, subject, session):
def find_path(self, bids_layout, from_path, subject, session):
pass

def get_mask(self, afq_object, row):
Expand Down Expand Up @@ -286,7 +306,7 @@ class B0Mask(StrInstantiatesMixin):
def __init__(self, median_otsu_kwargs={}):
self.median_otsu_kwargs = median_otsu_kwargs

def find_path(self, bids_layout, subject, session):
def find_path(self, bids_layout, from_path, subject, session):
pass

def get_mask(self, afq_object, row):
Expand Down Expand Up @@ -448,7 +468,7 @@ def __init__(self, scalar):
self.scalar = scalar

# overrides MaskFile
def find_path(self, bids_layout, subject, session):
def find_path(self, bids_layout, from_path, subject, session):
pass

# overrides MaskFile
Expand Down Expand Up @@ -536,9 +556,9 @@ def __init__(self, WM_probseg, GM_probseg, CSF_probseg):
"""
self.probsegs = (WM_probseg, GM_probseg, CSF_probseg)

def find_path(self, bids_layout, subject, session):
def find_path(self, bids_layout, from_path, subject, session):
for probseg in self.probsegs:
probseg.find_path(bids_layout, subject, session)
probseg.find_path(bids_layout, from_path, subject, session)

def get_mask(self, afq_object, row):
probseg_imgs = []
Expand All @@ -547,7 +567,7 @@ def get_mask(self, afq_object, row):
data, affine, meta = probseg.get_mask(afq_object, row)
probseg_imgs.append(nib.Nifti1Image(data, affine))
probseg_metas.append(meta)
return probseg_imgs, _, dict(sources=probseg_metas)
return probseg_imgs, None, dict(sources=probseg_metas)


class CombinedMask(StrInstantiatesMixin, CombineMaskMixin):
Expand Down Expand Up @@ -581,9 +601,9 @@ def __init__(self, mask_list, combine="and"):
CombineMaskMixin.__init__(self, combine)
self.mask_list = mask_list

def find_path(self, bids_layout, subject, session):
def find_path(self, bids_layout, from_path, subject, session):
for mask in self.mask_list:
mask.find_path(bids_layout, subject, session)
mask.find_path(bids_layout, from_path, subject, session)

def get_mask(self, afq_object, row):
self.mask_draft = None
Expand Down
4 changes: 2 additions & 2 deletions AFQ/tests/test_api.py
Expand Up @@ -71,11 +71,11 @@ def create_dummy_data(dmriprep_dir, subject, session=None):

np.savetxt(
op.join(
dmriprep_dir, data_dir, 'dwi', 'dwi.bvals'),
dmriprep_dir, data_dir, 'dwi', 'dwi.bval'),
bvals)
np.savetxt(
op.join(
dmriprep_dir, data_dir, 'dwi', 'dwi.bvecs'),
dmriprep_dir, data_dir, 'dwi', 'dwi.bvec'),
bvecs)
nib.save(
nib.Nifti1Image(data, aff),
Expand Down
31 changes: 25 additions & 6 deletions AFQ/tests/test_mask.py
@@ -1,5 +1,7 @@
import os.path as op
import numpy as np
import numpy.testing as npt
import pytest

from bids.layout import BIDSLayout

Expand Down Expand Up @@ -75,13 +77,30 @@ def test_resample_mask():
mask_data.dtype)


def test_find_path():
@pytest.mark.parametrize("subject", ["01", "02"])
@pytest.mark.parametrize("session", ["01", "02"])
def test_find_path(subject, session):
bids_dir = create_dummy_bids_path(2, 2)
print(bids_dir)
bids_layout = BIDSLayout(bids_dir, derivatives=True)

test_dwi_path = bids_layout.get(
subject=subject, session=session, return_type="filename",
suffix="dwi", extension="nii.gz"
)[0]

mask_file = MaskFile("seg", {'scope': 'synthetic'})
mask_file.find_path(bids_layout, '01', '01')
mask_file.find_path(bids_layout, '02', '01')
mask_file.find_path(bids_layout, '01', '02')
mask_file.find_path(bids_layout, '02', '02')
mask_file.find_path(bids_layout, test_dwi_path, subject, session)

assert mask_file.fnames[session][subject] == op.join(
bids_dir, "derivatives", "dmriprep", "sub-" + subject,
"ses-" + session, "anat", "seg.nii.gz"
)

other_sub = "01" if subject == "02" else "02"
with pytest.raises(ValueError):
mask_file.find_path(
bids_layout,
test_dwi_path,
subject=other_sub,
session=session,
)
9 changes: 5 additions & 4 deletions AFQ/tractography.py
@@ -1,3 +1,4 @@
from collections.abc import Iterable
import numpy as np
import nibabel as nib
import dipy.reconst.shm as shm
Expand Down Expand Up @@ -166,11 +167,11 @@ def track(params_file, directions="det", max_angle=30., sphere=None,
"You are using PFT tracking, but did not provide a string ",
"'stop_threshold' input. ",
"Possible inputs are: 'CMC' or 'ACT'")
if not isinstance(stop_mask, tuple):
if not (isinstance(stop_mask, Iterable) and len(stop_mask) == 3):
richford marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(
"You are using PFT tracking, but did not provide a tuple for",
"`stop_mask`",
"input. Expected a (pve_wm, pve_gm, pve_csf) tuple.")
"You are using PFT tracking, but did not provide a length "
"3 iterable for `stop_mask`. "
"Expected a (pve_wm, pve_gm, pve_csf) tuple.")
pves = []
pve_imgs = []
vox_sizes = []
Expand Down
4 changes: 2 additions & 2 deletions docs/source/usage/data.rst
Expand Up @@ -41,6 +41,6 @@ data set in a directory called `stanford_hardi`::
└── sub-01
└── ses-01
└── dwi
├── sub-01_ses-01_dwi.bvals
├── sub-01_ses-01_dwi.bvecs
├── sub-01_ses-01_dwi.bval
├── sub-01_ses-01_dwi.bvec
└── sub-01_ses-01_dwi.nii.gz