@@ -192,14 +192,19 @@ def _thread_fn(device: torch.device):
192192
193193
194194@requires_pjrt
195- def _run_multiprocess (fn : Callable [..., R ], * args , ** kwargs ) -> Dict [int , R ]:
195+ def _run_multiprocess (fn : Callable [..., R ],
196+ * args ,
197+ start_method : str = 'spawn' ,
198+ ** kwargs ) -> Dict [int , R ]:
196199 """Runs `fn` on all devices available to PjRt.
197200
198201 Spawns one process per physical device (e.g. TPU chip).
199202
200203 Args:
201204 fn: Function to run on all devices
202205 args: args to pass to `fn`
206+ start_method: The Python `multiprocessing` process creation method.
207+ Default: `spawn`
203208 kwargs: kwargs to pass to `fn`
204209
205210 Returns:
@@ -213,7 +218,7 @@ def _run_multiprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]:
213218
214219 with concurrent .futures .ProcessPoolExecutor (
215220 max_workers = num_processes ,
216- mp_context = torch .multiprocessing .get_context ('spawn' )) as executor :
221+ mp_context = torch .multiprocessing .get_context (start_method )) as executor :
217222
218223 mp_fn = functools .partial (
219224 _run_thread_per_device ,
@@ -239,15 +244,17 @@ def __call__(self) -> None:
239244 self .fn (global_ordinal (), * self .args , ** self .kwargs )
240245
241246
242- def spawn (fn : Callable , args : Tuple = ()) -> None :
247+ def spawn (fn : Callable , start_method : str = 'spawn' , args : Tuple = ()) -> None :
243248 """Run functions compatible with xmp.spawn.
244249
245250 Args:
246251 fn: Callable that takes the process index as the first argument.
247252 args: args to pass to `fn`
253+ start_method: The Python `multiprocessing` process creation method.
254+ Default: `spawn`
248255 """
249256 spawn_fn = _SpawnFn (fn , * args )
250- _run_multiprocess (spawn_fn )
257+ _run_multiprocess (spawn_fn , start_method = start_method )
251258
252259
253260def broadcast_master_param (model : nn .Module ) -> None :
0 commit comments