2222
2323known_context_ids = []
2424
25- # we don't need a lock here since the GIL is held while executing remote
26- # python UDFs, so access to known_context_ids is serialized across several workers.
27- def _store_context_id (context_id ):
28- global known_context_ids
29- known_context_ids .append (context_id )
3025
3126# Send rpc done info and context_id to
3227# dst_rank = (self.rank + rank_distance) % self.world_size
28+ # we don't need a lock here since the GIL is held while executing remote
29+ # python UDFs, so access is serialized across several workers.
3330def _set_rpc_done (ctx_id , rank_distance ):
3431 global rpc_done
3532 global ctx_ids
33+ global known_context_ids
3634 rpc_done [rank_distance ] = True
3735 ctx_ids [rank_distance ] = ctx_id
36+ known_context_ids .append (ctx_id )
3837
3938
4039def my_py_add (t1 , t2 ):
@@ -52,7 +51,7 @@ def my_py_nested_call(t1, t2, dst, world_size, hops):
5251# after dist autograd context is cleaned up, it should be cleaned up on other
5352# nodes. This helper allows timeout_seconds for those RPCs to be completed, and
5453# ensures that all the contexts have been cleaned up in that timeframe.any
55- def _all_contexts_cleaned_up (num_contexts , timeout_seconds = 10 ):
54+ def _all_contexts_cleaned_up (timeout_seconds = 10 ):
5655 global known_context_ids
5756 start = time .time ()
5857 context_id_to_raised = {}
@@ -62,10 +61,10 @@ def _all_contexts_cleaned_up(num_contexts, timeout_seconds=10):
6261 dist_autograd ._retrieve_context (context_id )
6362 except RuntimeError :
6463 context_id_to_raised [context_id ] = True
65- if len (context_id_to_raised ) == num_contexts :
64+ if len (context_id_to_raised ) == len ( known_context_ids ) :
6665 break
6766 # all contexts have been cleaned up if trying to retrieve any context resulted in a RuntimeError.
68- success = len (context_id_to_raised ) == num_contexts and all (context_id_to_raised .values ())
67+ success = len (context_id_to_raised ) == len ( known_context_ids ) and all (context_id_to_raised .values ())
6968 return success
7069
7170
@@ -453,19 +452,35 @@ def test_rpc_complex_args(self):
453452
454453 @dist_init (setup_model_parallel = True )
455454 def test_context_cleanup_many_workers (self ):
456- global known_context_ids
457455 dst_ranks = {rank for rank in range (self .world_size ) if rank != self .rank }
458456 with dist_autograd .context () as context_id :
459457 t1 = torch .ones (3 , 3 , requires_grad = True )
460458 t2 = torch .zeros (3 , 3 , requires_grad = True )
461459 for dst_rank in dst_ranks :
462460 ret = rpc .rpc_sync ("worker{}" .format (dst_rank ), torch .add , args = (t1 , t2 ))
463- rpc .rpc_sync ("worker{}" .format (dst_rank ), _store_context_id , args = (context_id ,))
461+ rpc .rpc_sync ("worker{}" .format (dst_rank ), _set_rpc_done , args = (context_id , 1 ))
464462 # the thread's context id should be cleaned up
465463 with self .assertRaises (RuntimeError ):
466464 dist_autograd ._retrieve_context (context_id )
467465 # check that all contexts have been cleaned up.
468- success = _all_contexts_cleaned_up (num_contexts = len (dst_ranks ))
466+ success = _all_contexts_cleaned_up ()
467+ self .assertTrue (success )
468+
469+ @dist_init (setup_model_parallel = True )
470+ def test_context_cleanup_nested_rpc (self ):
471+ dst_rank = (self .rank + 1 ) % self .world_size
472+ nested_dst_rank = (dst_rank + 1 ) % self .world_size
473+ with dist_autograd .context () as context_id :
474+ t1 = torch .ones (3 , 3 , requires_grad = True )
475+ t2 = torch .zeros (3 , 3 , requires_grad = True )
476+ rpc .rpc_sync ("worker{}" .format (dst_rank ),
477+ my_py_nested_call , args = (t1 , t2 , dst_rank , self .world_size , 0 ))
478+ # tell next worker and nested next worker to store this context id
479+ # so we can verify that it has been cleaned up
480+ rpc .rpc_sync ("worker{}" .format (dst_rank ), _set_rpc_done , args = (context_id , 1 ))
481+ rpc .rpc_sync ("worker{}" .format (nested_dst_rank ), _set_rpc_done , args = (context_id , 2 ))
482+ dist .barrier () # let all nodes finish sending their RPCs
483+ success = _all_contexts_cleaned_up ()
469484 self .assertTrue (success )
470485
471486 @dist_init (setup_model_parallel = True )
@@ -477,7 +492,7 @@ def test_worker_ids_recorded(self):
477492 t1 = torch .ones (3 , 3 , requires_grad = False )
478493 t2 = torch .zeros (3 , 3 , requires_grad = False )
479494 for dst_rank in dst_ranks :
480- ret = rpc .rpc_sync ("worker{}" .format (dst_rank ), torch .add , args = (t1 , t2 ))
495+ rpc .rpc_sync ("worker{}" .format (dst_rank ), torch .add , args = (t1 , t2 ))
481496 rpc .rpc_sync (
482497 "worker{}" .format (dst_rank ), _set_rpc_done , args = (context_id , 1 )
483498 )
0 commit comments