diff --git a/AFQ/api.py b/AFQ/api.py index 76407d061..c46ca429c 100644 --- a/AFQ/api.py +++ b/AFQ/api.py @@ -442,7 +442,7 @@ def __init__(self, bundle_names = BUNDLES # Create the bundle dict after reg_template has been resolved: - self.reg_template_img, _ = self._reg_img(self.reg_template) + self.reg_template_img, _ = self._reg_img(self.reg_template, False) self.bundle_dict = make_bundle_dict(bundle_names=bundle_names, seg_algo=self.seg_algo, resample_to=self.reg_template_img) @@ -823,11 +823,10 @@ def _get_best_scalar(self): return scalar return self.scalars[0] - def _reg_img(self, img, row=None): + def _reg_img(self, img, mask, row=None): if row is not None and row["reg_subject"] is not None: - return nib.load(row["reg_subject"]), None - - if isinstance(img, str): + img = nib.load(row["reg_subject"]) + elif isinstance(img, str): img_l = img.lower() if img_l == "mni_t2": img = afd.read_mni_template( @@ -868,13 +867,22 @@ def _reg_img(self, img, row=None): else: img = nib.load(img) + if mask: + brain_mask_file = self._brain_mask(row) + brain_mask = nib.load(brain_mask_file).get_fdata().astype(bool) + + masked_data = img.get_fdata() + masked_data[~brain_mask] = 0 + + img = nib.Nifti1Image(masked_data, img.affine) + return img, None def _reg_prealign(self, row): prealign_file = self._get_fname( row, '_prealign_from-DWI_to-MNI_xfm.npy') if self.force_recompute or not op.exists(prealign_file): - reg_subject_img, _ = self._reg_img(self.reg_subject, row) + reg_subject_img, _ = self._reg_img(self.reg_subject, True, row) _, aff = reg.affine_registration( reg_subject_img.get_fdata(), self.reg_template_img.get_fdata(), @@ -936,9 +944,9 @@ def _mapping(self, row): reg_prealign = None reg_template_img, reg_template_sls = \ - self._reg_img(self.reg_template, row) + self._reg_img(self.reg_template, False, row) reg_subject_img, reg_subject_sls = \ - self._reg_img(self.reg_subject, row) + self._reg_img(self.reg_subject, True, row) start_time = time() if self.reg_algo == "slr":