Skip to content

Commit

Permalink
add signal recovery function
Browse files Browse the repository at this point in the history
  • Loading branch information
yixinma9 committed Jan 13, 2023
1 parent f577494 commit 3c2920f
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 23 deletions.
35 changes: 29 additions & 6 deletions shimmingtoolbox/cli/b0shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,19 @@ def b0shim_cli():
help="Regularization factor for the current when optimizing. A higher coefficient will penalize higher "
"current values while 0 provides no regularization. Not relevant for 'pseudo-inverse' "
"optimizer_method.")
@click.option('--optimizer-criteria', 'opt_criteria', type=click.Choice(['mse', 'mae']), required=False,
########################################################################################################################
@click.option('--weighting-signal-loss', 'w_signal_loss', type=click.FLOAT, required=False, default=0.0, show_default=True,
help="weighting for signal loss recovery. Since there is generally a compromise between B0 inhomogeneity"
" and signal loss recovery, a higher coefficient will put more weights to recover the signal loss over "
"the B0 inhomogeneity.")
@click.option('--epi_echo_time', 'epi_te', type=click.FLOAT, required=False, default=0.0, show_default=True,
help="EPI acquistion parameter Echo Time (TE).")
# add grad option
@click.option('--optimizer-criteria', 'opt_criteria', type=click.Choice(['mse', 'mae','grad']), required=False,
default='mse', show_default=True,
help="Criteria of optimization for the optimizer 'least_squares'."
" mse: Mean Squared Error, mae: Mean Absolute Error")
" mse: Mean Squared Error, mae: Mean Absolute Error, grad: Signal Loss")
########################################################################################################################
@click.option('--mask-dilation-kernel-size', 'dilation_kernel_size', type=click.INT, required=False, default='3',
show_default=True,
help="Number of voxels to consider outside of the masked area. For example, when doing dynamic shimming "
Expand Down Expand Up @@ -128,7 +137,7 @@ def b0shim_cli():
@timeit
def dynamic(fname_fmap, fname_anat, fname_mask_anat, method, opt_criteria, slices, slice_factor, coils,
dilation_kernel_size, scanner_coil_order, fname_sph_constr, fatsat, path_output, o_format_coil,
o_format_sph, output_value_format, reg_factor, verbose):
o_format_sph, output_value_format, reg_factor, w_signal_loss, epi_te, verbose):
""" Static shim by fitting a fieldmap. Use the option --optimizer-method to change the shimming algorithm used to
optimize. Use the options --slices and --slice-factor to change the shimming order/size of the slices.
Expand Down Expand Up @@ -260,6 +269,8 @@ def dynamic(fname_fmap, fname_anat, fname_mask_anat, method, opt_criteria, slice
mask_dilation_kernel='sphere',
mask_dilation_kernel_size=dilation_kernel_size,
reg_factor=reg_factor,
w_signal_loss=w_signal_loss,
epi_te=epi_te,
path_output=path_output)

# Output
Expand Down Expand Up @@ -486,10 +497,20 @@ def _save_to_text_file_static(coil, coefs, list_slices, path_output, o_format, o
@click.option('--optimizer-method', 'method', type=click.Choice(['least_squares', 'pseudo_inverse']), required=False,
default='least_squares', show_default=True,
help="Method used by the optimizer. LS will respect the constraints, PS will not respect the constraints")
@click.option('--optimizer-criteria', 'opt_criteria', type=click.Choice(['mse', 'mae']), required=False,
#########code added by Yixin############################################################################################
@click.option('--weighting-signal-loss', 'w_signal_loss', type=click.FLOAT, required=False, default=0.0, show_default=True,
help="weighting for signal loss recovery. Since there is generally a compromise between B0 inhomogeneity"
" and signal loss recovery, a higher coefficient will put more weights to recover the signal loss over "
"the B0 inhomogeneity.")
@click.option('--epi_echo_time', 'epi_te', type=click.FLOAT, required=False, default=0.0, show_default=True,
help="EPI acquistion parameter Echo Time (TE)")
#########################################################################################################################
# add grad option
@click.option('--optimizer-criteria', 'opt_criteria', type=click.Choice(['mse', 'mae','grad']), required=False,
default='mse', show_default=True,
help="Criteria of optimization for the optimizer 'least_squares'."
" mse: Mean Squared Error, mae: Mean Absolute Error")
" mse: Mean Squared Error, mae: Mean Absolute Error, grad: Signal Loss")
#########################################################################################################################
@click.option('--regularization-factor', 'reg_factor', type=click.FLOAT, required=False, default=0.0, show_default=True,
help="Regularization factor for the current when optimizing. A higher coefficient will penalize higher "
"current values while 0 provides no regularization. Not relevant for 'pseudo-inverse' "
Expand Down Expand Up @@ -539,7 +560,7 @@ def _save_to_text_file_static(coil, coefs, list_slices, path_output, o_format, o
def realtime_dynamic(fname_fmap, fname_anat, fname_mask_anat_static, fname_mask_anat_riro, fname_resp, method,
opt_criteria, slices, slice_factor, coils, dilation_kernel_size, scanner_coil_order,
fname_sph_constr, fatsat, path_output, o_format_coil, o_format_sph, output_value_format,
reg_factor, verbose):
reg_factor, w_signal_loss, epi_te, verbose):
""" Realtime shim by fitting a fieldmap to a pressure monitoring unit. Use the option --optimizer-method to change
the shimming algorithm used to optimize. Use the options --slices and --slice-factor to change the shimming
order/size of the slices.
Expand Down Expand Up @@ -659,6 +680,8 @@ def realtime_dynamic(fname_fmap, fname_anat, fname_mask_anat_static, fname_mask_
mask_dilation_kernel='sphere',
mask_dilation_kernel_size=dilation_kernel_size,
reg_factor=reg_factor,
w_signal_loss=w_signal_loss,
epi_te=epi_te,
path_output=path_output)

coefs_static, coefs_riro, mean_p, p_rms = out
Expand Down
4 changes: 3 additions & 1 deletion shimmingtoolbox/cli/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,11 @@ def sphere(fname_input, fname_output, radius, center, verbose):
"value is equal or less than this threshold. (default: 30)")
@click.option('-v', '--verbose', type=click.Choice(['info', 'debug']), default='info', help="Be more verbose")
def threshold(fname_input, output, thr, verbose):

# Set all loggers
set_all_loggers(verbose)
# test
logger.info("this works third time")

# Prepare the output
create_output_dir(output, is_file=True)
Expand Down
40 changes: 36 additions & 4 deletions shimmingtoolbox/optimizer/lsq_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
from shimmingtoolbox.coils.coil import Coil

ListCoil = List[Coil]
allowed_opt_criteria = ['mse', 'mae', 'std']
# add allowed criteria 'grad'
allowed_opt_criteria = ['mse', 'mae', 'std', 'grad']


class LsqOptimizer(Optimizer):
""" Optimizer object that stores coil profiles and optimizes an unshimmed volume given a mask.
Use optimize(args) to optimize a given mask. The algorithm uses a least squares solver to find the best shim.
Use optimize(args) to optimize a given mask. The algorithm uses a least squares solver to find the best shim.allowed_opt_criteria
It supports bounds for each channel as well as a bound for the absolute sum of the channels.
"""

def __init__(self, coils: ListCoil, unshimmed, affine, opt_criteria='mse', reg_factor=0):
def __init__(self, coils: ListCoil, unshimmed, affine, opt_criteria='mse', reg_factor=0, w_signal_loss=0, epi_te=0):
"""
Initializes coils according to input list of Coil
Expand All @@ -38,12 +39,18 @@ def __init__(self, coils: ListCoil, unshimmed, affine, opt_criteria='mse', reg_f
self._initial_guess_method = 'mean'
self.initial_coefs = None
self.reg_factor = reg_factor
self.w_signal_loss = w_signal_loss
self.epi_te = epi_te
self.reg_factor_channel = np.array([max(np.abs(bound)) for bound in self.merged_bounds])

lsq_residual_dict = {
allowed_opt_criteria[0]: self._residuals_mse,
allowed_opt_criteria[1]: self._residuals_mae,
allowed_opt_criteria[2]: self._residuals_std
allowed_opt_criteria[2]: self._residuals_std,
#############################################
####### Yixin add the following code ########
allowed_opt_criteria[3]: self._residuals_grad
#############################################
}
if opt_criteria in lsq_residual_dict:
self._criteria_func = lsq_residual_dict[opt_criteria]
Expand Down Expand Up @@ -86,6 +93,31 @@ def _residuals_mae(self, coef, unshimmed_vec, coil_mat, factor):
# MAE regularized to minimize currents
return np.mean(np.abs(unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False))) / factor + \
(self.reg_factor * np.mean(np.abs(coef) / self.reg_factor_channel))

def _residuals_grad(self, coef, unshimmed_vec, coil_mat, factor):
""" Objective function to minimize the mean squared error (MSE) and the signal loss function (gradient in z direction)
Args:
coef (numpy.ndarray): 1D array of channel coefficients
factor (float): Devise the result by 'factor'. This allows to scale the output for the minimize function to
avoid positive directional linesearch
Returns:
numpy.ndarray: Residuals for least squares optimization -- equivalent to flattened shimmed vector
"""
#print("w_signal_loss is: " + str(self.w_signal_loss) + "," + " epi_te is: " + str(self.epi_te) + " factor is" + str(factor))
nx,ny,nz,nc = np.shape(self.merged_coils)
shimmed = self.unshimmed + np.sum(self.merged_coils * np.tile(coef,(nx,ny,nz,1)),axis= 3) # need test
signal = 1
# if consider signal loss from x, y, and z
for i in range(0,3):
G = np.gradient(shimmed, axis = i)
signal = signal * np.sinc(self.epi_te * G)
# MSE regularized to minimize currents
#print("" + str(np.shape(signal)))
#print("in this round of optimization, residual from signal loss is : " + str(np.mean(1 - signal) * self.w_signal_loss) + ", residual from B0 inhomogeneity is: " + str(np.mean(np.abs(unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False)))) + ", residual from current is " + str((self.reg_factor * np.mean(np.abs(coef) / self.reg_factor_channel))))
return np.mean(1 - signal) * self.w_signal_loss + \
np.mean((unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False)) ** 2) / factor + \
(self.reg_factor * np.mean(np.abs(coef) / self.reg_factor_channel))

def _residuals_mse(self, coef, unshimmed_vec, coil_mat, factor):
""" Objective function to minimize the mean squared error (MSE)
Expand Down

0 comments on commit 3c2920f

Please sign in to comment.