From 6fa1c55757d96fc3f80bff85b73978fb57a2fa63 Mon Sep 17 00:00:00 2001 From: Alexandre D'Astous Date: Wed, 11 May 2022 17:26:57 -0400 Subject: [PATCH 01/10] Remove unecessary check' speed things up --- shimmingtoolbox/optimizer/lsq_optimizer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/shimmingtoolbox/optimizer/lsq_optimizer.py b/shimmingtoolbox/optimizer/lsq_optimizer.py index 808eeb62..c0c00de1 100644 --- a/shimmingtoolbox/optimizer/lsq_optimizer.py +++ b/shimmingtoolbox/optimizer/lsq_optimizer.py @@ -57,9 +57,6 @@ def _residuals(self, coef, unshimmed_vec, coil_mat, factor): Returns: numpy.ndarray: Residuals for least squares optimization -- equivalent to flattened shimmed vector """ - if unshimmed_vec.shape[0] != coil_mat.shape[0]: - ValueError(f"Unshimmed ({unshimmed_vec.shape}) and coil ({coil_mat.shape} arrays do not align on axis 0") - return np.sum(np.abs(unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False))) / factor def _define_scipy_constraints(self): From e2e1b407ffeed6c028b58667b04cae0cf6480f3d Mon Sep 17 00:00:00 2001 From: Alexandre D'Astous Date: Wed, 11 May 2022 17:33:13 -0400 Subject: [PATCH 02/10] Try multiprocessing --- shimmingtoolbox/shim/sequencer.py | 51 ++++++++++++++++++------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/shimmingtoolbox/shim/sequencer.py b/shimmingtoolbox/shim/sequencer.py index df81e2a5..3951a9c8 100644 --- a/shimmingtoolbox/shim/sequencer.py +++ b/shimmingtoolbox/shim/sequencer.py @@ -12,6 +12,7 @@ from matplotlib.figure import Figure from mpl_toolkits.axes_grid1 import make_axes_locatable import json +import multiprocessing as mp from shimmingtoolbox.optimizer.lsq_optimizer import LsqOptimizer, PmuLsqOptimizer from shimmingtoolbox.optimizer.basic_optimizer import Optimizer @@ -831,38 +832,44 @@ def select_optimizer(method, unshimmed, affine, coils: ListCoil, pmu: PmuResp = def _optimize(optimizer: Optimizer, nii_mask_anat, slices_anat, shimwise_bounds=None, dilation_kernel='sphere', dilation_size=3, path_output=None): - # Count number of channels - n_channels = optimizer.merged_coils.shape[3] # Count shims to perform n_shims = len(slices_anat) - # Initialize - coefs = np.zeros((n_shims, n_channels)) + # multiprocessing optimization + mp.set_start_method('spawn', force=True) + with mp.Pool(mp.cpu_count()) as pool: + results = pool.starmap_async(_opt, [(i, optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, + path_output, shimwise_bounds) for i in range(n_shims)]).get() + # TODO: Add a callback to have a progress bar, otherwise the logger will probably output in a messed up order + results.sort(key=lambda x: x[0]) + results_final = [r for i, r in results] - # For each shim - for i in range(n_shims): - logger.info(f"Shimming shim group: {i + 1} of {n_shims}") - # Create nibabel object of the unshimmed map - nii_unshimmed = nib.Nifti1Image(optimizer.unshimmed, optimizer.unshimmed_affine) + return np.array(results_final) - # Create mask in the fieldmap coordinate system from the anat roi mask and slice anat mask - sliced_mask_resampled = resample_mask(nii_mask_anat, nii_unshimmed, slices_anat[i], - dilation_kernel=dilation_kernel, - dilation_size=dilation_size, - path_output=path_output).get_fdata() - # If new bounds are included, change them for each shim - if shimwise_bounds is not None: - optimizer.set_merged_bounds(shimwise_bounds[i]) +def _opt(i, optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output, shimwise_bounds): + logger.info(f"Shimming shim group: {i + 1} of {len(slices_anat)}") + # Create nibabel object of the unshimmed map + nii_unshimmed = nib.Nifti1Image(optimizer.unshimmed, optimizer.unshimmed_affine) - if np.all(sliced_mask_resampled == 0): - continue + # Create mask in the fieldmap coordinate system from the anat roi mask and slice anat mask + sliced_mask_resampled = resample_mask(nii_mask_anat, nii_unshimmed, slices_anat[i], + dilation_kernel=dilation_kernel, + dilation_size=dilation_size, + path_output=path_output).get_fdata() - # Optimize using the mask - coefs[i, :] = optimizer.optimize(sliced_mask_resampled) + # If new bounds are included, change them for each shim + if shimwise_bounds is not None: + optimizer.set_merged_bounds(shimwise_bounds[i]) - return coefs + if np.all(sliced_mask_resampled == 0): + return i, np.zeros(optimizer.merged_coils.shape[-1]) + + # Optimize using the mask + coef = optimizer.optimize(sliced_mask_resampled) + + return i, coef def update_affine_for_ap_slices(affine, n_slices=1, axis=2): From 9c2b28d3382982dfc81684427d9512a12e27a360 Mon Sep 17 00:00:00 2001 From: Alexandre D'Astous Date: Sat, 14 May 2022 17:08:05 -0400 Subject: [PATCH 03/10] Use fork to allow logging, use logging library to avoid locking --- setup.py | 1 + shimmingtoolbox/shim/sequencer.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8c62911b..4f689e5a 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ "scipy~=1.6.0", "tqdm", "matplotlib~=3.1.2", + "multiprocessing-logging~=0.3.3" "psutil~=5.7.3", "pytest~=4.6.3", "pytest-cov~=2.5.1", diff --git a/shimmingtoolbox/shim/sequencer.py b/shimmingtoolbox/shim/sequencer.py index 3951a9c8..5a03cfea 100644 --- a/shimmingtoolbox/shim/sequencer.py +++ b/shimmingtoolbox/shim/sequencer.py @@ -13,6 +13,7 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable import json import multiprocessing as mp +import multiprocessing_logging from shimmingtoolbox.optimizer.lsq_optimizer import LsqOptimizer, PmuLsqOptimizer from shimmingtoolbox.optimizer.basic_optimizer import Optimizer @@ -28,6 +29,7 @@ ListCoil = List[Coil] logger = logging.getLogger(__name__) +multiprocessing_logging.install_mp_handler(logger) supported_optimizers = { 'least_squares_rt': PmuLsqOptimizer, @@ -837,10 +839,14 @@ def _optimize(optimizer: Optimizer, nii_mask_anat, slices_anat, shimwise_bounds= n_shims = len(slices_anat) # multiprocessing optimization - mp.set_start_method('spawn', force=True) + mp.set_start_method('fork', force=True) with mp.Pool(mp.cpu_count()) as pool: results = pool.starmap_async(_opt, [(i, optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output, shimwise_bounds) for i in range(n_shims)]).get() + + pool.close() + pool.join() + # TODO: Add a callback to have a progress bar, otherwise the logger will probably output in a messed up order results.sort(key=lambda x: x[0]) results_final = [r for i, r in results] From de12e0c1ce62f3b512f677cb53be0ea6117c935e Mon Sep 17 00:00:00 2001 From: Alexandre D'Astous Date: Sun, 15 May 2022 12:02:15 -0400 Subject: [PATCH 04/10] Missing comma --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4f689e5a..b540e70c 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ "scipy~=1.6.0", "tqdm", "matplotlib~=3.1.2", - "multiprocessing-logging~=0.3.3" + "multiprocessing-logging~=0.3.3", "psutil~=5.7.3", "pytest~=4.6.3", "pytest-cov~=2.5.1", From 2bfa0ed1d650344526befc3021ecd63797d6c94c Mon Sep 17 00:00:00 2001 From: Alexandre D'Astous Date: Wed, 18 May 2022 13:38:45 -0400 Subject: [PATCH 05/10] Add timeout to multiprocessing, this will raise an exception if shimming takes more than 20 minutes --- shimmingtoolbox/shim/sequencer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/shimmingtoolbox/shim/sequencer.py b/shimmingtoolbox/shim/sequencer.py index 5a03cfea..0d5f32ea 100644 --- a/shimmingtoolbox/shim/sequencer.py +++ b/shimmingtoolbox/shim/sequencer.py @@ -841,8 +841,12 @@ def _optimize(optimizer: Optimizer, nii_mask_anat, slices_anat, shimwise_bounds= # multiprocessing optimization mp.set_start_method('fork', force=True) with mp.Pool(mp.cpu_count()) as pool: - results = pool.starmap_async(_opt, [(i, optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, - path_output, shimwise_bounds) for i in range(n_shims)]).get() + try: + results = pool.starmap_async(_opt, + [(i, optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, + path_output, shimwise_bounds) for i in range(n_shims)]).get(timeout=1200) + except mp.context.TimeoutError: + logger.info("Multiprocessing might have hung, retry the same command") pool.close() pool.join() From bccc21c017ddaf568d0e77bf366cb274c955848d Mon Sep 17 00:00:00 2001 From: Alexandre D'Astous Date: Wed, 18 May 2022 16:35:12 -0400 Subject: [PATCH 06/10] Switch back to spawn, mac os seems to hang everytime with fork --- shimmingtoolbox/shim/sequencer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/shimmingtoolbox/shim/sequencer.py b/shimmingtoolbox/shim/sequencer.py index 0d5f32ea..1597cd58 100644 --- a/shimmingtoolbox/shim/sequencer.py +++ b/shimmingtoolbox/shim/sequencer.py @@ -839,17 +839,19 @@ def _optimize(optimizer: Optimizer, nii_mask_anat, slices_anat, shimwise_bounds= n_shims = len(slices_anat) # multiprocessing optimization - mp.set_start_method('fork', force=True) + mp.set_start_method('spawn', force=True) with mp.Pool(mp.cpu_count()) as pool: try: results = pool.starmap_async(_opt, [(i, optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output, shimwise_bounds) for i in range(n_shims)]).get(timeout=1200) + pool.close() + pool.join() + except mp.context.TimeoutError: logger.info("Multiprocessing might have hung, retry the same command") - pool.close() - pool.join() + # TODO: Add a callback to have a progress bar, otherwise the logger will probably output in a messed up order results.sort(key=lambda x: x[0]) From 1dd1876703744c81b1d4a93478d126ab9d8b8c93 Mon Sep 17 00:00:00 2001 From: Alexandre D'Astous Date: Wed, 18 May 2022 16:57:41 -0400 Subject: [PATCH 07/10] Remove mp-logging --- shimmingtoolbox/shim/sequencer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/shimmingtoolbox/shim/sequencer.py b/shimmingtoolbox/shim/sequencer.py index 1597cd58..f9f079ed 100644 --- a/shimmingtoolbox/shim/sequencer.py +++ b/shimmingtoolbox/shim/sequencer.py @@ -13,7 +13,6 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable import json import multiprocessing as mp -import multiprocessing_logging from shimmingtoolbox.optimizer.lsq_optimizer import LsqOptimizer, PmuLsqOptimizer from shimmingtoolbox.optimizer.basic_optimizer import Optimizer @@ -29,7 +28,6 @@ ListCoil = List[Coil] logger = logging.getLogger(__name__) -multiprocessing_logging.install_mp_handler(logger) supported_optimizers = { 'least_squares_rt': PmuLsqOptimizer, From e3cecacf846184dfa37dd440b73ba48f259b610a Mon Sep 17 00:00:00 2001 From: Alexandre D'Astous Date: Wed, 18 May 2022 16:58:57 -0400 Subject: [PATCH 08/10] Refactor try block --- shimmingtoolbox/shim/sequencer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/shimmingtoolbox/shim/sequencer.py b/shimmingtoolbox/shim/sequencer.py index f9f079ed..4a270220 100644 --- a/shimmingtoolbox/shim/sequencer.py +++ b/shimmingtoolbox/shim/sequencer.py @@ -838,18 +838,16 @@ def _optimize(optimizer: Optimizer, nii_mask_anat, slices_anat, shimwise_bounds= # multiprocessing optimization mp.set_start_method('spawn', force=True) - with mp.Pool(mp.cpu_count()) as pool: - try: + try: + with mp.Pool(mp.cpu_count()) as pool: results = pool.starmap_async(_opt, [(i, optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output, shimwise_bounds) for i in range(n_shims)]).get(timeout=1200) pool.close() pool.join() - except mp.context.TimeoutError: - logger.info("Multiprocessing might have hung, retry the same command") - - + except mp.context.TimeoutError: + logger.info("Multiprocessing might have hung, retry the same command") # TODO: Add a callback to have a progress bar, otherwise the logger will probably output in a messed up order results.sort(key=lambda x: x[0]) From 710e30a4dc0096be363a292e4f77a3e394709984 Mon Sep 17 00:00:00 2001 From: Nick Date: Sun, 5 Jun 2022 21:15:15 -0500 Subject: [PATCH 09/10] Use multiprocessing's fork method to speed up optimizing (#390) * Use multiprocessing's fork method to speed up - avoid python interpreter's startup time - avoid passing large data via pickle by using a global that can be inherited * Temporarily disable only testing on master PRs * Cleanup and add comments * Remove logging extension since not supported on macs * Add timeout and move start process to top of the file * Use initializer and initargs in mp.Pool * Use fork for linux and spawn for other platforms Co-authored-by: Alexandre D'Astous --- .github/workflows/tests.yml | 2 - setup.py | 1 - shimmingtoolbox/shim/sequencer.py | 75 ++++++++++++++++++++++--------- 3 files changed, 54 insertions(+), 24 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6a36d273..aaea1b5c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -6,8 +6,6 @@ on: branches: - master pull_request: - branches: - - master # Allows you to run this workflow manually from the Actions tab workflow_dispatch: diff --git a/setup.py b/setup.py index b540e70c..8c62911b 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,6 @@ "scipy~=1.6.0", "tqdm", "matplotlib~=3.1.2", - "multiprocessing-logging~=0.3.3", "psutil~=5.7.3", "pytest~=4.6.3", "pytest-cov~=2.5.1", diff --git a/shimmingtoolbox/shim/sequencer.py b/shimmingtoolbox/shim/sequencer.py index 9245d63c..823ab549 100644 --- a/shimmingtoolbox/shim/sequencer.py +++ b/shimmingtoolbox/shim/sequencer.py @@ -13,6 +13,7 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable import json import multiprocessing as mp +import sys from shimmingtoolbox.optimizer.lsq_optimizer import LsqOptimizer, PmuLsqOptimizer from shimmingtoolbox.optimizer.basic_optimizer import Optimizer @@ -29,6 +30,11 @@ logger = logging.getLogger(__name__) +if sys.platform == 'linux': + mp.set_start_method('fork', force=True) +else: + mp.set_start_method('spawn', force=True) + supported_optimizers = { 'least_squares_rt': PmuLsqOptimizer, 'least_squares': LsqOptimizer, @@ -199,7 +205,8 @@ def _eval_static_shim(opt: Optimizer, nii_fieldmap_orig, nii_mask, coef, slices, # TODO: Output json sidecar # TODO: Update the shim settings if Scanner coil? # Output the resulting fieldmap since it can be calculated over the entire fieldmap - nii_shimmed_fmap = nib.Nifti1Image(shimmed[..., 0], nii_fieldmap_orig.affine, header=nii_fieldmap_orig.header) + nii_shimmed_fmap = nib.Nifti1Image(shimmed[..., 0], nii_fieldmap_orig.affine, + header=nii_fieldmap_orig.header) fname_shimmed_fmap = os.path.join(path_output, 'fieldmap_calculated_shim.nii.gz') nib.save(nii_shimmed_fmap, fname_shimmed_fmap) else: @@ -855,20 +862,22 @@ def select_optimizer(method, unshimmed, affine, coils: ListCoil, pmu: PmuResp = def _optimize(optimizer: Optimizer, nii_mask_anat, slices_anat, shimwise_bounds=None, dilation_kernel='sphere', dilation_size=3, path_output=None): - # Count shims to perform n_shims = len(slices_anat) # multiprocessing optimization - mp.set_start_method('spawn', force=True) + _optimize_scope = ( + optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output, shimwise_bounds) try: - with mp.Pool(mp.cpu_count()) as pool: - results = pool.starmap_async(_opt, - [(i, optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, - path_output, shimwise_bounds) for i in range(n_shims)]).get(timeout=1200) - pool.close() - pool.join() - + # Default number of workers is set to mp.cpu_count() + # _worker_init gets called by each worker with _optimize_scope as arguments + # _worker_init converts those arguments as globals so they can be accessed in _opt + # This works because each worker has its own version of the global variables + # This allows to use both fork and spawn while not serializing the arguments making it slow + with mp.Pool(initializer=_worker_init, initargs=_optimize_scope) as pool: + # should be safe to del here. Because at this point all the child processes have forked and inherited their + # copy + results = pool.starmap_async(_opt, [(i,) for i in range(n_shims)]).get(timeout=1200) except mp.context.TimeoutError: logger.info("Multiprocessing might have hung, retry the same command") @@ -879,26 +888,50 @@ def _optimize(optimizer: Optimizer, nii_mask_anat, slices_anat, shimwise_bounds= return np.array(results_final) -def _opt(i, optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output, shimwise_bounds): - logger.info(f"Shimming shim group: {i + 1} of {len(slices_anat)}") +gl_optimizer = None +gl_nii_mask_anat = None +gl_slices_anat = None +gl_dilation_kernel = None +gl_dilation_size = None +gl_path_output = None +gl_shimwise_bounds = None + + +def _worker_init(optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output, + shimwise_bounds): + + global gl_optimizer, gl_nii_mask_anat, gl_slices_anat, gl_dilation_kernel + global gl_dilation_size, gl_path_output, gl_shimwise_bounds + gl_optimizer = optimizer + gl_nii_mask_anat = nii_mask_anat + gl_slices_anat = slices_anat + gl_dilation_kernel = dilation_kernel + gl_dilation_size = dilation_size + gl_path_output = path_output + gl_shimwise_bounds = shimwise_bounds + + +def _opt(i): + + logger.info(f"Shimming shim group: {i + 1} of {len(gl_slices_anat)}") # Create nibabel object of the unshimmed map - nii_unshimmed = nib.Nifti1Image(optimizer.unshimmed, optimizer.unshimmed_affine) + nii_unshimmed = nib.Nifti1Image(gl_optimizer.unshimmed, gl_optimizer.unshimmed_affine) # Create mask in the fieldmap coordinate system from the anat roi mask and slice anat mask - sliced_mask_resampled = resample_mask(nii_mask_anat, nii_unshimmed, slices_anat[i], - dilation_kernel=dilation_kernel, - dilation_size=dilation_size, - path_output=path_output).get_fdata() + sliced_mask_resampled = resample_mask(gl_nii_mask_anat, nii_unshimmed, gl_slices_anat[i], + dilation_kernel=gl_dilation_kernel, + dilation_size=gl_dilation_size, + path_output=gl_path_output).get_fdata() # If new bounds are included, change them for each shim - if shimwise_bounds is not None: - optimizer.set_merged_bounds(shimwise_bounds[i]) + if gl_shimwise_bounds is not None: + gl_optimizer.set_merged_bounds(gl_shimwise_bounds[i]) if np.all(sliced_mask_resampled == 0): - return i, np.zeros(optimizer.merged_coils.shape[-1]) + return i, np.zeros(gl_optimizer.merged_coils.shape[-1]) # Optimize using the mask - coef = optimizer.optimize(sliced_mask_resampled) + coef = gl_optimizer.optimize(sliced_mask_resampled) return i, coef From bfad338ba97044bfcaae233efae1ce90009a8295 Mon Sep 17 00:00:00 2001 From: Alexandre D'Astous Date: Sun, 5 Jun 2022 23:06:57 -0400 Subject: [PATCH 10/10] Resample with order 0 --- shimmingtoolbox/masking/mask_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shimmingtoolbox/masking/mask_utils.py b/shimmingtoolbox/masking/mask_utils.py index 0c5370ff..a6dacbcd 100644 --- a/shimmingtoolbox/masking/mask_utils.py +++ b/shimmingtoolbox/masking/mask_utils.py @@ -44,7 +44,7 @@ def resample_mask(nii_mask_from, nii_target, from_slices, dilation_kernel='None' nii_mask_target = resample_from_to(nii_mask, nii_target, order=1, mode='grid-constant', cval=0) # Resample the full mask onto nii_target - nii_full_mask_target = resample_from_to(nii_mask_from, nii_target, order=1, mode='grid-constant', cval=0) + nii_full_mask_target = resample_from_to(nii_mask_from, nii_target, order=0, mode='grid-constant', cval=0) # TODO: Deal with soft mask # Find highest value and stretch to 1