Skip to content

Commit 7e492b6

Browse files
committed
[distributed] add test to ensure that dist autograd contexts are cleaned up incase of nested rpcs
Pull Request resolved: #28485 This diff adds a test to ensure that when we have multiple nested RPCs inside a dist autograd context, the context that is created as a result of a nested rpc is cleaned up after the node creating the context exits the context manager. For example, worker 0 might send an rpc to worker 1 that results in an rpc to worker 2, so worker 2 will have 0's context, even though worker 0 never directly talked to 2. This test ensures that the context on 2 would also be cleaned up. ghstack-source-id: 92611018 Differential Revision: [D18079212](https://our.internmc.facebook.com/intern/diff/D18079212/)
1 parent d04973b commit 7e492b6

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

test/dist_autograd_test.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,18 @@
2222

2323
known_context_ids = []
2424

25-
# we don't need a lock here since the GIL is held while executing remote
26-
# python UDFs, so access to known_context_ids is serialized across several workers.
27-
def _store_context_id(context_id):
28-
global known_context_ids
29-
known_context_ids.append(context_id)
3025

3126
# Send rpc done info and context_id to
3227
# dst_rank = (self.rank + rank_distance) % self.world_size
28+
# we don't need a lock here since the GIL is held while executing remote
29+
# python UDFs, so access is serialized across several workers.
3330
def _set_rpc_done(ctx_id, rank_distance):
3431
global rpc_done
3532
global ctx_ids
33+
global known_context_ids
3634
rpc_done[rank_distance] = True
3735
ctx_ids[rank_distance] = ctx_id
36+
known_context_ids.append(ctx_id)
3837

3938

4039
def my_py_add(t1, t2):
@@ -52,7 +51,7 @@ def my_py_nested_call(t1, t2, dst, world_size, hops):
5251
# after dist autograd context is cleaned up, it should be cleaned up on other
5352
# nodes. This helper allows timeout_seconds for those RPCs to be completed, and
5453
# ensures that all the contexts have been cleaned up in that timeframe.any
55-
def _all_contexts_cleaned_up(num_contexts, timeout_seconds=10):
54+
def _all_contexts_cleaned_up(timeout_seconds=10):
5655
global known_context_ids
5756
start = time.time()
5857
context_id_to_raised = {}
@@ -62,10 +61,10 @@ def _all_contexts_cleaned_up(num_contexts, timeout_seconds=10):
6261
dist_autograd._retrieve_context(context_id)
6362
except RuntimeError:
6463
context_id_to_raised[context_id] = True
65-
if len(context_id_to_raised) == num_contexts:
64+
if len(context_id_to_raised) == len(known_context_ids):
6665
break
6766
# all contexts have been cleaned up if trying to retrieve any context resulted in a RuntimeError.
68-
success = len(context_id_to_raised) == num_contexts and all(context_id_to_raised.values())
67+
success = len(context_id_to_raised) == len(known_context_ids) and all(context_id_to_raised.values())
6968
return success
7069

7170

@@ -453,19 +452,35 @@ def test_rpc_complex_args(self):
453452

454453
@dist_init(setup_model_parallel=True)
455454
def test_context_cleanup_many_workers(self):
456-
global known_context_ids
457455
dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
458456
with dist_autograd.context() as context_id:
459457
t1 = torch.ones(3, 3, requires_grad=True)
460458
t2 = torch.zeros(3, 3, requires_grad=True)
461459
for dst_rank in dst_ranks:
462460
ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
463-
rpc.rpc_sync("worker{}".format(dst_rank), _store_context_id, args=(context_id,))
461+
rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1))
464462
# the thread's context id should be cleaned up
465463
with self.assertRaises(RuntimeError):
466464
dist_autograd._retrieve_context(context_id)
467465
# check that all contexts have been cleaned up.
468-
success = _all_contexts_cleaned_up(num_contexts=len(dst_ranks))
466+
success = _all_contexts_cleaned_up()
467+
self.assertTrue(success)
468+
469+
@dist_init(setup_model_parallel=True)
470+
def test_context_cleanup_nested_rpc(self):
471+
dst_rank = (self.rank + 1) % self.world_size
472+
nested_dst_rank = (dst_rank + 1) % self.world_size
473+
with dist_autograd.context() as context_id:
474+
t1 = torch.ones(3, 3, requires_grad=True)
475+
t2 = torch.zeros(3, 3, requires_grad=True)
476+
rpc.rpc_sync("worker{}".format(dst_rank),
477+
my_py_nested_call, args=(t1, t2, dst_rank, self.world_size, 0))
478+
# tell next worker and nested next worker to store this context id
479+
# so we can verify that it has been cleaned up
480+
rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1))
481+
rpc.rpc_sync("worker{}".format(nested_dst_rank), _set_rpc_done, args=(context_id, 2))
482+
dist.barrier() # let all nodes finish sending their RPCs
483+
success = _all_contexts_cleaned_up()
469484
self.assertTrue(success)
470485

471486
@dist_init(setup_model_parallel=True)
@@ -477,7 +492,7 @@ def test_worker_ids_recorded(self):
477492
t1 = torch.ones(3, 3, requires_grad=False)
478493
t2 = torch.zeros(3, 3, requires_grad=False)
479494
for dst_rank in dst_ranks:
480-
ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
495+
rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
481496
rpc.rpc_sync(
482497
"worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)
483498
)

0 commit comments

Comments
 (0)