Skip to content

Commit

Permalink
Merge pull request #919 from 36000/resample_subject_fix
Browse files Browse the repository at this point in the history
[FIX] don't resample subject-space ROIs unless user provides something
  • Loading branch information
arokem committed Nov 29, 2022
2 parents a44d570 + 65596eb commit bed4915
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 45 deletions.
41 changes: 18 additions & 23 deletions AFQ/api/bundle_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self,
bundle_info=BUNDLES,
seg_algo="afq",
resample_to=None,
resample_subject_to=None,
resample_subject_to=False,
keep_in_memory=False):
"""
Create a bundle dictionary, needed for the segmentation
Expand Down Expand Up @@ -100,8 +100,6 @@ def __init__(self,
If there are bundles in bundle_info with the 'space' attribute
set to 'subject', their images (all ROIs and probability maps)
will be resampled to the affine and shape of this image.
If None, the template will be overriden when passed to
an API class.
If False, no resampling will be done.
Default: None
Expand Down Expand Up @@ -326,17 +324,15 @@ def cond_load(sl): return load_tractogram(
bbox_valid_check=False).streamlines
if not self.keep_in_memory:
old_vals = self.apply_to_rois(key, cond_load)
if self.resample_to:
self._resample_roi(key)
self._resample_roi(key)
else:
if "loaded" not in self._dict[key] or\
not self._dict[key]["loaded"]:
self.apply_to_rois(key, cond_load)
self._dict[key]["loaded"] = True
old_vals = None
if self.resample_to and (
"resampled" not in self._dict[key] or not self._dict[
key]["resampled"]):
if "resampled" not in self._dict[key] or not self._dict[
key]["resampled"]:
self._resample_roi(key)
_item = self._dict[key].copy()
if old_vals is not None:
Expand Down Expand Up @@ -455,23 +451,24 @@ def _resample_roi(self, b_name):
b_name : str
Name of the bundle to be resampled.
"""
if self.resample_to and self.seg_algo == "afq":
if self.seg_algo == "afq":
if "space" not in self._dict[b_name]\
or self._dict[b_name]["space"] == "template":
resample_to = self.resample_to
else:
resample_to = self.resample_subject_to
try:
self.apply_to_rois(
b_name,
afd.read_resample_roi,
resample_to=resample_to)
self._dict[b_name]["resampled"] = True
except AttributeError as e:
if "'ImageFile' object" in str(e):
self._dict[b_name]["resampled"] = False
else:
raise
if resample_to:
try:
self.apply_to_rois(
b_name,
afd.read_resample_roi,
resample_to=resample_to)
self._dict[b_name]["resampled"] = True
except AttributeError as e:
if "'ImageFile' object" in str(e):
self._dict[b_name]["resampled"] = False
else:
raise

def __add__(self, other):
self.gen_all()
Expand Down Expand Up @@ -520,7 +517,7 @@ def __init__(self,
bundle_info=PEDIATRIC_BUNDLES,
seg_algo="afq",
resample_to=None,
resample_subject_to=None,
resample_subject_to=False,
keep_in_memory=False):
"""
Create a pediatric bundle dictionary, needed for the segmentation
Expand Down Expand Up @@ -548,8 +545,6 @@ def __init__(self,
If there are ROIs with the 'space' attribute
set to 'subject', those ROIs will be resampled to the affine
and shape of this image.
If None, the template will be overriden when passed to
an API class.
If False, no resampling will be done.
Default: None
Expand Down
26 changes: 22 additions & 4 deletions AFQ/data/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
"fetch_templates", "read_templates",
"fetch_stanford_hardi_tractography",
"read_stanford_hardi_tractography",
"organize_stanford_data"]
"organize_stanford_data",
"fetch_stanford_hardi_lv1"]


afq_home = op.join(op.expanduser('~'), 'AFQ_data')
Expand All @@ -57,14 +58,18 @@


def _make_reusable_fetcher(name, folder, baseurl, remote_fnames, local_fnames,
doc="", **make_fetcher_kwargs):
doc="", md5_list=None, **make_fetcher_kwargs):
def fetcher():
all_files_downloaded = True
for fname in local_fnames:
if not op.exists(op.join(folder, fname)):
all_files_downloaded = False
if all_files_downloaded:
return local_fnames, folder
files = {}
for i, (f, n), in enumerate(zip(remote_fnames, local_fnames)):
files[n] = (baseurl + f, md5_list[i] if
md5_list is not None else None)
return files, folder
else:
return _make_fetcher(
name, folder, baseurl, remote_fnames, local_fnames,
Expand Down Expand Up @@ -367,7 +372,7 @@ def read_resample_roi(roi, resample_to=None, threshold=False):
if isinstance(resample_to, str):
resample_to = nib.load(resample_to)

if np.allclose(resample_to.affine, roi.affine):
if resample_to is False or np.allclose(resample_to.affine, roi.affine):
return roi

as_array = resample(
Expand Down Expand Up @@ -857,6 +862,19 @@ def organize_stanford_data(path=None, clear_previous_afq=False):
"PipelineDescription": {"Name": "freesurfer"}})


fetch_stanford_hardi_lv1 = _make_reusable_fetcher(
"fetch_stanford_hardi_lv1",
op.join(afq_home,
'stanford_hardi',
'derivatives/freesurfer/sub-01/ses-01/anat'),
'https://stacks.stanford.edu/file/druid:ng782rw8378/',
["SUB1_LV1.nii.gz"],
["sub-01_ses-01_desc-LV1_anat.nii.gz"],
md5_list=["e403c602e53e5491414f86af5f08a913"],
doc="Download the LV1 segmentation for the Standord Hardi subject",
unzip=False)


fetch_hcp_atlas_16_bundles = _make_reusable_fetcher(
"fetch_hcp_atlas_16_bundles",
op.join(afq_home,
Expand Down
13 changes: 6 additions & 7 deletions AFQ/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,10 @@ def _get_bundle_info(self, bundle, vox_dim, tol):
self.mapping,
bundle_name=bundle)
else:
if isinstance(roi, str):
roi = nib.load(roi)
if isinstance(roi, nib.Nifti1Image):
roi = roi.get_fdata()
warped_roi = roi

if roi_type == 'include':
Expand Down Expand Up @@ -526,10 +530,6 @@ def _get_bundle_info(self, bundle, vox_dim, tol):

# The probability map if doesn't exist is all ones with the same
# shape as the ROIs:
if isinstance(roi, str):
roi = nib.load(roi)
if isinstance(roi, nib.Nifti1Image):
roi = roi.get_fdata()
prob_map = bundle_entry.get(
'prob_map', np.ones(roi.shape))

Expand All @@ -543,7 +543,6 @@ def _get_bundle_info(self, bundle, vox_dim, tol):
self.mapping.transform_inverse(
prob_map.copy(),
interpolation='nearest')

return warped_prob_map, include_rois, exclude_rois,\
include_roi_tols, exclude_roi_tols

Expand Down Expand Up @@ -733,7 +732,6 @@ def segment_afq(self, tg=None):
for end_type in ['start', 'end']:
if end_type in self.bundle_dict[bundle]:
warped_roi = self.bundle_dict[bundle][end_type]

# Create binary masks and warp these into subject's
# DWI space:
if "space" not in self.bundle_dict[bundle]\
Expand All @@ -758,7 +756,8 @@ def segment_afq(self, tg=None):
'endpoint_ROI',
bundle,
f'{end_type}point_as_used.nii.gz'))

else:
warped_roi = warped_roi.get_fdata()
atlas_idx.append(
np.array(np.where(warped_roi > 0)).T)
else:
Expand Down
12 changes: 7 additions & 5 deletions AFQ/tasks/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,17 +649,19 @@ def get_bundle_dict(base_fname, dwi, segmentation_params,
def roi_scalar_to_info(roi):
if not isinstance(roi, ImageDefinition):
return roi
roi.find_path(
bids_info["bids_layout"],
dwi,
bids_info["subject"],
bids_info["session"])
if bids_info is not None:
roi.find_path(
bids_info["bids_layout"],
dwi,
bids_info["subject"],
bids_info["session"])
roi_img, _ = roi.get_image_direct(
dwi, bids_info, b0, data_imap=None)
return roi_img
for b_name, b_info in bundle_dict._dict.items():
if "space" in b_info and b_info["space"] == "subject":
bundle_dict.apply_to_rois(b_name, roi_scalar_to_info)
bundle_dict._resample_roi(b_name)
return bundle_dict, reg_template


Expand Down
4 changes: 4 additions & 0 deletions AFQ/tasks/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def export_rois(base_fname, results_dir, data_imap, mapping, dwi_affine):
mapping,
bundle_name=bundle)
else:
if isinstance(roi, str):
roi = nib.load(roi)
if isinstance(roi, nib.Nifti1Image):
roi = roi.get_fdata()
warped_roi = roi

# Cast to float32,
Expand Down
36 changes: 30 additions & 6 deletions AFQ/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from AFQ.definitions.image import RoiImage,\
PFTImage, ImageFile
from AFQ.definitions.mapping import SynMap, AffMap, SlrMap
from AFQ.definitions.image import TemplateImage, ImageFile
from AFQ.definitions.image import TemplateImage, ImageFile, LabelledImageFile


def touch(fname, times=None):
Expand Down Expand Up @@ -735,10 +735,34 @@ def test_AFQ_data_waypoint():
t1_path)
shutil.copy(t1_path, t1_path_other)

vista_folder = op.join(
bids_path,
"derivatives/vistasoft/sub-01/ses-01/dwi")
freesurfer_folder = op.join(
bids_path,
"derivatives/freesurfer/sub-01/ses-01/anat")
lv1_files, lv1_folder = afd.fetch_stanford_hardi_lv1()
lv1_fname = op.join(
lv1_folder,
list(lv1_files.keys())[0])
seg_fname = op.join(
freesurfer_folder,
"sub-01_ses-01_seg.nii.gz")
bundle_names = [
"SLF_L", "SLF_R", "ARC_L", "ARC_R", "CST_L", "CST_R", "FP"]
bundle_info = BundleDict(bundle_names)
bundle_info = BundleDict(
bundle_names,
resample_subject_to=nib.load(
op.join(vista_folder, "sub-01_ses-01_dwi.nii.gz")))
del bundle_info["SLF_L"]["include"] # test endpoint ROIs as include
bundle_info["LV1"] = {
"include": [
ImageFile(path=lv1_fname),
LabelledImageFile(
path=seg_fname,
inclusive_labels=[71])],
"space": "subject"
}

tracking_params = dict(odf_model="csd",
seed_mask=RoiImage(),
Expand All @@ -753,9 +777,6 @@ def test_AFQ_data_waypoint():

clean_params = dict(return_idx=True)

vista_folder = op.join(
bids_path,
"derivatives/vistasoft/sub-01/ses-01/dwi")
afq_folder = op.join(bids_path, "derivatives/afq/sub-01/ses-01")
os.makedirs(afq_folder, exist_ok=True)
myafq = ParticipantAFQ(
Expand Down Expand Up @@ -856,13 +877,16 @@ def test_AFQ_data_waypoint():
max_length=1000,
random_seeds=True,
rng_seed=42)
bundle_dict_as_str = (
'BundleDict(["SLF_L", "SLF_R", "ARC_L", '
'"ARC_R", "CST_L", "CST_R", "FP"])')
config = dict(
BIDS_PARAMS=dict(
bids_path=bids_path,
preproc_pipeline='vistasoft'),
DATA=dict(
robust_tensor_fitting=True,
bundle_info=bundle_names),
bundle_info=bundle_dict_as_str),
SEGMENTATION=dict(
scalars=[
"dti_fa",
Expand Down
2 changes: 2 additions & 0 deletions AFQ/tests/test_tractography.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os.path as op
import random
import numpy as np
import numpy.testing as npt
import pytest
Expand Down Expand Up @@ -28,6 +29,7 @@


def test_csd_local_tracking():
random.seed(1234)
for sh_order in [4, 8, 10]:
fname = fit_csd(fdata, fbval, fbvec,
response=((0.0015, 0.0003, 0.0003), 100),
Expand Down
Empty file modified bin/pyAFQ
100644 → 100755
Empty file.

0 comments on commit bed4915

Please sign in to comment.