diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6a36d273..aaea1b5c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -6,8 +6,6 @@ on: branches: - master pull_request: - branches: - - master # Allows you to run this workflow manually from the Actions tab workflow_dispatch: diff --git a/shimmingtoolbox/masking/mask_utils.py b/shimmingtoolbox/masking/mask_utils.py index 0c5370ff..a6dacbcd 100644 --- a/shimmingtoolbox/masking/mask_utils.py +++ b/shimmingtoolbox/masking/mask_utils.py @@ -44,7 +44,7 @@ def resample_mask(nii_mask_from, nii_target, from_slices, dilation_kernel='None' nii_mask_target = resample_from_to(nii_mask, nii_target, order=1, mode='grid-constant', cval=0) # Resample the full mask onto nii_target - nii_full_mask_target = resample_from_to(nii_mask_from, nii_target, order=1, mode='grid-constant', cval=0) + nii_full_mask_target = resample_from_to(nii_mask_from, nii_target, order=0, mode='grid-constant', cval=0) # TODO: Deal with soft mask # Find highest value and stretch to 1 diff --git a/shimmingtoolbox/optimizer/lsq_optimizer.py b/shimmingtoolbox/optimizer/lsq_optimizer.py index 808eeb62..c0c00de1 100644 --- a/shimmingtoolbox/optimizer/lsq_optimizer.py +++ b/shimmingtoolbox/optimizer/lsq_optimizer.py @@ -57,9 +57,6 @@ def _residuals(self, coef, unshimmed_vec, coil_mat, factor): Returns: numpy.ndarray: Residuals for least squares optimization -- equivalent to flattened shimmed vector """ - if unshimmed_vec.shape[0] != coil_mat.shape[0]: - ValueError(f"Unshimmed ({unshimmed_vec.shape}) and coil ({coil_mat.shape} arrays do not align on axis 0") - return np.sum(np.abs(unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False))) / factor def _define_scipy_constraints(self): diff --git a/shimmingtoolbox/shim/sequencer.py b/shimmingtoolbox/shim/sequencer.py index fcbb1ce8..deb49ef9 100644 --- a/shimmingtoolbox/shim/sequencer.py +++ b/shimmingtoolbox/shim/sequencer.py @@ -12,6 +12,8 @@ from matplotlib.figure import Figure from mpl_toolkits.axes_grid1 import make_axes_locatable import json +import multiprocessing as mp +import sys from shimmingtoolbox.optimizer.lsq_optimizer import LsqOptimizer, PmuLsqOptimizer from shimmingtoolbox.optimizer.basic_optimizer import Optimizer @@ -28,6 +30,11 @@ logger = logging.getLogger(__name__) +if sys.platform == 'linux': + mp.set_start_method('fork', force=True) +else: + mp.set_start_method('spawn', force=True) + supported_optimizers = { 'least_squares_rt': PmuLsqOptimizer, 'least_squares': LsqOptimizer, @@ -198,7 +205,8 @@ def _eval_static_shim(opt: Optimizer, nii_fieldmap_orig, nii_mask, coef, slices, # TODO: Output json sidecar # TODO: Update the shim settings if Scanner coil? # Output the resulting fieldmap since it can be calculated over the entire fieldmap - nii_shimmed_fmap = nib.Nifti1Image(shimmed[..., 0], nii_fieldmap_orig.affine, header=nii_fieldmap_orig.header) + nii_shimmed_fmap = nib.Nifti1Image(shimmed[..., 0], nii_fieldmap_orig.affine, + header=nii_fieldmap_orig.header) fname_shimmed_fmap = os.path.join(path_output, 'fieldmap_calculated_shim.nii.gz') nib.save(nii_shimmed_fmap, fname_shimmed_fmap) else: @@ -854,38 +862,78 @@ def select_optimizer(method, unshimmed, affine, coils: ListCoil, pmu: PmuResp = def _optimize(optimizer: Optimizer, nii_mask_anat, slices_anat, shimwise_bounds=None, dilation_kernel='sphere', dilation_size=3, path_output=None): - # Count number of channels - n_channels = optimizer.merged_coils.shape[3] - # Count shims to perform n_shims = len(slices_anat) - # Initialize - coefs = np.zeros((n_shims, n_channels)) - - # For each shim - for i in range(n_shims): - logger.info(f"Shimming shim group: {i + 1} of {n_shims}") - # Create nibabel object of the unshimmed map - nii_unshimmed = nib.Nifti1Image(optimizer.unshimmed, optimizer.unshimmed_affine) - - # Create mask in the fieldmap coordinate system from the anat roi mask and slice anat mask - sliced_mask_resampled = resample_mask(nii_mask_anat, nii_unshimmed, slices_anat[i], - dilation_kernel=dilation_kernel, - dilation_size=dilation_size, - path_output=path_output).get_fdata() - - # If new bounds are included, change them for each shim - if shimwise_bounds is not None: - optimizer.set_merged_bounds(shimwise_bounds[i]) - - if np.all(sliced_mask_resampled == 0): - continue - - # Optimize using the mask - coefs[i, :] = optimizer.optimize(sliced_mask_resampled) - - return coefs + # multiprocessing optimization + _optimize_scope = ( + optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output, shimwise_bounds) + try: + # Default number of workers is set to mp.cpu_count() + # _worker_init gets called by each worker with _optimize_scope as arguments + # _worker_init converts those arguments as globals so they can be accessed in _opt + # This works because each worker has its own version of the global variables + # This allows to use both fork and spawn while not serializing the arguments making it slow + with mp.Pool(initializer=_worker_init, initargs=_optimize_scope) as pool: + # should be safe to del here. Because at this point all the child processes have forked and inherited their + # copy + results = pool.starmap_async(_opt, [(i,) for i in range(n_shims)]).get(timeout=1200) + except mp.context.TimeoutError: + logger.info("Multiprocessing might have hung, retry the same command") + + # TODO: Add a callback to have a progress bar, otherwise the logger will probably output in a messed up order + results.sort(key=lambda x: x[0]) + results_final = [r for i, r in results] + + return np.array(results_final) + + +gl_optimizer = None +gl_nii_mask_anat = None +gl_slices_anat = None +gl_dilation_kernel = None +gl_dilation_size = None +gl_path_output = None +gl_shimwise_bounds = None + + +def _worker_init(optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output, + shimwise_bounds): + + global gl_optimizer, gl_nii_mask_anat, gl_slices_anat, gl_dilation_kernel + global gl_dilation_size, gl_path_output, gl_shimwise_bounds + gl_optimizer = optimizer + gl_nii_mask_anat = nii_mask_anat + gl_slices_anat = slices_anat + gl_dilation_kernel = dilation_kernel + gl_dilation_size = dilation_size + gl_path_output = path_output + gl_shimwise_bounds = shimwise_bounds + + +def _opt(i): + + logger.info(f"Shimming shim group: {i + 1} of {len(gl_slices_anat)}") + # Create nibabel object of the unshimmed map + nii_unshimmed = nib.Nifti1Image(gl_optimizer.unshimmed, gl_optimizer.unshimmed_affine) + + # Create mask in the fieldmap coordinate system from the anat roi mask and slice anat mask + sliced_mask_resampled = resample_mask(gl_nii_mask_anat, nii_unshimmed, gl_slices_anat[i], + dilation_kernel=gl_dilation_kernel, + dilation_size=gl_dilation_size, + path_output=gl_path_output).get_fdata() + + # If new bounds are included, change them for each shim + if gl_shimwise_bounds is not None: + gl_optimizer.set_merged_bounds(gl_shimwise_bounds[i]) + + if np.all(sliced_mask_resampled == 0): + return i, np.zeros(gl_optimizer.merged_coils.shape[-1]) + + # Optimize using the mask + coef = gl_optimizer.optimize(sliced_mask_resampled) + + return i, coef def update_affine_for_ap_slices(affine, n_slices=1, axis=2):