Skip to content

Commit

Permalink
Put Multiprocessing for every dataset (May change)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurelienpujolm committed Dec 16, 2022
1 parent ae01e66 commit 869ad8d
Showing 1 changed file with 40 additions and 40 deletions.
80 changes: 40 additions & 40 deletions shimmingtoolbox/shim/sequencer.py
Expand Up @@ -934,52 +934,52 @@ def select_optimizer(method, unshimmed, affine, coils: ListCoil, opt_criteria, p

return optimizer


@timeit
def _optimize(optimizer: Optimizer, nii_mask_anat, slices_anat, opt_criteria, shimwise_bounds=None,
dilation_kernel='sphere', dilation_size=3, path_output=None):
# Count shims to perform
n_shims = len(slices_anat)
# If the method is the mse with jacobian, it's faster to not do the multiprocessing on mac computer,
# But it's faster on linux ones
if opt_criteria == 'mse' and sys.platform != 'linux':
_worker_init(optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output,
shimwise_bounds)
# If the method is the mse with jacobian, it's faster to not do the multiprocessing on mac computer for smaller
# dataset, but it's not the case for big ones. So for now, I will put mp for every dataset, but it can change
#if opt_criteria == 'mse' and sys.platform != 'linux':
#_worker_init(optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output,
#shimwise_bounds)
#results = []
#for i in range(n_shims):
#result = _opt(i)
#results.append(result)
#else:
# multiprocessing optimization
_optimize_scope = (
optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output, shimwise_bounds)

# Default number of workers is set to mp.cpu_count()
# _worker_init gets called by each worker with _optimize_scope as arguments
# _worker_init converts those arguments as globals, so they can be accessed in _opt
# This works because each worker has its own version of the global variables
# This allows to use both fork and spawn while not serializing the arguments making it slow
# It also allows to give as input only 1 iterable (range(n_shims))) so 'starmap' does not have to be used

# 'imap_unordered' is used since a worker returns the value when it is done instead of waiting for the whole
# call to 'map', 'starmap' to finish. This allows to show progress. 'imap' is similar to 'imap_unordered' but
# since it returns in order, the progress is less accurate. Even though 'map_async' and 'starmap_async' do
# not block, the whole call needs to be finished to access the results (results.get()). A whole discussion
# thread is available here: https://stackoverflow.com/questions/26520781/multiprocessing-pool-whats-the
# -difference-between-map-async-and-imap
pool = mp.Pool(initializer=_worker_init, initargs=_optimize_scope)
try:

results = []
for i in range(n_shims):
result = _opt(i)
print(f"\rProgress 0.0%")
for i, result in enumerate(pool.imap_unordered(_opt, range(n_shims))):
print(f"\rProgress {np.round((i + 1) / n_shims * 100)}%")
results.append(result)
else:
# multiprocessing optimization
_optimize_scope = (
optimizer, nii_mask_anat, slices_anat, dilation_kernel, dilation_size, path_output, shimwise_bounds)

# Default number of workers is set to mp.cpu_count()
# _worker_init gets called by each worker with _optimize_scope as arguments
# _worker_init converts those arguments as globals, so they can be accessed in _opt
# This works because each worker has its own version of the global variables
# This allows to use both fork and spawn while not serializing the arguments making it slow
# It also allows to give as input only 1 iterable (range(n_shims))) so 'starmap' does not have to be used

# 'imap_unordered' is used since a worker returns the value when it is done instead of waiting for the whole
# call to 'map', 'starmap' to finish. This allows to show progress. 'imap' is similar to 'imap_unordered' but
# since it returns in order, the progress is less accurate. Even though 'map_async' and 'starmap_async' do
# not block, the whole call needs to be finished to access the results (results.get()). A whole discussion
# thread is available here: https://stackoverflow.com/questions/26520781/multiprocessing-pool-whats-the
# -difference-between-map-async-and-imap
pool = mp.Pool(initializer=_worker_init, initargs=_optimize_scope)
try:

results = []
print(f"\rProgress 0.0%")
for i, result in enumerate(pool.imap_unordered(_opt, range(n_shims))):
print(f"\rProgress {np.round((i + 1) / n_shims * 100)}%")
results.append(result)

except mp.context.TimeoutError:
logger.info("Multiprocessing might have hung, retry the same command")
finally:
pool.close()
pool.join()

except mp.context.TimeoutError:
logger.info("Multiprocessing might have hung, retry the same command")
finally:
pool.close()
pool.join()

results.sort(key=lambda x: x[0])
results_final = [r for i, r in results]
Expand Down

0 comments on commit 869ad8d

Please sign in to comment.