Skip to content

Commit

Permalink
Resample input masks on the target anat when B0 shimming (#376)
Browse files Browse the repository at this point in the history
* Resample mask on the target anatomical images if they are different

* Update to output binary masks + debug statements

* Update tests
  • Loading branch information
po09i committed Apr 10, 2022
1 parent ab98259 commit 7a43b9a
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 54 deletions.
8 changes: 3 additions & 5 deletions shimmingtoolbox/cli/b0shim.py
Expand Up @@ -50,8 +50,7 @@ def b0shim_cli():
@click.option('--anat', 'fname_anat', type=click.Path(exists=True), required=True,
help="Anatomical image to apply the correction onto.")
@click.option('--mask', 'fname_mask_anat', type=click.Path(exists=True), required=False,
help="Mask defining the spatial region to shim."
"The coordinate system should be the same as ``anat``'s coordinate system.")
help="Mask defining the spatial region to shim.")
@click.option('--scanner-coil-order', type=click.Choice(['-1', '0', '1', '2']), default='-1', show_default=True,
help="Maximum order of the shim system. Note that specifying 1 will return "
"orders 0 and 1. The 0th order is the f0 frequency.")
Expand Down Expand Up @@ -442,11 +441,10 @@ def _save_to_text_file_static(coil, coefs, list_slices, path_output, o_format, o
@click.option('--resp', 'fname_resp', type=click.Path(exists=True), required=True,
help="Siemens respiratory file containing pressure data.")
@click.option('--mask-static', 'fname_mask_anat_static', type=click.Path(exists=True), required=False,
help="Mask defining the static spatial region to shim."
"The coordinate system should be the same as ``anat``'s coordinate system.")
help="Mask defining the static spatial region to shim.")
@click.option('--mask-riro', 'fname_mask_anat_riro', type=click.Path(exists=True), required=False,
help="Mask defining the time varying (i.e. RIRO, Respiration-Induced Resonance Offset) "
"region to shim. The coordinate system should be the same as ``anat``'s coordinate system.")
"region to shim.")
@click.option('--scanner-coil-order', type=click.Choice(['-1', '0', '1', '2']), default='-1', show_default=True,
help="Maximum order of the shim system. Note that specifying 1 will return "
"orders 0 and 1. The 0th order is the f0 frequency.")
Expand Down
56 changes: 37 additions & 19 deletions shimmingtoolbox/shim/sequencer.py
Expand Up @@ -18,6 +18,7 @@
from shimmingtoolbox.load_nifti import get_acquisition_times
from shimmingtoolbox.pmu import PmuResp
from shimmingtoolbox.masking.mask_utils import resample_mask
from shimmingtoolbox.masking.threshold import threshold
from shimmingtoolbox.coils.coordinates import resample_from_to
from shimmingtoolbox.utils import montage
from shimmingtoolbox.shim.shim_utils import calculate_metric_within_mask
Expand Down Expand Up @@ -94,16 +95,20 @@ def shim_sequencer(nii_fieldmap, nii_anat, nii_mask_anat, slices, coils: ListCoi
raise ValueError("Anatomical image must be in 3d")

# Make sure the mask has the appropriate dimensions
mask = nii_mask_anat.get_fdata()
if mask.ndim != 3:
if nii_mask_anat.get_fdata().ndim != 3:
raise ValueError("Mask image must be in 3d")

# Make sure shape and affine of mask are the same as the anat
if not np.all(mask.shape == anat.shape):
raise ValueError(f"Shape of mask:\n {mask.shape} must be the same as the shape of anat:\n{anat.shape}")
if not np.all(np.isclose(nii_mask_anat.affine, nii_anat.affine)):
raise ValueError(f"Affine of mask:\n{nii_mask_anat.affine}\nmust be the same as the affine of anat:\n"
f"{nii_anat.affine}")
# Resample the input mask on the target anatomical image if they are different
if not np.all(nii_mask_anat.shape == anat.shape) or not np.all(nii_mask_anat.affine == nii_anat.affine):
logger.debug("Resampling mask on the target anat")
nii_mask_anat_soft = resample_from_to(nii_mask_anat, nii_anat, order=1, mode='grid-constant')
tmp_mask = nii_mask_anat_soft.get_fdata()
# Change soft mask into binary mask
tmp_mask = threshold(tmp_mask, thr=0.001)
nii_mask_anat = nib.Nifti1Image(tmp_mask, nii_mask_anat_soft.affine, header=nii_mask_anat_soft.header)

if logger.level <= getattr(logging, 'DEBUG') and path_output is not None:
nib.save(nii_mask_anat, os.path.join(path_output, "mask_static_resampled_on_anat.nii.gz"))

# Select and initialize the optimizer
optimizer = select_optimizer(method, fieldmap, affine_fieldmap, coils)
Expand Down Expand Up @@ -395,20 +400,33 @@ def shim_realtime_pmu_sequencer(nii_fieldmap, json_fmap, nii_anat, nii_static_ma
raise ValueError("Anatomical image must be in 3d")

# Make sure masks have the appropriate dimensions
static_mask = nii_static_mask.get_fdata()
if static_mask.ndim != 3:
if nii_static_mask.get_fdata().ndim != 3:
raise ValueError("static_mask image must be in 3d")
riro_mask = nii_riro_mask.get_fdata()
if riro_mask.ndim != 3:
if nii_riro_mask.get_fdata().ndim != 3:
raise ValueError("riro_mask image must be in 3d")

# Make sure shape and affine of masks are the same as the anat
if not (np.all(riro_mask.shape == anat.shape) and np.all(static_mask.shape == anat.shape)):
raise ValueError(f"Shape of riro mask: {riro_mask.shape} and static mask: {static_mask.shape} "
f"must be the same as the shape of anat: {anat.shape}")
if not (np.all(nii_riro_mask.affine == nii_anat.affine) and np.all(nii_static_mask.affine == nii_anat.affine)):
raise ValueError(f"Affine of riro mask:\n{nii_riro_mask.affine}\nand static mask: {nii_static_mask.affine}\n"
f"must be the same as the affine of anat:\n{nii_anat.affine}")
# Resample the input masks on the target anatomical image if they are different
if not np.all(nii_static_mask.shape == anat.shape) or not np.all(nii_static_mask.affine == nii_anat.affine):
logger.debug("Resampling static mask on the target anat")
nii_static_mask_soft = resample_from_to(nii_static_mask, nii_anat, order=1, mode='grid-constant')
tmp_mask = nii_static_mask_soft.get_fdata()
# Change soft mask into binary mask
tmp_mask = threshold(tmp_mask, thr=0.001)
nii_static_mask = nib.Nifti1Image(tmp_mask, nii_static_mask_soft.affine, header=nii_static_mask_soft.header)

if logger.level <= getattr(logging, 'DEBUG') and path_output is not None:
nib.save(nii_static_mask, os.path.join(path_output, "mask_static_resampled_on_anat.nii.gz"))

if not np.all(nii_riro_mask.shape == anat.shape) or not np.all(nii_riro_mask.affine == nii_anat.affine):
logger.debug("Resampling riro mask on the target anat")
nii_riro_mask_soft = resample_from_to(nii_riro_mask, nii_anat, order=1, mode='grid-constant')
tmp_mask = nii_riro_mask_soft.get_fdata()
# Change soft mask into binary mask
tmp_mask = threshold(tmp_mask, thr=0.001)
nii_riro_mask = nib.Nifti1Image(tmp_mask, nii_riro_mask_soft.affine, header=nii_riro_mask_soft.header)

if logger.level <= getattr(logging, 'DEBUG') and path_output is not None:
nib.save(nii_riro_mask, os.path.join(path_output, "mask_riro_resampled_on_anat.nii.gz"))

# Fetch PMU timing
acq_timestamps = get_acquisition_times(nii_fieldmap, json_fmap)
Expand Down
64 changes: 34 additions & 30 deletions test/shim/test_sequencer.py
Expand Up @@ -220,21 +220,23 @@ def test_shim_sequencer_wrong_mask_dim(self, nii_fieldmap, nii_anat, nii_mask, s
# shim_sequencer(nii_fieldmap, nii_anat, nii_mask, slices, [sph_coil, sph_coil2])
# assert "The coils don't have matching units:" in caplog.text

def test_shim_sequencer_wrong_mask_affine(self, nii_fieldmap, nii_anat, nii_mask, sph_coil, sph_coil2):
def test_shim_sequencer_diff_mask_affine(self, nii_fieldmap, nii_anat, nii_mask, sph_coil, sph_coil2):
# Optimize
slices = [(0, 2), (1,)]
wrong_affine = nii_mask.affine
wrong_affine[0, 0] = 100
nii_wrong_mask = nib.Nifti1Image(nii_mask.get_fdata(), wrong_affine, header=nii_mask.header)
with pytest.raises(ValueError, match="Affine of mask:"):
shim_sequencer(nii_fieldmap, nii_anat, nii_wrong_mask, slices, [sph_coil])
diff_affine = nii_mask.affine
diff_affine[0, 0] = 2
nii_diff_mask = nib.Nifti1Image(nii_mask.get_fdata(), diff_affine, header=nii_mask.header)

currents = shim_sequencer(nii_fieldmap, nii_anat, nii_diff_mask, slices, [sph_coil])
assert_results(nii_fieldmap, nii_anat, nii_diff_mask, [sph_coil], currents, slices)

def test_shim_sequencer_wrong_mask_shape(self, nii_fieldmap, nii_anat, nii_mask, sph_coil, sph_coil2):
def test_shim_sequencer_diff_mask_shape(self, nii_fieldmap, nii_anat, nii_mask, sph_coil, sph_coil2):
# Optimize
slices = [(0, 2), (1,)]
nii_wrong_mask = nib.Nifti1Image(nii_mask.get_fdata()[:5, ...], nii_mask.affine, header=nii_mask.header)
with pytest.raises(ValueError, match="Shape of mask:"):
shim_sequencer(nii_fieldmap, nii_anat, nii_wrong_mask, slices, [sph_coil])
nii_diff_mask = nib.Nifti1Image(nii_mask.get_fdata()[5:, ...], nii_mask.affine, header=nii_mask.header)

currents = shim_sequencer(nii_fieldmap, nii_anat, nii_diff_mask, slices, [sph_coil])
assert_results(nii_fieldmap, nii_anat, nii_diff_mask, [sph_coil], currents, slices)

# def test_speed_huge_matrix(self, nii_fieldmap, nii_anat, nii_mask, sph_coil, sph_coil2):
# # Create 1 huge coil which essentially is siemens basis concatenated 4 times
Expand Down Expand Up @@ -495,32 +497,34 @@ def test_shim_sequencer_rt_wrong_anat_dim(self, nii_fieldmap, json_data, nii_ana
shim_realtime_pmu_sequencer(nii_fieldmap, json_data, nii_wrong_anat, nii_mask_static, nii_mask_riro,
slices, pmu, [coil])

def test_shim_sequencer_rt_wrong_mask_dim(self, nii_fieldmap, json_data, nii_anat, nii_mask_static,
nii_mask_riro, slices, pmu, coil):
def test_shim_sequencer_rt_diff_mask_shape_static(self, nii_fieldmap, json_data, nii_anat, nii_mask_static,
nii_mask_riro, slices, pmu, coil):
# Optimize
nii_wrong_mask = nib.Nifti1Image(nii_mask_static.get_fdata()[:5, ...], nii_mask.affine, header=nii_mask.header)
with pytest.raises(ValueError, match="Shape of riro mask"):
shim_realtime_pmu_sequencer(nii_fieldmap, json_data, nii_anat, nii_wrong_mask, nii_mask_riro,
slices, pmu, [coil])
nii_diff_mask = nib.Nifti1Image(nii_mask_static.get_fdata()[5:, ...], nii_mask_static.affine,
header=nii_mask_static.header)
output = shim_realtime_pmu_sequencer(nii_fieldmap, json_data, nii_anat, nii_diff_mask, nii_mask_riro,
slices, pmu, [coil])
assert output[0].shape == (20, 3)

def test_shim_sequencer_rt_wrong_mask_affine(self, nii_fieldmap, json_data, nii_anat, nii_mask_static,
nii_mask_riro, slices, pmu, coil):
def test_shim_sequencer_rt_diff_mask_shape_riro(self, nii_fieldmap, json_data, nii_anat, nii_mask_static,
nii_mask_riro, slices, pmu, coil):
# Optimize
wrong_affine = nii_mask.affine
wrong_affine[0, 0] = 100
nii_wrong_mask = nib.Nifti1Image(nii_mask_static.get_fdata(), wrong_affine, header=nii_mask_static.header)
with pytest.raises(ValueError, match="Affine of riro mask:"):
shim_realtime_pmu_sequencer(nii_fieldmap, json_data, nii_anat, nii_wrong_mask, nii_mask_riro,
slices, pmu, [coil])
nii_diff_mask = nib.Nifti1Image(nii_mask_riro.get_fdata()[5:, ...], nii_mask_riro.affine,
header=nii_mask_riro.header)

output = shim_realtime_pmu_sequencer(nii_fieldmap, json_data, nii_anat, nii_mask_static, nii_diff_mask,
slices, pmu, [coil])
assert output[0].shape == (20, 3)

def test_shim_sequencer_rt_wrong_mask_shape(self, nii_fieldmap, json_data, nii_anat, nii_mask_static,
def test_shim_sequencer_rt_diff_mask_affine(self, nii_fieldmap, json_data, nii_anat, nii_mask_static,
nii_mask_riro, slices, pmu, coil):
# Optimize
nii_wrong_mask = nib.Nifti1Image(nii_mask_static.get_fdata()[:5, ...], nii_mask_static.affine,
header=nii_mask_static.header)
with pytest.raises(ValueError, match="Shape of riro mask:"):
shim_realtime_pmu_sequencer(nii_fieldmap, json_data, nii_anat, nii_wrong_mask, nii_mask_riro,
slices, pmu, [coil])
diff_affine = nii_mask.affine
diff_affine[0, 0] = 2
nii_diff_mask = nib.Nifti1Image(nii_mask_static.get_fdata(), diff_affine, header=nii_mask_static.header)
output = shim_realtime_pmu_sequencer(nii_fieldmap, json_data, nii_anat, nii_diff_mask, nii_mask_riro,
slices, pmu, [coil])
assert output[0].shape == (20, 3)


def test_shim_realtime_pmu_sequencer_rt_zshim_data():
Expand Down

0 comments on commit 7a43b9a

Please sign in to comment.