Skip to content

Commit

Permalink
Faster B0 shimming (#423)
Browse files Browse the repository at this point in the history
* Add a faster solver.

Add a new class to allow the user to optimize the currents with scipy.optimize.minimize by using the residuals_mse function, and the jacobian of it

* Use the smaller dataset

* Faster the sequencer
1) Removed the multiprocessing, if the user is using a mac (Spawn method makes the optimization slower)
2) If there is no correction in a slice then it's possible to reduce the calculation time of _eval_static_shim and _cal_anat_orient. By not making all of the time consumming process.

* Make _plot_coef faster
Removes all of the useless subplot where there is no corrections.
Change the size of the figure to make it smaller if there is less subplot to show
Add a subplot to show an example of the current correction in an unshimmed slice
If the bounds are the same in every channel, then don't display it for all of them, but instead plot a single line to represent all of them

* Creating the Shimmed anat orient faster
-Reshape the coils_anat matrix into a 2D matrix
-Change the dimension of the correction matrix to a 2D one (This allow to make the calculation a lot faster, by removing a dimension for the np.sum)
- Make a list of the slices to shim to make the code faster and more readable

* Fixed a typo mistake in calc_shimmed_anat_orient

* Make sure that the merge with master is working
Use the new opt_criteria from master
In Eval_static_shim, only print in the debugging if there is a correction in a slice

* Make sure that the optimization is working with the merge
Inplement a jacobian when the function to minimize is mse
Delete lsq_faster_solver which is now obsolet

* Fix all of the problem with the merging
Removed all of the comment about least-squares-faster
Fixed some typo problem
Put opt_criteria, everytime that the optimize function is called

* Removed the jacobian in real time shimming

* Removed an useless test

* Syntax of comment

Co-authored-by: Alex Dastous <47249340+po09i@users.noreply.github.com>

* Fixed a mispell in a comment

Co-authored-by: Alex Dastous <47249340+po09i@users.noreply.github.com>

* Fixed syntax in a comment

Co-authored-by: Alex Dastous <47249340+po09i@users.noreply.github.com>

* Fixed syntax in a comment

Co-authored-by: Alex Dastous <47249340+po09i@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Alex Dastous <47249340+po09i@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Alex Dastous <47249340+po09i@users.noreply.github.com>

* Fixed a problem with the reviews changes

* Apply the review changes
Put a continue in the for loop of eval_static_shim
Apply the correct indentation

* Fixed the bug in plot_coefs + changes in calc_shimmed_anat_orient
make sure that the program doesn't crash if the slices dimension is different than 1

* Fixed the cases where all the slices are shim in plot_coefs

* Put Multiprocessing for every dataset (May change)

* Refactor _plot_coefs

* Refactor optimizer

* Small fixups

* Comment mp bug that outputs debug even when not in debug

* Move plotting of currents after saving the coefficients

* Remove newline

Co-authored-by: Alex Dastous <47249340+po09i@users.noreply.github.com>
Co-authored-by: Alexandre D'Astous <po09i@hotmail.com>
  • Loading branch information
3 people committed Jan 4, 2023
1 parent f577494 commit ceb5fd5
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 83 deletions.
162 changes: 122 additions & 40 deletions shimmingtoolbox/cli/b0shim.py
Expand Up @@ -67,7 +67,8 @@ def b0shim_cli():
"'--slice-factor' value is '3', then with the 'sequential' mode, shimming will be performed "
"independently on the following groups: {0,1,2}, {3,4,5}, etc. With the mode 'interleaved', "
"it will be: {0,2,4}, {1,3,5}, etc.")
@click.option('--optimizer-method', 'method', type=click.Choice(['least_squares', 'pseudo_inverse']), required=False,
@click.option('--optimizer-method', 'method', type=click.Choice(['least_squares', 'pseudo_inverse',
]), required=False,
default='least_squares', show_default=True,
help="Method used by the optimizer. LS will respect the constraints, PS will not respect the constraints")
@click.option('--regularization-factor', 'reg_factor', type=click.FLOAT, required=False, default=0.0, show_default=True,
Expand Down Expand Up @@ -252,7 +253,6 @@ def dynamic(fname_fmap, fname_anat, fname_mask_anat, method, opt_criteria, slice
else:
list_slices = define_slices(n_slices, slice_factor, slices)
logger.info(f"The slices to shim are:\n{list_slices}")

# Get shimming coefficients
coefs = shim_sequencer(nii_fmap_orig, nii_anat, nii_mask_anat, list_slices, list_coils,
method=method,
Expand Down Expand Up @@ -333,10 +333,25 @@ def dynamic(fname_fmap, fname_anat, fname_mask_anat, method, opt_criteria, slice
else:
list_fname_output += _save_to_text_file_static(coil, coefs_coil, list_slices, path_output, o_format_coil,
options, coil_number=i_coil)

logger.info(f"Coil txt file(s) are here:\n{os.linesep.join(list_fname_output)}")
logger.info(f"Plotting figure(s)")

# Plot the coefs after outputting the currents to the text file
end_channel = 0
for i_coil, coil in enumerate(list_coils):
# Figure out the start and end channels for a coil to be able to select it from the coefs
n_channels = coil.dim[3]
start_channel = end_channel
end_channel = start_channel + n_channels

if type(coil) != ScannerCoil:
# Select the coefficients for a coil
coefs_coil = copy.deepcopy(coefs[:, start_channel:end_channel])
# Plot a figure of the coefficients
_plot_coefs(coil, list_slices, coefs_coil, path_output, i_coil, bounds=coil.coef_channel_minmax)

logger.info(f"Coil txt file(s) are here:\n{os.linesep.join(list_fname_output)}")
logger.info(f"Finished plotting figure(s)")


def _save_to_text_file_static(coil, coefs, list_slices, path_output, o_format, options, coil_number,
Expand Down Expand Up @@ -746,15 +761,30 @@ def realtime_dynamic(fname_fmap, fname_anat, fname_mask_anat_static, fname_mask_
path_output, o_format_sph, options, i_coil)

else: # Custom coil
list_fname_output += _save_to_text_file_rt(coil, coefs_coil_static, coefs_coil_riro, mean_p, list_slices,
path_output, o_format_coil, options, i_coil)

logger.info(f"Coil txt file(s) are here:\n{os.linesep.join(list_fname_output)}")
logger.info(f"Plotting figure(s)")

# Plot the coefs after outputting the currents to the text file
end_channel = 0
for i_coil, coil in enumerate(list_coils):
# Figure out the start and end channels for a coil to be able to select it from the coefs
n_channels = coil.dim[3]
start_channel = end_channel
end_channel = start_channel + n_channels

if type(coil) != ScannerCoil:
# Select the coefficients for a coil
coefs_coil_static = copy.deepcopy(coefs_static[:, start_channel:end_channel])
coefs_coil_riro = copy.deepcopy(coefs_riro[:, start_channel:end_channel])
# Plot a figure of the coefficients
_plot_coefs(coil, list_slices, coefs_coil_static, path_output, i_coil, coefs_coil_riro,
pres_probe_max=pmu.max - mean_p, pres_probe_min=pmu.min - mean_p,
bounds=coil.coef_channel_minmax)

list_fname_output += _save_to_text_file_rt(coil, coefs_coil_static, coefs_coil_riro, mean_p, list_slices,
path_output, o_format_coil, options, i_coil)

logger.info(f"Coil txt file(s) are here:\n{os.linesep.join(list_fname_output)}")
logger.info(f"Finished plotting figure(s)")


def _save_to_text_file_rt(coil, currents_static, currents_riro, mean_p, list_slices, path_output, o_format,
Expand Down Expand Up @@ -978,11 +1008,29 @@ def _get_current_shim_settings(json_data):
@timeit
def _plot_coefs(coil, slices, static_coefs, path_output, coil_number, rt_coefs=None, pres_probe_min=None,
pres_probe_max=None, units='', bounds=None):
n_shims = static_coefs.shape[0]
fig = Figure(figsize=(8, 4 * n_shims), tight_layout=True)
# Find which slices are not shimmed and group them (smaller file size and reduce the plot saving time)
shimmed_slice_index = []
n_shims = len(slices)
slices_index_wo_shim = []
unused_slice = False
for i_shim in range(n_shims):
# Static case
if np.any(static_coefs[i_shim]):
shimmed_slice_index.append(i_shim)
continue

# Realtime case
if rt_coefs is not None:
if np.any(rt_coefs[i_shim]):
shimmed_slice_index.append(i_shim)
continue

# Get a string with the number of all the unshimmed slices
slices_index_wo_shim.append(i_shim)
unused_slice = True

# Find min and max values of the plots
# Calculate the min and max of the bounds if its an input
# Calculate the min and max of the bounds if it's an input
if bounds is not None:
bounds = np.array(bounds)
min_y = bounds.min()
Expand Down Expand Up @@ -1019,31 +1067,69 @@ def _plot_coefs(coil, slices, static_coefs, path_output, coil_number, rt_coefs=N
if max_y is None or max_y < temp_max:
max_y = np.array(static_coefs).max()

# Create a plot for each shim group
for i_shim in range(n_shims):
ax = fig.add_subplot(n_shims + 1, 1, i_shim + 1)
n_channels = static_coefs.shape[1]
# Plot the currents
n_plots = len(shimmed_slice_index)
if unused_slice:
n_plots += 1

fig = Figure(figsize=(8, 4 * n_plots), tight_layout=True)
for i_plot, slice_index in enumerate(shimmed_slice_index):

# Add realtime component as an errorbar
if rt_coefs is not None:
rt_coef_ishim = rt_coefs[i_shim]
riro = [rt_coef_ishim * -pres_probe_min, rt_coef_ishim * pres_probe_max]
ax.errorbar(range(n_channels), static_coefs[i_shim], yerr=riro, fmt='o', elinewidth=4, capsize=6,
label='static-riro')
# Add static component
rt_coef_tmp = rt_coefs[slice_index]
else:
ax.scatter(range(n_channels), static_coefs[i_shim], marker='o', label='static')
rt_coef_tmp = None

_add_sub_figure(fig, i_plot + 1, n_plots, static_coefs[slice_index], bounds, min_y, max_y, units,
slices[slice_index], rt_coef_tmp, pres_probe_min, pres_probe_max)

# Add a subplot for all the non shimmed slices
if unused_slice:
i_unshimmed_slice = slices_index_wo_shim[0]
slices_wo_shim = tuple(j for i in slices_index_wo_shim for j in slices[i])
_add_sub_figure(fig, n_plots, n_plots, static_coefs[i_unshimmed_slice], bounds,
min_y, max_y, units, slices_wo_shim)

# Save the figure
fname_figure = os.path.join(path_output, f"fig_currents_per_slice_group_coil{coil_number}_{coil.name}.png")
fig.savefig(fname_figure, bbox_inches='tight')
logger.debug(f"Saved figure: {fname_figure}")

# Draw a black line at y=0
ax.hlines(0, 0, 1, transform=ax.get_yaxis_transform(), colors='k')

delta_y = max_y - min_y
def _add_sub_figure(fig, i_plot, n_plots, static_coefs, bounds, min_y, max_y, units, slice_number, rt_coefs=None,
pres_probe_min=None, pres_probe_max=None):
# Make a subplot for slices
# If it's the recap subplot for all the slices where the correction is null then we need to take an index further to
# not have visual problem

# Add bounds on the graph
if bounds is not None:
ax = fig.add_subplot(n_plots, 1, i_plot)
n_channels = len(static_coefs)

# Add realtime component as an errorbar
if rt_coefs is not None:
rt_coef_ishim = rt_coefs
riro = [rt_coef_ishim * -pres_probe_min, rt_coef_ishim * pres_probe_max]
ax.errorbar(range(n_channels), static_coefs, yerr=riro, fmt='o', elinewidth=4, capsize=6,
label='static-riro')
# Add static component
else:
ax.scatter(range(n_channels), static_coefs, marker='o', label='static')

# Draw a black line at y=0
ax.hlines(0, 0, 1, transform=ax.get_yaxis_transform(), colors='k')

delta_y = max_y - min_y
# Add bounds on the graph
if bounds is not None:
len_vline_bounds = 0.01
len_hline_bounds = 0.4
if np.all(bounds[:, 0] == bounds[0, 0]) and np.all(bounds[:, 1] == bounds[0, 1]):
ax.hlines(bounds[0, 0], 0 - len_hline_bounds, n_channels + len_hline_bounds, colors='r',
label='bounds', capstyle='projecting')
ax.hlines(bounds[0, 1], 0 - len_hline_bounds, n_channels + len_hline_bounds, colors='r',
capstyle='projecting')
else:
# Channel 0 used for the legend
len_vline_bounds = 0.01
len_hline_bounds = 0.4
# min
ax.hlines(bounds[0, 0], -len_hline_bounds, len_hline_bounds, colors='r', label='bounds',
capstyle='projecting')
Expand Down Expand Up @@ -1073,18 +1159,14 @@ def _plot_coefs(coil, slices, static_coefs, path_output, coil_number, rt_coefs=N
bounds[i_channel, 1], colors='r', capstyle='projecting')
ax.vlines(i_channel + len_hline_bounds, bounds[i_channel, 1] - (delta_y * len_vline_bounds),
bounds[i_channel, 1], colors='r', capstyle='projecting')

# Set the extent of the plot
ax.set(ylim=(min_y - (0.05 * delta_y), max_y + (0.05 * delta_y)), xlim=(-0.75, n_channels - 0.25),
xticks=range(n_channels))
ax.legend()
ax.set_title(f"Slices: {slices[i_shim]}, Total static current: {np.abs(static_coefs[i_shim]).sum()}")
ax.set_xlabel('Channels')
ax.set_ylabel(f"Coefficients {units}")

fname_figure = os.path.join(path_output, f"fig_currents_per_slice_group_coil{coil_number}_{coil.name}.png")
fig.savefig(fname_figure, bbox_inches='tight')
logger.debug(f"Saved figure: {fname_figure}")
# Set the extent of the plot
ax.set(ylim=(min_y - (0.05 * delta_y), max_y + (0.05 * delta_y)), xlim=(-0.75, n_channels - 0.25),
xticks=range(n_channels))
ax.legend()

ax.set_title(f"Slices: {slice_number}, Total static current: {np.abs(static_coefs).sum()}")
ax.set_xlabel('Channels')
ax.set_ylabel(f"Coefficients {units}")


@click.command(context_settings=CONTEXT_SETTINGS)
Expand Down
8 changes: 4 additions & 4 deletions shimmingtoolbox/masking/mask_utils.py
Expand Up @@ -59,10 +59,10 @@ def resample_mask(nii_mask_from, nii_target, from_slices=None, dilation_kernel='
mask_dilated_in_roi = np.logical_and(mask_dilated, nii_full_mask_target.get_fdata())
nii_mask_dilated = nib.Nifti1Image(mask_dilated_in_roi, nii_mask_target.affine, header=nii_mask_target.header)

if logger.level <= getattr(logging, 'DEBUG') and path_output is not None:
nib.save(nii_mask, os.path.join(path_output, f"fig_mask_{from_slices[0]}.nii.gz"))
nib.save(nii_mask_target, os.path.join(path_output, f"fig_mask_res{from_slices[0]}.nii.gz"))
nib.save(nii_mask_dilated, os.path.join(path_output, f"fig_mask_dilated{from_slices[0]}.nii.gz"))
# if logger.level <= getattr(logging, 'DEBUG') and path_output is not None:
# nib.save(nii_mask, os.path.join(path_output, f"fig_mask_{from_slices[0]}.nii.gz"))
# nib.save(nii_mask_target, os.path.join(path_output, f"fig_mask_res{from_slices[0]}.nii.gz"))
# nib.save(nii_mask_dilated, os.path.join(path_output, f"fig_mask_dilated{from_slices[0]}.nii.gz"))

return nii_mask_dilated

Expand Down
50 changes: 44 additions & 6 deletions shimmingtoolbox/optimizer/lsq_optimizer.py
Expand Up @@ -45,11 +45,21 @@ def __init__(self, coils: ListCoil, unshimmed, affine, opt_criteria='mse', reg_f
allowed_opt_criteria[1]: self._residuals_mae,
allowed_opt_criteria[2]: self._residuals_std
}
if opt_criteria in lsq_residual_dict:
lsq_jacobian_dict = {
allowed_opt_criteria[0]: self._residuals_mse_jacobian,
allowed_opt_criteria[1]: None,
allowed_opt_criteria[2]: None
}

if opt_criteria in allowed_opt_criteria:
self._criteria_func = lsq_residual_dict[opt_criteria]
self._jacobian_func = lsq_jacobian_dict[opt_criteria]
self.opt_criteria = opt_criteria
else:
raise ValueError("Optimization criteria not supported")

self.b = None

@property
def initial_guess_method(self):
return self._initial_guess_method
Expand Down Expand Up @@ -85,7 +95,7 @@ def _residuals_mae(self, coef, unshimmed_vec, coil_mat, factor):

# MAE regularized to minimize currents
return np.mean(np.abs(unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False))) / factor + \
(self.reg_factor * np.mean(np.abs(coef) / self.reg_factor_channel))
(self.reg_factor * np.mean(np.abs(coef) / self.reg_factor_channel))

def _residuals_mse(self, coef, unshimmed_vec, coil_mat, factor):
""" Objective function to minimize the mean squared error (MSE)
Expand All @@ -104,7 +114,7 @@ def _residuals_mse(self, coef, unshimmed_vec, coil_mat, factor):

# MSE regularized to minimize currents
return np.mean((unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False)) ** 2) / factor + \
(self.reg_factor * np.mean(np.abs(coef) / self.reg_factor_channel))
(self.reg_factor * np.mean(np.abs(coef) / self.reg_factor_channel))

def _residuals_std(self, coef, unshimmed_vec, coil_mat, factor):
""" Objective function to minimize the standard deviation (STD)
Expand All @@ -123,7 +133,7 @@ def _residuals_std(self, coef, unshimmed_vec, coil_mat, factor):

# STD regularized to minimize currents
return np.std(unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False)) / factor + \
(self.reg_factor * np.mean(np.abs(coef) / self.reg_factor_channel))
(self.reg_factor * np.mean(np.abs(coef) / self.reg_factor_channel))

def _define_scipy_constraints(self):
return self._define_scipy_coef_sum_max_constraint()
Expand All @@ -148,13 +158,14 @@ def _apply_sum_constraint(inputs, indexes, coef_sum_max):
return constraints

def _scipy_minimize(self, currents_0, unshimmed_vec, coil_mat, scipy_constraints, factor):

currents_sp = opt.minimize(self._criteria_func, currents_0,
args=(unshimmed_vec, coil_mat, factor),
method='SLSQP',
bounds=self.merged_bounds,
constraints=tuple(scipy_constraints),
jac=self._jacobian_func,
options={'maxiter': 1000})

return currents_sp

def get_initial_guess(self):
Expand Down Expand Up @@ -208,6 +219,30 @@ def _initial_guess_zeros(self):

return current_0

def _residuals_mse_jacobian(self, coef, unshimmed_vec, coil_mat, factor):
""" Jacobian of the function that we want to minimize
The function to minimize is :
np.mean((unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False)) ** 2) / factor+\
(self.reg_factor * np.mean(np.abs(coef) / self.reg_factor_channel))
Args:
coef (numpy.ndarray): 1D array of channel coefficients
unshimmed_vec (numpy.ndarray): 1D flattened array (point) of the masked unshimmed map
coil_mat (numpy.ndarray): 2D flattened array (point, channel) of masked coils
(axis 0 must align with unshimmed_vec)
factor (float): unused but necessary to call the function in scipy.optimize.minimize
Returns:
jacobian (numpy.ndarray) : 1D array of the gradient of the mse function to minimize
"""
jacobian = np.array([
self.b * np.sum((unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False)) * coil_mat[:, j]) +
np.sign(coef[j]) * (self.reg_factor / (9 * self.reg_factor_channel[j]))
for j in range(coef.size)
])

return jacobian

def optimize(self, mask):
"""
Optimize unshimmed volume by varying current to each channel
Expand Down Expand Up @@ -252,7 +287,9 @@ def optimize(self, mask):
# regularization on the currents has no affect on the output stability factor.
stability_factor = self._criteria_func(self._initial_guess_zeros(), unshimmed_vec, np.zeros_like(coil_mat),
factor=1)

if self.opt_criteria == 'mse':
# This factor is used to calculate the Jacobian of the mse function
self.b = (2 / (unshimmed_vec.size * stability_factor))
currents_sp = self._scipy_minimize(currents_0, unshimmed_vec, coil_mat, scipy_constraints,
factor=stability_factor)

Expand Down Expand Up @@ -385,5 +422,6 @@ def _scipy_minimize(self, currents_0, unshimmed_vec, coil_mat, scipy_constraints
args=(unshimmed_vec, coil_mat, factor),
method='SLSQP',
constraints=tuple(scipy_constraints),
jac=self._jacobian_func,
options={'maxiter': 500})
return currents_sp

0 comments on commit ceb5fd5

Please sign in to comment.