Skip to content

Commit

Permalink
Merge pull request #478 from 36000/apply_brain_mask_for_registration
Browse files Browse the repository at this point in the history
Apply brain mask to subject img before registration
  • Loading branch information
arokem committed Sep 29, 2020
2 parents af44bf6 + 39a039c commit e5ea2a9
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions AFQ/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def __init__(self,
bundle_names = BUNDLES

# Create the bundle dict after reg_template has been resolved:
self.reg_template_img, _ = self._reg_img(self.reg_template)
self.reg_template_img, _ = self._reg_img(self.reg_template, False)
self.bundle_dict = make_bundle_dict(bundle_names=bundle_names,
seg_algo=self.seg_algo,
resample_to=self.reg_template_img)
Expand Down Expand Up @@ -823,11 +823,10 @@ def _get_best_scalar(self):
return scalar
return self.scalars[0]

def _reg_img(self, img, row=None):
def _reg_img(self, img, mask, row=None):
if row is not None and row["reg_subject"] is not None:
return nib.load(row["reg_subject"]), None

if isinstance(img, str):
img = nib.load(row["reg_subject"])
elif isinstance(img, str):
img_l = img.lower()
if img_l == "mni_t2":
img = afd.read_mni_template(
Expand Down Expand Up @@ -868,13 +867,22 @@ def _reg_img(self, img, row=None):
else:
img = nib.load(img)

if mask:
brain_mask_file = self._brain_mask(row)
brain_mask = nib.load(brain_mask_file).get_fdata().astype(bool)

masked_data = img.get_fdata()
masked_data[~brain_mask] = 0

img = nib.Nifti1Image(masked_data, img.affine)

return img, None

def _reg_prealign(self, row):
prealign_file = self._get_fname(
row, '_prealign_from-DWI_to-MNI_xfm.npy')
if self.force_recompute or not op.exists(prealign_file):
reg_subject_img, _ = self._reg_img(self.reg_subject, row)
reg_subject_img, _ = self._reg_img(self.reg_subject, True, row)
_, aff = reg.affine_registration(
reg_subject_img.get_fdata(),
self.reg_template_img.get_fdata(),
Expand Down Expand Up @@ -937,9 +945,9 @@ def _mapping(self, row):
reg_prealign = None

reg_template_img, reg_template_sls = \
self._reg_img(self.reg_template, row)
self._reg_img(self.reg_template, False, row)
reg_subject_img, reg_subject_sls = \
self._reg_img(self.reg_subject, row)
self._reg_img(self.reg_subject, True, row)

start_time = time()
if self.reg_algo == "slr":
Expand Down

0 comments on commit e5ea2a9

Please sign in to comment.