Skip to content

Commit 7c0acfc

Browse files
authored
Add fork start_method support for xmp.spawn (#4236)
* Add fork start_method support for xmp.spawn * Fix spacing format * Make start_method keyword only for run_multiprocess
1 parent 1251261 commit 7c0acfc

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

torch_xla/distributed/xla_multiprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def spawn(fn,
386386
return None.
387387
"""
388388
if pjrt.using_pjrt():
389-
return pjrt.spawn(fn, args)
389+
return pjrt.spawn(fn, start_method, args)
390390

391391
if not _is_xla_config():
392392
# If this is not an XLA setup, jump to normal multi-processing.

torch_xla/experimental/pjrt.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,19 @@ def _thread_fn(device: torch.device):
192192

193193

194194
@requires_pjrt
195-
def _run_multiprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]:
195+
def _run_multiprocess(fn: Callable[..., R],
196+
*args,
197+
start_method: str = 'spawn',
198+
**kwargs) -> Dict[int, R]:
196199
"""Runs `fn` on all devices available to PjRt.
197200
198201
Spawns one process per physical device (e.g. TPU chip).
199202
200203
Args:
201204
fn: Function to run on all devices
202205
args: args to pass to `fn`
206+
start_method: The Python `multiprocessing` process creation method.
207+
Default: `spawn`
203208
kwargs: kwargs to pass to `fn`
204209
205210
Returns:
@@ -213,7 +218,7 @@ def _run_multiprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]:
213218

214219
with concurrent.futures.ProcessPoolExecutor(
215220
max_workers=num_processes,
216-
mp_context=torch.multiprocessing.get_context('spawn')) as executor:
221+
mp_context=torch.multiprocessing.get_context(start_method)) as executor:
217222

218223
mp_fn = functools.partial(
219224
_run_thread_per_device,
@@ -239,15 +244,17 @@ def __call__(self) -> None:
239244
self.fn(global_ordinal(), *self.args, **self.kwargs)
240245

241246

242-
def spawn(fn: Callable, args: Tuple = ()) -> None:
247+
def spawn(fn: Callable, start_method: str = 'spawn', args: Tuple = ()) -> None:
243248
"""Run functions compatible with xmp.spawn.
244249
245250
Args:
246251
fn: Callable that takes the process index as the first argument.
247252
args: args to pass to `fn`
253+
start_method: The Python `multiprocessing` process creation method.
254+
Default: `spawn`
248255
"""
249256
spawn_fn = _SpawnFn(fn, *args)
250-
_run_multiprocess(spawn_fn)
257+
_run_multiprocess(spawn_fn, start_method=start_method)
251258

252259

253260
def broadcast_master_param(model: nn.Module) -> None:

0 commit comments

Comments
 (0)