Skip to content

Commit

Permalink
Merge 344e765 into efe6aef
Browse files Browse the repository at this point in the history
  • Loading branch information
po09i committed Sep 1, 2022
2 parents efe6aef + 344e765 commit e2d59a6
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 51 deletions.
2 changes: 1 addition & 1 deletion examples/realtime_shimming.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ st_mask box --input "../anat/sub-example_unshimmed_e1.nii.gz" --size 40 40 14 --

# Shim
st_b0shim gradient_realtime --fmap "sub-example_fieldmap.nii.gz" --anat "../anat/sub-example_unshimmed_e1.nii.gz" --resp "${INPUT_PATH}/PMUresp_signal.resp" --mask-static "../../derivatives/sub-example/sub-example_anat_mask.nii.gz" --mask-riro "../../derivatives/sub-example/sub-example_anat_mask.nii.gz" --output "../../derivatives/sub-example/gradient_realtime" || exit
st_b0shim realtime --fmap "sub-example_fieldmap.nii.gz" --anat "../anat/sub-example_unshimmed_e1.nii.gz" --mask-static "../../derivatives/sub-example/sub-example_anat_mask.nii.gz" --mask-riro "../../derivatives/sub-example/sub-example_anat_mask.nii.gz" --resp "${INPUT_PATH}/PMUresp_signal.resp" --scanner-coil-order '1' --output-file-format-scanner "gradient" --output "../../derivatives/sub-example/realtime" || exit
st_b0shim realtime-dynamic --fmap "sub-example_fieldmap.nii.gz" --anat "../anat/sub-example_unshimmed_e1.nii.gz" --mask-static "../../derivatives/sub-example/sub-example_anat_mask.nii.gz" --mask-riro "../../derivatives/sub-example/sub-example_anat_mask.nii.gz" --resp "${INPUT_PATH}/PMUresp_signal.resp" --scanner-coil-order '1' --output-file-format-scanner "gradient" --output "../../derivatives/sub-example/realtime" || exit
OUTPUT_PATH="$(dirname "${INPUT_PATH}")/rt_shim_nifti/derivatives/sub-example"
echo -e "\n\033[0;32mOutput is located here: ${OUTPUT_PATH}"

Expand Down
2 changes: 1 addition & 1 deletion shimmingtoolbox/cli/b0shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,7 @@ def _plot_coefs(coil, slices, static_coefs, path_output, coil_number, rt_coefs=N
ax.set(ylim=(min_y - (0.05 * delta_y), max_y + (0.05 * delta_y)), xlim=(-0.75, n_channels - 0.25),
xticks=range(n_channels))
ax.legend()
ax.set_title(f"Slices: {slices[i_shim]}")
ax.set_title(f"Slices: {slices[i_shim]}, Total static current: {np.abs(static_coefs[i_shim]).sum()}")
ax.set_xlabel('Channels')
ax.set_ylabel(f"Coefficients {units}")

Expand Down
2 changes: 1 addition & 1 deletion shimmingtoolbox/masking/mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def resample_mask(nii_mask_from, nii_target, from_slices, dilation_kernel='None'
nii_mask_target = resample_from_to(nii_mask, nii_target, order=1, mode='grid-constant', cval=0)

# Resample the full mask onto nii_target
nii_full_mask_target = resample_from_to(nii_mask_from, nii_target, order=1, mode='grid-constant', cval=0)
nii_full_mask_target = resample_from_to(nii_mask_from, nii_target, order=0, mode='grid-constant', cval=0)

# TODO: Deal with soft mask
# Find highest value and stretch to 1
Expand Down
66 changes: 53 additions & 13 deletions shimmingtoolbox/optimizer/lsq_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,47 @@ def __init__(self, coils: ListCoil, unshimmed, affine):
"""
super().__init__(coils, unshimmed, affine)
self._initial_guess_method = 'mean'
self.initial_coefs = None

@property
def initial_guess_method(self):
return self._initial_guess_method

@initial_guess_method.setter
def initial_guess_method(self, method):
allowed_methods = ['mean', 'zeros']
def initial_guess_method(self, method, coefs=None):
allowed_methods = ['mean', 'zeros', 'set']
if method not in allowed_methods:
raise ValueError(f"Initial_guess_methos not supported. Supported methods are: {allowed_methods}")
raise ValueError(f"Initial_guess_method not supported. Supported methods are: {allowed_methods}")

if method == 'set':
if coefs is not None:
self.initial_coefs = coefs
else:
raise ValueError(f"There are no coefficients to set")

self._initial_guess_method = method

def _residuals(self, coef, unshimmed_vec, coil_mat, factor):
def _residuals_mae(self, coef, unshimmed_vec, coil_mat, factor, max_current):
""" Objective function to minimize
Args:
coef (numpy.ndarray): 1D array of channel coefficients
unshimmed_vec (numpy.ndarray): 1D flattened array (point) of the masked unshimmed map
coil_mat (numpy.ndarray): 2D flattened array (point, channel) of masked coils
(axis 0 must align with unshimmed_vec)
factor (float): Devise the result by 'factor'. This allows to scale the output for the minimize function to
avoid positive directional linesearch
Returns:
numpy.ndarray: Residuals for least squares optimization -- equivalent to flattened shimmed vector
"""

# MAE regularized to minimize currents
# TODO Possibly regularize by individual channels (If one channel is used with big relative max coef (f0)
# then other currents are relatively not penalized
return np.mean(np.abs(unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False))) / factor + (0.15 * np.abs(coef).sum() / max_current)

def _residuals_mse(self, coef, unshimmed_vec, coil_mat, factor, max_current):
""" Objective function to minimize
Args:
Expand All @@ -57,10 +84,11 @@ def _residuals(self, coef, unshimmed_vec, coil_mat, factor):
Returns:
numpy.ndarray: Residuals for least squares optimization -- equivalent to flattened shimmed vector
"""
if unshimmed_vec.shape[0] != coil_mat.shape[0]:
ValueError(f"Unshimmed ({unshimmed_vec.shape}) and coil ({coil_mat.shape} arrays do not align on axis 0")

return np.sum(np.abs(unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False))) / factor
# MSE regularized to minimize currents
# TODO Possibly regularize by individual channels (If one channel is used with big relative max coef (f0)
# then other currents are relatively not penalized
return np.mean((unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False)) ** 2) / factor + (0.15 * np.abs(coef).sum() / max_current)

def _define_scipy_constraints(self):
return self._define_scipy_coef_sum_max_constraint()
Expand All @@ -85,8 +113,12 @@ def _apply_sum_constraint(inputs, indexes, coef_sum_max):
return constraints

def _scipy_minimize(self, currents_0, unshimmed_vec, coil_mat, scipy_constraints, factor):
currents_sp = opt.minimize(self._residuals, currents_0,
args=(unshimmed_vec, coil_mat, factor),
max_current = 0
for coil in self.coils:
max_current += coil.coef_sum_max

currents_sp = opt.minimize(self._residuals_mae, currents_0,
args=(unshimmed_vec, coil_mat, factor, max_current),
method='SLSQP',
bounds=self.merged_bounds,
constraints=tuple(scipy_constraints),
Expand All @@ -102,13 +134,17 @@ def get_initial_guess(self):

allowed_guess_method = {
'mean': self._initial_guess_mean_bounds,
'zeros': self._initial_guess_zeros
'zeros': self._initial_guess_zeros,
'set': self._initial_guess_set
}

initial_guess = allowed_guess_method[self.initial_guess_method]()

return initial_guess

def _initial_guess_set(self):
return self.initial_coefs

def _initial_guess_mean_bounds(self):
"""
Calculates the initial guess from the bounds, sets it to the mean of the bounds
Expand Down Expand Up @@ -181,7 +217,7 @@ def optimize(self, mask):
module='scipy')
# scipy minimize expects the return value of the residual function to be ~10^0 to 10^1
# --> aiming for 1 then optimizing will lower that
stability_factor = self._residuals(currents_0, unshimmed_vec, np.zeros_like(coil_mat), factor=1)
stability_factor = self._residuals_mae(self._initial_guess_zeros(), unshimmed_vec, np.zeros_like(coil_mat), factor=1, max_current=1)

currents_sp = self._scipy_minimize(currents_0, unshimmed_vec, coil_mat, scipy_constraints,
factor=stability_factor)
Expand Down Expand Up @@ -308,8 +344,12 @@ def _apply_max_pressure_constraint(inputs, i_channel, bound_max, pressure_max):

def _scipy_minimize(self, currents_0, unshimmed_vec, coil_mat, scipy_constraints, factor):
"""Redefined from super() since normal bounds are now constraints"""
currents_sp = opt.minimize(self._residuals, currents_0,
args=(unshimmed_vec, coil_mat, factor),
max_current = 0
for coil in self.coils:
max_current += coil.coef_sum_max

currents_sp = opt.minimize(self._residuals_mae, currents_0,
args=(unshimmed_vec, coil_mat, factor, max_current),
method='SLSQP',
constraints=tuple(scipy_constraints),
options={'maxiter': 500})
Expand Down
135 changes: 105 additions & 30 deletions shimmingtoolbox/shim/sequencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from matplotlib.figure import Figure
from mpl_toolkits.axes_grid1 import make_axes_locatable
import json
import multiprocessing as mp
import sys

from shimmingtoolbox.optimizer.lsq_optimizer import LsqOptimizer, PmuLsqOptimizer
from shimmingtoolbox.optimizer.basic_optimizer import Optimizer
Expand All @@ -28,6 +30,11 @@

logger = logging.getLogger(__name__)

if sys.platform == 'linux':
mp.set_start_method('fork', force=True)
else:
mp.set_start_method('spawn', force=True)

supported_optimizers = {
'least_squares_rt': PmuLsqOptimizer,
'least_squares': LsqOptimizer,
Expand Down Expand Up @@ -198,7 +205,8 @@ def _eval_static_shim(opt: Optimizer, nii_fieldmap_orig, nii_mask, coef, slices,
# TODO: Output json sidecar
# TODO: Update the shim settings if Scanner coil?
# Output the resulting fieldmap since it can be calculated over the entire fieldmap
nii_shimmed_fmap = nib.Nifti1Image(shimmed[..., 0], nii_fieldmap_orig.affine, header=nii_fieldmap_orig.header)
nii_shimmed_fmap = nib.Nifti1Image(shimmed[..., 0], nii_fieldmap_orig.affine,
header=nii_fieldmap_orig.header)
fname_shimmed_fmap = os.path.join(path_output, 'fieldmap_calculated_shim.nii.gz')
nib.save(nii_shimmed_fmap, fname_shimmed_fmap)
else:
Expand All @@ -214,6 +222,7 @@ def _eval_static_shim(opt: Optimizer, nii_fieldmap_orig, nii_mask, coef, slices,
_plot_static_full_mask(unshimmed, shimmed_masked, mask_full_binary, path_output)
_plot_static_partial_mask(unshimmed, shimmed, masks_fmap, path_output)
_plot_currents(coef, path_output)
_cal_shimmed_anat_orient(coef, merged_coils, nii_mask, nii_fieldmap_orig, slices, path_output)

if logger.level <= getattr(logging, 'DEBUG'):
# Save to a NIfTI
Expand All @@ -227,6 +236,30 @@ def _eval_static_shim(opt: Optimizer, nii_fieldmap_orig, nii_mask, coef, slices,
nib.save(nii_correction, fname_correction)


def _cal_shimmed_anat_orient(coefs, coils, nii_mask_anat, nii_fieldmap, slices, path_output):
nii_coils = nib.Nifti1Image(coils, nii_fieldmap.affine, header=nii_fieldmap.header)
coils_anat = resample_from_to(nii_coils,
nii_mask_anat,
order=1,
mode='grid-constant',
cval=0).get_fdata()
fieldmap_anat = resample_from_to(nii_fieldmap,
nii_mask_anat,
order=1,
mode='grid-constant',
cval=0).get_fdata()

shimmed_anat_orient = np.zeros_like(fieldmap_anat)
for i_shim in range(len(slices)):
corr = np.sum(coefs[i_shim] * coils_anat, axis=3, keepdims=False)
shimmed_anat_orient[..., slices[i_shim]] = fieldmap_anat[..., slices[i_shim]] + corr[..., slices[i_shim]]

fname_shimmed_anat_orient = os.path.join(path_output, 'fig_shimmed_anat_orient.nii.gz')
nii_shimmed_anat_orient = nib.Nifti1Image(shimmed_anat_orient * nii_mask_anat.get_fdata(), nii_mask_anat.affine,
header=nii_mask_anat.header)
nib.save(nii_shimmed_anat_orient, fname_shimmed_anat_orient)


def _calc_shimmed_full_mask(unshimmed, correction, nii_mask_anat, nii_fieldmap, slices, masks_fmap):
mask_full_binary = np.clip(np.ceil(resample_from_to(nii_mask_anat,
nii_fieldmap,
Expand Down Expand Up @@ -854,38 +887,80 @@ def select_optimizer(method, unshimmed, affine, coils: ListCoil, pmu: PmuResp =

def _optimize(optimizer: Optimizer, nii_mask_anat, slices_anat, shimwise_bounds=None,
dilation_kernel='sphere', dilation_size=3, path_output=None):
# Count number of channels
n_channels = optimizer.merged_coils.shape[3]

# Count shims to perform
n_shims = len(slices_anat)

# Initialize
coefs = np.zeros((n_shims, n_channels))

# For each shim
for i in range(n_shims):
logger.info(f"Shimming shim group: {i + 1} of {n_shims}")
# Create nibabel object of the unshimmed map
nii_unshimmed = nib.Nifti1Image(optimizer.unshimmed, optimizer.unshimmed_affine)

# Create mask in the fieldmap coordinate system from the anat roi mask and slice anat mask
sliced_mask_resampled = resample_mask(nii_mask_anat, nii_unshimmed, slices_anat[i],
dilation_kernel=dilation_kernel,
dilation_size=dilation_size,
path_output=path_output).get_fdata()

# If new bounds are included, change them for each shim
if shimwise_bounds is not None:
optimizer.set_merged_bounds(shimwise_bounds[i])

if np.all(sliced_mask_resampled == 0):
continue

# Optimize using the mask
coefs[i, :] = optimizer.optimize(sliced_mask_resampled)

return coefs
# multiprocessing optimization
_optimize_scope = (
optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output, shimwise_bounds)

# Default number of workers is set to mp.cpu_count()
# _worker_init gets called by each worker with _optimize_scope as arguments
# _worker_init converts those arguments as globals so they can be accessed in _opt
# This works because each worker has its own version of the global variables
# This allows to use both fork and spawn while not serializing the arguments making it slow
pool = mp.Pool(initializer=_worker_init, initargs=_optimize_scope)
try:
# should be safe to del here. Because at this point all the child processes have forked and inherited their
# copy
results = pool.starmap_async(_opt, [(i,) for i in range(n_shims)]).get(timeout=1200)
except mp.context.TimeoutError:
logger.info("Multiprocessing might have hung, retry the same command")
finally:
pool.close()
pool.join()

# TODO: Add a callback to have a progress bar, otherwise the logger will probably output in a messed up order
results.sort(key=lambda x: x[0])
results_final = [r for i, r in results]

return np.array(results_final)


gl_optimizer = None
gl_nii_mask_anat = None
gl_slices_anat = None
gl_dilation_kernel = None
gl_dilation_size = None
gl_path_output = None
gl_shimwise_bounds = None


def _worker_init(optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output,
shimwise_bounds):
global gl_optimizer, gl_nii_mask_anat, gl_slices_anat, gl_dilation_kernel
global gl_dilation_size, gl_path_output, gl_shimwise_bounds
gl_optimizer = optimizer
gl_nii_mask_anat = nii_mask_anat
gl_slices_anat = slices_anat
gl_dilation_kernel = dilation_kernel
gl_dilation_size = dilation_size
gl_path_output = path_output
gl_shimwise_bounds = shimwise_bounds


def _opt(i):
logger.info(f"Shimming shim group: {i + 1} of {len(gl_slices_anat)}")
# Create nibabel object of the unshimmed map
nii_unshimmed = nib.Nifti1Image(gl_optimizer.unshimmed, gl_optimizer.unshimmed_affine)

# Create mask in the fieldmap coordinate system from the anat roi mask and slice anat mask
sliced_mask_resampled = resample_mask(gl_nii_mask_anat, nii_unshimmed, gl_slices_anat[i],
dilation_kernel=gl_dilation_kernel,
dilation_size=gl_dilation_size,
path_output=gl_path_output).get_fdata()

# If new bounds are included, change them for each shim
if gl_shimwise_bounds is not None:
gl_optimizer.set_merged_bounds(gl_shimwise_bounds[i])

if np.all(sliced_mask_resampled == 0):
return i, np.zeros(gl_optimizer.merged_coils.shape[-1])

# Optimize using the mask
coef = gl_optimizer.optimize(sliced_mask_resampled)

return i, coef


def update_affine_for_ap_slices(affine, n_slices=1, axis=2):
Expand Down
11 changes: 6 additions & 5 deletions test/masking/test_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,17 @@
)]
)
class TestDilateBinaryMask(object):
def test_dilate_binary_mask_default(self, input_mask):
def test_dilate_binary_mask_cross(self, input_mask):
"""Default is the cross"""
dilated = dilate_binary_mask(input_mask[0])
dilated = dilate_binary_mask(input_mask[0], shape='cross')

# Expected slice 7
expected_slice = np.zeros([10, 10])
expected_slice[2, 5:9] = 1
expected_slice[1:4, 6:8] = 1

assert np.all(expected_slice == dilated[..., 7])
expected_slice[2, 1:4] = 1
expected_slice[1:4, 2] = 1

assert np.all(expected_slice == dilated[..., 3])

def test_dilate_binary_mask_sphere(self, input_mask):

Expand Down

0 comments on commit e2d59a6

Please sign in to comment.