Skip to content

Commit

Permalink
Fix signal recovery metrics when using different resolutions between …
Browse files Browse the repository at this point in the history
…anat and fm
  • Loading branch information
4rnaudB committed Jan 22, 2024
1 parent 68697e7 commit b2fac63
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 24 deletions.
68 changes: 46 additions & 22 deletions shimmingtoolbox/shim/sequencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,24 +419,18 @@ def eval(self, coef):
merged_coils = self.optimizer.merged_coils

shimmed, corrections, list_shim_slice = self.evaluate_shimming(unshimmed, coef, merged_coils)

shimmed_masked, mask_full_binary = self.calc_shimmed_full_mask(unshimmed, corrections)
if self.path_output is not None:
# fmap space
# Merge the i_shim into one single fieldmap shimmed (correction applied only where it will be applied on
# the fieldmap)
if self.opt_criteria == 'grad':
full_shimmed = np.zeros(unshimmed.shape)
full_Gz = np.zeros(unshimmed.shape)
full_Gz = np.zeros(corrections.shape)
for i_shim in range(len(self.slices)):
slc = self.slices[i_shim]
print('Currently writing: i_shim is ' + str(i_shim) + ', corespondingly, slices are: ' + str(slc))
shimmed_temp = corrections[..., i_shim] + unshimmed
grad_temp = np.gradient(shimmed_temp, axis = 2)
full_Gz[:,:,slc] = grad_temp[:,:,slc]
# Apply the correction weighted according to the mask
full_shimmed[:,:,slc] = shimmed_temp[:,:,slc]
full_Gz[..., i_shim] = np.gradient(shimmed_temp, axis = 2)

shimmed_masked, mask_full_binary = self.calc_shimmed_full_mask(unshimmed, corrections)
full_Gz, _ = self.calc_shimmed_full_mask(np.zeros_like(unshimmed), full_Gz)
# eroded_mask_binary = erode_binary_mask(mask_full_binary,shape='sphere',size=3)
if len(self.slices) == 1:
# TODO: Output json sidecar
Expand All @@ -457,8 +451,8 @@ def eval(self, coef):
# TODO: Add in anat space?
if self.opt_criteria == 'grad':
# Plot gradient realted results
self._plot_static_signal_recovery_mask(unshimmed, full_shimmed, full_Gz, mask_full_binary, self.path_output, self.epi_te)
self._plot_T2_star_mask(unshimmed, full_shimmed, full_Gz, mask_full_binary, self.path_output, self.epi_te)
self._plot_static_signal_recovery_mask(unshimmed, shimmed_masked, full_Gz, mask_full_binary, self.path_output, self.epi_te)
self._plot_T2_star_mask(unshimmed, shimmed_masked, full_Gz, mask_full_binary, self.path_output, self.epi_te)

# Figure that shows unshimmed vs shimmed for each slice
self.plot_full_mask(unshimmed, shimmed_masked, mask_full_binary)
Expand Down Expand Up @@ -609,6 +603,34 @@ def calc_shimmed_full_mask(self, unshimmed, correction):

return shimmed_masked, mask_full_binary

def calc_shimmed_gradient_full_mask(self, gradient):
"""
Calculate the shimmed gradient full mask
Args:
gradient (np.ndarray): Gradient of each shimmed fieldmap slice
Returns:
(tuple) : tuple containing:
* np.ndarray: Masked shimmed fieldmap
* np.ndarray: Binary mask in the fieldmap space
"""
mask_full_binary = np.clip(np.ceil(resample_from_to(self.nii_mask_anat,
self.nii_fieldmap_orig,
order=0,
mode='grid-constant',
cval=0).get_fdata()), 0, 1)

full_correction = np.einsum('ijkl,ijkl->ijk', self.masks_fmap, gradient, optimize='optimizer')
# Calculate the weighted whole mask
mask_weight = np.sum(self.masks_fmap, axis=3)
# Divide by the weighted mask. This is done so that the edges of the soft mask can be shimmed appropriately
full_correction_scaled = np.divide(full_correction, mask_weight, where=mask_full_binary.astype(bool))

# Apply the correction to the unshimmed image
shimmed_masked = full_correction_scaled * mask_full_binary

return shimmed_masked, mask_full_binary

def plot_full_mask(self, unshimmed, shimmed_masked, mask):
"""
Plot and save the static full mask
Expand Down Expand Up @@ -790,20 +812,22 @@ def calculate_signal_loss(B0_map):
mask_erode = erode_binary_mask(mask,shape='sphere',size=3)

# choose selected slices to plot
nonzero_indices = np.nonzero(np.sum(mask_erode,axis=(0,1)))[0];
nonzero_indices = np.nonzero(np.sum(mask_erode,axis=(0,1)))[0]
mt_unshimmed = montage(unshimmed_signal_loss[:,:,nonzero_indices])
mt_unshimmed_masked = montage(unshimmed_signal_loss[:,:,nonzero_indices]*mask_erode[:,:,nonzero_indices])
mt_shimmed_masked = montage(shimmed_signal_loss[:,:,nonzero_indices]*mask_erode[:,:,nonzero_indices])

metric_unshimmed_std = calculate_metric_within_mask(unshimmed_signal_loss, mask_erode, metric='std')
metric_shimmed_std = calculate_metric_within_mask(shimmed_signal_loss, mask_erode, metric='std')
metric_unshimmed_mean = calculate_metric_within_mask(unshimmed_signal_loss, mask_erode, metric='mean')
metric_shimmed_mean = calculate_metric_within_mask(shimmed_signal_loss, mask_erode, metric='mean')
metric_unshimmed_absmean = calculate_metric_within_mask(np.abs(unshimmed_signal_loss), mask_erode, metric='mean')
metric_shimmed_absmean = calculate_metric_within_mask(np.abs(shimmed_signal_loss), mask_erode, metric='mean')

min_value = min(mt_unshimmed_masked.min(), mt_shimmed_masked.min())
max_value = max(mt_unshimmed_masked.max(), mt_shimmed_masked.max())
temp_unshimmed_signal_loss = unshimmed_signal_loss.copy()
temp_unshimmed_signal_loss[unshimmed_signal_loss < 0.1] = np.nan
temp_shimmed_signal_loss = shimmed_signal_loss.copy()
temp_shimmed_signal_loss[unshimmed_signal_loss < 0.1] = np.nan

metric_unshimmed_std = calculate_metric_within_mask(temp_unshimmed_signal_loss, mask_erode, metric='std')
metric_shimmed_std = calculate_metric_within_mask(temp_shimmed_signal_loss, mask_erode, metric='std')
metric_unshimmed_mean = calculate_metric_within_mask(temp_unshimmed_signal_loss, mask_erode, metric='mean')
metric_shimmed_mean = calculate_metric_within_mask(temp_shimmed_signal_loss, mask_erode, metric='mean')
metric_unshimmed_absmean = calculate_metric_within_mask(np.abs(temp_unshimmed_signal_loss), mask_erode, metric='mean')
metric_shimmed_absmean = calculate_metric_within_mask(np.abs(temp_shimmed_signal_loss), mask_erode, metric='mean')

fig = Figure(figsize=(60, 30)) #make the figure larger and higher resolution
fig.suptitle(f"Signal Percentage Loss Map\nFieldmap Coordinate System")
Expand Down
2 changes: 1 addition & 1 deletion shimmingtoolbox/shim/shim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def calculate_metric_within_mask(array, mask, metric='mean', axis=None):
np.ndarray: Array containing the output metrics, if axis is None, the output is a single value
"""
ma_array = np.ma.array(array, mask=mask == False)

ma_array = np.ma.array(ma_array, mask=np.isnan(ma_array))
if metric == 'mean':
output = np.ma.mean(ma_array, axis=axis)
elif metric == 'std':
Expand Down
4 changes: 3 additions & 1 deletion test/cli/test_cli_b0shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def test_cli_dynamic_coils(self, nii_fmap, nii_anat, nii_mask, fm_data, anat_dat
'--fmap', fname_fmap,
'--anat', fname_anat,
'--mask', fname_mask,
'--output', tmp],
'--output', tmp,
'--optimizer-method', 'least_squares',
'--optimizer-criteria', 'grad'],
catch_exceptions=False)

assert res.exit_code == 0
Expand Down

0 comments on commit b2fac63

Please sign in to comment.