Skip to content

Commit

Permalink
Update on "AsyncCollectiveTensor: wait on pending collectives before …
Browse files Browse the repository at this point in the history
…executing compiled fw"

This patch is relatively low LoC but I'm not very satisfied with it (internal post coming soon)




[ghstack-poisoned]
  • Loading branch information
bdhirsh committed May 7, 2024
1 parent c793008 commit edba89d
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 20 deletions.
28 changes: 27 additions & 1 deletion test/distributed/_tensor/test_dtensor_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def forward(self, input):


def extract_graph(fx_g, _, graph_cell):
graph_cell[0] = fx_g
graph_cell[0] = fx_g.code
return fx_g


Expand Down Expand Up @@ -368,6 +368,32 @@ def fn(x_dt):
res = opt_fn(x_dt)
self.assertEqual(ref, res)

def test_graph_input_is_async(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

def fn(x):
return x.sin().sin()

opt_fn = torch.compile(fn, backend=aot_eager_graph, fullgraph=True)

x = torch.randn(4, 4, requires_grad=True)
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
x2 = x_dt.redistribute(mesh, [Replicate()], async_op=True)
x2 = x2.to_local()
out = opt_fn(x2)
# The important part: we get a wait_tensor() in the graph.
# At runtime, the input to the graph is an AsyncCollectiveTensor,
# and inside the graph we need to issue a wait() to synchronize.
self.assertExpectedInline(
str(fw_graph_cell[0]).strip(),
"""\
def forward(self, primals_1):
wait_tensor = torch.ops._c10d_functional.wait_tensor.default(primals_1)
sin = torch.ops.aten.sin.default(wait_tensor)
sin_1 = torch.ops.aten.sin.default(sin); sin = None
return [sin_1, primals_1, wait_tensor]""",
)

@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_dtensor_partial_placement_graph_output(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
Expand Down
2 changes: 0 additions & 2 deletions torch/_functorch/_aot_autograd/runtime_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ def runtime_wrapper(args: List[Any]):

# stash a ref to each input tensor we plan to use after the compiled function
orig_inputs = {i: args[i] for i in epilogue_args_idx}
for i in runtime_metadata.async_collective_inp_indices:
args[i] = args[i].trigger_wait()

if trace_joint:
args_ = list(args)
Expand Down
8 changes: 0 additions & 8 deletions torch/_functorch/_aot_autograd/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,6 @@ def __post_init__(self):
# When keep_input_mutations is set, we don't need to worry about our epilogue
# handling data-only mutations, because we keep them directly in the graph.

from torch.distributed._functional_collectives import AsyncCollectiveTensor

self.async_collective_inp_indices = [
i
for i, x in enumerate(self.subclass_inp_meta)
if isinstance(x, SubclassCreationMeta)
and isinstance(x.original_subclass, AsyncCollectiveTensor)
]
mutated_inp_runtime_indices = [
i
for i, m in enumerate(self.input_info)
Expand Down
17 changes: 8 additions & 9 deletions torch/distributed/_functional_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,7 @@ def __tensor_flatten__(self):
return ["elem"], None

def tolist(self):
self.trigger_wait()
return self.elem.tolist()
return self.trigger_wait().tolist()

@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
Expand All @@ -600,18 +599,18 @@ def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
return AsyncCollectiveTensor(elem)

def __repr__(self):
self.trigger_wait()
return f"AsyncCollectiveTensor({self.elem})"
return f"AsyncCollectiveTensor({self.trigger_wait()})"

def trigger_wait(self):
if not self.completed:
wait_tensor(self.elem)
out = wait_tensor(self.elem)
self.completed = True
return self.elem
return out
else:
return self.elem

def wait(self) -> torch.Tensor:
wait_tensor(self.elem)
return self.elem
return wait_tensor(self.elem)

def _get_acs_underlying_tensor(self):
"""This method enables _functional_collectives_impl to test if a tensor is an ACS"""
Expand All @@ -631,7 +630,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e: AsyncCollectiveTensor):
# wait_tensor is idepotent and will do stream sync only once
if not is_view_op:
e.trigger_wait()
return e.trigger_wait()
return e.elem

def wrap(e: torch.Tensor):
Expand Down

0 comments on commit edba89d

Please sign in to comment.