Skip to content

Commit

Permalink
Merge bfad338 into 59897aa
Browse files Browse the repository at this point in the history
  • Loading branch information
po09i committed Jun 6, 2022
2 parents 59897aa + bfad338 commit 599deb5
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 36 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ on:
branches:
- master
pull_request:
branches:
- master
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

Expand Down
2 changes: 1 addition & 1 deletion shimmingtoolbox/masking/mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def resample_mask(nii_mask_from, nii_target, from_slices, dilation_kernel='None'
nii_mask_target = resample_from_to(nii_mask, nii_target, order=1, mode='grid-constant', cval=0)

# Resample the full mask onto nii_target
nii_full_mask_target = resample_from_to(nii_mask_from, nii_target, order=1, mode='grid-constant', cval=0)
nii_full_mask_target = resample_from_to(nii_mask_from, nii_target, order=0, mode='grid-constant', cval=0)

# TODO: Deal with soft mask
# Find highest value and stretch to 1
Expand Down
3 changes: 0 additions & 3 deletions shimmingtoolbox/optimizer/lsq_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ def _residuals(self, coef, unshimmed_vec, coil_mat, factor):
Returns:
numpy.ndarray: Residuals for least squares optimization -- equivalent to flattened shimmed vector
"""
if unshimmed_vec.shape[0] != coil_mat.shape[0]:
ValueError(f"Unshimmed ({unshimmed_vec.shape}) and coil ({coil_mat.shape} arrays do not align on axis 0")

return np.sum(np.abs(unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False))) / factor

def _define_scipy_constraints(self):
Expand Down
108 changes: 78 additions & 30 deletions shimmingtoolbox/shim/sequencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from matplotlib.figure import Figure
from mpl_toolkits.axes_grid1 import make_axes_locatable
import json
import multiprocessing as mp
import sys

from shimmingtoolbox.optimizer.lsq_optimizer import LsqOptimizer, PmuLsqOptimizer
from shimmingtoolbox.optimizer.basic_optimizer import Optimizer
Expand All @@ -28,6 +30,11 @@

logger = logging.getLogger(__name__)

if sys.platform == 'linux':
mp.set_start_method('fork', force=True)
else:
mp.set_start_method('spawn', force=True)

supported_optimizers = {
'least_squares_rt': PmuLsqOptimizer,
'least_squares': LsqOptimizer,
Expand Down Expand Up @@ -198,7 +205,8 @@ def _eval_static_shim(opt: Optimizer, nii_fieldmap_orig, nii_mask, coef, slices,
# TODO: Output json sidecar
# TODO: Update the shim settings if Scanner coil?
# Output the resulting fieldmap since it can be calculated over the entire fieldmap
nii_shimmed_fmap = nib.Nifti1Image(shimmed[..., 0], nii_fieldmap_orig.affine, header=nii_fieldmap_orig.header)
nii_shimmed_fmap = nib.Nifti1Image(shimmed[..., 0], nii_fieldmap_orig.affine,
header=nii_fieldmap_orig.header)
fname_shimmed_fmap = os.path.join(path_output, 'fieldmap_calculated_shim.nii.gz')
nib.save(nii_shimmed_fmap, fname_shimmed_fmap)
else:
Expand Down Expand Up @@ -854,38 +862,78 @@ def select_optimizer(method, unshimmed, affine, coils: ListCoil, pmu: PmuResp =

def _optimize(optimizer: Optimizer, nii_mask_anat, slices_anat, shimwise_bounds=None,
dilation_kernel='sphere', dilation_size=3, path_output=None):
# Count number of channels
n_channels = optimizer.merged_coils.shape[3]

# Count shims to perform
n_shims = len(slices_anat)

# Initialize
coefs = np.zeros((n_shims, n_channels))

# For each shim
for i in range(n_shims):
logger.info(f"Shimming shim group: {i + 1} of {n_shims}")
# Create nibabel object of the unshimmed map
nii_unshimmed = nib.Nifti1Image(optimizer.unshimmed, optimizer.unshimmed_affine)

# Create mask in the fieldmap coordinate system from the anat roi mask and slice anat mask
sliced_mask_resampled = resample_mask(nii_mask_anat, nii_unshimmed, slices_anat[i],
dilation_kernel=dilation_kernel,
dilation_size=dilation_size,
path_output=path_output).get_fdata()

# If new bounds are included, change them for each shim
if shimwise_bounds is not None:
optimizer.set_merged_bounds(shimwise_bounds[i])

if np.all(sliced_mask_resampled == 0):
continue

# Optimize using the mask
coefs[i, :] = optimizer.optimize(sliced_mask_resampled)

return coefs
# multiprocessing optimization
_optimize_scope = (
optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output, shimwise_bounds)
try:
# Default number of workers is set to mp.cpu_count()
# _worker_init gets called by each worker with _optimize_scope as arguments
# _worker_init converts those arguments as globals so they can be accessed in _opt
# This works because each worker has its own version of the global variables
# This allows to use both fork and spawn while not serializing the arguments making it slow
with mp.Pool(initializer=_worker_init, initargs=_optimize_scope) as pool:
# should be safe to del here. Because at this point all the child processes have forked and inherited their
# copy
results = pool.starmap_async(_opt, [(i,) for i in range(n_shims)]).get(timeout=1200)
except mp.context.TimeoutError:
logger.info("Multiprocessing might have hung, retry the same command")

# TODO: Add a callback to have a progress bar, otherwise the logger will probably output in a messed up order
results.sort(key=lambda x: x[0])
results_final = [r for i, r in results]

return np.array(results_final)


gl_optimizer = None
gl_nii_mask_anat = None
gl_slices_anat = None
gl_dilation_kernel = None
gl_dilation_size = None
gl_path_output = None
gl_shimwise_bounds = None


def _worker_init(optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output,
shimwise_bounds):

global gl_optimizer, gl_nii_mask_anat, gl_slices_anat, gl_dilation_kernel
global gl_dilation_size, gl_path_output, gl_shimwise_bounds
gl_optimizer = optimizer
gl_nii_mask_anat = nii_mask_anat
gl_slices_anat = slices_anat
gl_dilation_kernel = dilation_kernel
gl_dilation_size = dilation_size
gl_path_output = path_output
gl_shimwise_bounds = shimwise_bounds


def _opt(i):

logger.info(f"Shimming shim group: {i + 1} of {len(gl_slices_anat)}")
# Create nibabel object of the unshimmed map
nii_unshimmed = nib.Nifti1Image(gl_optimizer.unshimmed, gl_optimizer.unshimmed_affine)

# Create mask in the fieldmap coordinate system from the anat roi mask and slice anat mask
sliced_mask_resampled = resample_mask(gl_nii_mask_anat, nii_unshimmed, gl_slices_anat[i],
dilation_kernel=gl_dilation_kernel,
dilation_size=gl_dilation_size,
path_output=gl_path_output).get_fdata()

# If new bounds are included, change them for each shim
if gl_shimwise_bounds is not None:
gl_optimizer.set_merged_bounds(gl_shimwise_bounds[i])

if np.all(sliced_mask_resampled == 0):
return i, np.zeros(gl_optimizer.merged_coils.shape[-1])

# Optimize using the mask
coef = gl_optimizer.optimize(sliced_mask_resampled)

return i, coef


def update_affine_for_ap_slices(affine, n_slices=1, axis=2):
Expand Down

0 comments on commit 599deb5

Please sign in to comment.