-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inductor support for aten::all_reduce #93111
Changes from 38 commits
ba04c86
c77ba15
aa76f6d
c83ee3e
a673f6b
dfcf748
7ce7416
851a8fa
793be98
9ad496e
edad4be
0d111e6
16d3894
732d27e
4a78f87
bd814d2
fc94abc
48f3544
ab412f4
e0055d9
576e292
8b9adc1
e94ebfa
dc12a9d
9d626fa
47eb8c6
af99f73
0c66c3a
731a98b
fe76e7e
1d846b0
3eeb44a
f33314e
eb45a40
7b1026b
1402b00
c921522
da0ac0e
43b6f8c
9618b3d
0a61bca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
# Owner(s): ["module: dynamo"] | ||
import functools | ||
import unittest | ||
from unittest.mock import patch | ||
import torch | ||
from torch._C import FileCheck | ||
from torch._dispatch.python import enable_python_dispatcher | ||
import torch._dynamo | ||
import torch._dynamo.test_case | ||
from torch._dynamo.utils import same | ||
from torch._dynamo.testing import CompileCounter | ||
from torch.fx.experimental.proxy_tensor import make_fx | ||
from torch.testing._internal.common_distributed import ( | ||
DynamoDistributedSingleProcTestCase, | ||
DynamoDistributedMultiProcTestCase, | ||
_dynamo_dist_per_rank_init, | ||
requires_nccl, | ||
skip_if_lt_x_gpu | ||
) | ||
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx | ||
from torch._inductor.utils import has_triton, run_and_get_triton_code | ||
import torch._dynamo.logging | ||
|
||
# LOL if you don't remember to import this, then the op isn't registered and it hits | ||
# the no-op C++ kernel that i am forced to implement despite not using it | ||
import torch.distributed._functional_collectives | ||
|
||
|
||
@requires_nccl() | ||
class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): | ||
""" | ||
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under | ||
""" | ||
def get_world_trs(self): | ||
return { | ||
"tag": "", | ||
"ranks": list(range(self.world_size)), | ||
"group_size": self.world_size, | ||
} | ||
|
||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") | ||
@skip_if_lt_x_gpu(2) | ||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor | ||
@patch.object(torch._inductor.config, "compile_threads", 1) | ||
def test_allreduce_inductor(self): | ||
""" | ||
This is matmul/cat/allreduce is a pattern we aim to optimize. | ||
""" | ||
|
||
def matmul_cat_col(a, b, c, d, e, f, *, tag, ranks, group_size): | ||
x = torch.matmul(a, b) | ||
y = torch.matmul(c, d) | ||
z = torch.cat((x, y)) | ||
ar = torch.ops.aten.all_reduce(z, "sum", tag, ranks, group_size) | ||
g = torch.matmul(e, f) | ||
ar = torch.ops.aten.wait_tensor(ar) | ||
out = torch.add(ar, g.repeat(2, 1)) | ||
return (out, ) | ||
|
||
def compile(func, example_inputs): | ||
graph = make_fx(func)(*example_inputs) | ||
return inductor_compile_fx(graph, example_inputs) | ||
|
||
with _dynamo_dist_per_rank_init(self.rank, self.world_size): | ||
|
||
matmul_cat_col = functools.partial( | ||
matmul_cat_col, | ||
**self.get_world_trs(), | ||
) | ||
inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 6 | ||
|
||
# non-ideally, i seem to need to enable this at user level in order to construct a torchdispatch subclass | ||
# inside py registered collective ops | ||
with enable_python_dispatcher(): | ||
eager_out = matmul_cat_col(*inputs) | ||
compiled_matmul_cat_col = compile(matmul_cat_col, inputs) | ||
inductor_out = compiled_matmul_cat_col(*inputs) | ||
assert same(eager_out, inductor_out, tol=0.001) | ||
|
||
|
||
@requires_nccl() | ||
class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): | ||
""" | ||
Prefer single-proc test runner for basic tests as it is easier to work with. | ||
""" | ||
def get_world_trs(self, world_size=1): | ||
return { | ||
"tag": "", | ||
"ranks": list(range(world_size)), | ||
"group_size": world_size, | ||
} | ||
|
||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") | ||
def test_inductor_single_op(self): | ||
torch._inductor.config.debug = True | ||
|
||
def func(inp, *, tag, ranks, group_size): | ||
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size) | ||
ar = torch.ops.aten.wait_tensor(ar) | ||
return ar | ||
|
||
inputs = torch.ones(4, 4, device="cuda") | ||
|
||
with enable_python_dispatcher(): | ||
compiled = torch.compile(func) | ||
out = compiled(inputs, **self.get_world_trs()) | ||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) | ||
FileCheck() \ | ||
.check("buf0 = empty_strided") \ | ||
.check("buf0.copy_(arg0_1)") \ | ||
.check("buf0_work = dist.all_reduce(buf0") \ | ||
.check("buf0_work.wait()") \ | ||
.check("return (buf1, )") \ | ||
.run(code) | ||
correct = func(inputs, **self.get_world_trs()) | ||
assert same(out, correct) | ||
|
||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") | ||
def test_inductor_steal_buffer(self): | ||
""" | ||
it's ok and optimal if inductor allreduce mutates the buffer of an intermediate | ||
that isn't going to be used again | ||
""" | ||
torch._inductor.config.debug = True | ||
|
||
def func(inp, *, tag, ranks, group_size): | ||
x = inp + 1 | ||
ar = torch.ops.aten.all_reduce(x, "sum", tag, ranks, group_size) | ||
ar = torch.ops.aten.wait_tensor(ar) | ||
# ensure other is not incorrectly aliasing ar's buffer | ||
other = torch.ones_like(inp) + 22 | ||
return ar, other | ||
|
||
inputs = torch.ones(4, 4, device="cuda") | ||
|
||
with enable_python_dispatcher(): | ||
compiled = torch.compile(func) | ||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) | ||
FileCheck() \ | ||
.check("buf1 = buf0; del buf0 # reuse") \ | ||
.check_not("buf1.copy_(") \ | ||
.check("buf1_work = dist.all_reduce(buf1") \ | ||
.check("buf1_work.wait()") \ | ||
.check("buf2 = buf1") \ | ||
.check("buf3 = empty_strided") \ | ||
.check("return (buf2, buf3") \ | ||
.run(code) | ||
out = compiled(inputs, **self.get_world_trs()) | ||
correct = func(inputs, **self.get_world_trs()) | ||
assert same(out, correct) | ||
|
||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") | ||
def test_inductor_doesnt_mutate_shared(self): | ||
""" | ||
make sure that an intermediate that's going to be reuse isn't mutated unless copied | ||
""" | ||
torch._inductor.config.debug = True | ||
|
||
def func(inp, *, tag, ranks, group_size): | ||
x = inp + 1 | ||
ar = torch.ops.aten.all_reduce(x, "sum", tag, ranks, group_size) | ||
y = x + 2 | ||
ar = torch.ops.aten.wait_tensor(ar) | ||
# ensure other is not incorrectly aliasing ar's buffer | ||
other = torch.ones_like(inp) + 22 | ||
return ar, y, other | ||
|
||
inputs = torch.ones(4, 4, device="cuda") | ||
|
||
with enable_python_dispatcher(): | ||
compiled = torch.compile(func) | ||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) | ||
FileCheck() \ | ||
.check("buf0 = empty_strided(") \ | ||
.check("buf2 = empty_strided") \ | ||
.check("triton__0.run(arg0_1, buf0, buf2") \ | ||
.check_not("copy_(") \ | ||
.check("buf1 = buf0; del buf0 # reuse") \ | ||
.check("buf1_work = dist.all_reduce(buf1") \ | ||
.check("buf1_work.wait()") \ | ||
.check("buf3 = buf1") \ | ||
.check("return (buf3, buf2, buf4") \ | ||
.run(code) | ||
out = compiled(inputs, **self.get_world_trs()) | ||
correct = func(inputs, **self.get_world_trs()) | ||
assert same(out, correct) | ||
|
||
def test_dynamo_trace_allreduce(self): | ||
def func(inp, *, tag, ranks, group_size): | ||
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dynamo works because we are calling the aten op not the functional collective directly, so we get around the AsyncTensor subclass? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, this is not the real dynamo support. See a later PR in this stack. I change this test to call the real collective and make changes to dynamo to fix it. |
||
return ar | ||
|
||
inputs = torch.ones(4, 4, device="cuda") | ||
counter = CompileCounter() | ||
with enable_python_dispatcher(): | ||
compiled = torch.compile(func, backend=counter) | ||
out = compiled(inputs, **self.get_world_trs()) | ||
correct = func(inputs, **self.get_world_trs()) | ||
assert counter.frame_count == 1 | ||
assert counter.op_count == 1 | ||
assert same(out, correct) | ||
|
||
def test_backwards(self): | ||
""" | ||
It's probably not that common to need backwards support for collectives. | ||
|
||
However, I wanted to at least see if it was possible to support it as a design goal. | ||
""" | ||
def func(inp, *, tag, ranks, group_size): | ||
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size) | ||
return ar | ||
|
||
input = torch.ones(4, 4, device="cuda", requires_grad=True) | ||
with enable_python_dispatcher(): | ||
# TODO implement backwards | ||
with self.assertRaisesRegex(RuntimeError, "derivative for aten::all_reduce is not implemented"): | ||
compiled = torch.compile(func, backend="aot_eager") # inductor bug with single-op allreduce graph | ||
out = compiled(input, **self.get_world_trs()) | ||
out.sum().backward() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I thought we didn't implement the allreduce backward yet, so it's a dummy function right now and we just test the correctness of dummy function to see if it could work here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea, i should delete this test. I asked Rodrigo to cover this in his own test file, and he did. Also, he dropped backward support for now so i think his test is a stub. we will add it later. |
||
|
||
correct_input = input.clone().detach().requires_grad_() | ||
correct = func(correct_input, **self.get_world_trs()) | ||
correct.sum().backward() | ||
assert same(out, correct) | ||
assert same(input.grad, correct_input.grad) | ||
|
||
def test_meta(self): | ||
x = torch.rand((2, 3, 4), device="meta") | ||
out = torch.ops.aten.all_reduce(x, "sum", **self.get_world_trs()) | ||
assert x.size() == out.size() | ||
|
||
|
||
if __name__ == "__main__": | ||
from torch._dynamo.test_case import run_tests | ||
|
||
run_tests() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,8 +78,6 @@ | |
Tensors backed by views add one more indirection to the IR. | ||
TensorBox -> View -> StorageBox -> Buffer | ||
In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox. | ||
|
||
For metadata mutation (e.g. as_strided_) we swing the TensorBox pointer. | ||
""" | ||
|
||
|
||
|
@@ -4202,3 +4200,109 @@ def debug_str(self, name="block"): | |
"", | ||
code.strip().replace("def forward(", f"def {name}("), | ||
) | ||
|
||
|
||
class Wait(ExternKernel): | ||
""" | ||
Wait should not be used by itself. It should always be constructed in tandem | ||
with a collective op that produces a work to wait on. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
layout, | ||
inputs, | ||
constant_args=(), | ||
): | ||
super().__init__(None, layout, inputs, constant_args) | ||
self.name = V.graph.register_buffer(self) | ||
|
||
def should_allocate(self): | ||
return False | ||
|
||
def codegen(self, wrapper): | ||
(input_collective,) = [t.codegen_reference() for t in self.inputs] | ||
work = f"{input_collective}_work" # hacky way to name work objs.. | ||
wrapper.writeline(f"{work}.wait()") | ||
|
||
# wait op still needs to produce a 'buffer' that represents the tensor output. | ||
# this is a symbolic gesture, and it gets handled by WrapperCodegen. | ||
# codegen outputs a '# reuse' line that assigns the input buffer here ('input_collective') | ||
# to a new name (`self.get_name()`) and `del`s the old name. | ||
wrapper.writeline(f"{self.get_name()} = {input_collective}") | ||
wconstab marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@classmethod | ||
def create(cls, collective_op: "TensorBox"): | ||
return Wait( | ||
layout=collective_op.get_layout(), | ||
inputs=[collective_op], | ||
) | ||
|
||
def get_alias_names(self): | ||
# Signal to codegen that our output buffer isn't safe to reuse | ||
return [self.inputs[0].codegen_reference()] | ||
|
||
|
||
class AllReduce(ExternKernel): | ||
def __init__( | ||
self, | ||
layout, | ||
inputs, | ||
constant_args=(), | ||
): | ||
super().__init__(None, layout, inputs, constant_args) | ||
self.name = V.graph.register_buffer(self) | ||
|
||
def should_allocate(self): | ||
return True | ||
|
||
@classmethod | ||
def create( | ||
cls, x: "TensorBox", reduce_op: str, tag: str, ranks: List[int], group_size: int | ||
): | ||
x = cls.realize_input(x) | ||
|
||
# is there a difference between literally using x.data.layout below, vs | ||
# creating a new one that has the same properties? | ||
new_layout = FlexibleLayout(x.get_device(), x.get_dtype(), x.get_size()) | ||
|
||
# AllReduce returns a 'work' object. But Inductor's scheduler doesn't need to know | ||
# about that, and we just pretend for scheduling purposes that the work obj is a 1-elem tensor. | ||
# Nobody should consume the output of AllReduce except 'Wait', which we control here. | ||
return AllReduce( | ||
layout=new_layout, | ||
inputs=[x], | ||
constant_args=[reduce_op, tag, ranks, group_size], | ||
) | ||
|
||
def codegen(self, wrapper): | ||
wrapper.add_import_once("import torch.distributed as dist") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So right now it generates the triton python code I suppose, would it possible to generate a C++ kernel in the future? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all of the python code i'm generating here is not triton code. Triton is generated by one layer deeper of inductor, when it does a 'fusion' of some ops and then codegens a kernel. This code here is going into the 'top level wrapper' script inductor generates, which is what calls the generated triton kernels and also calls other eager ops or allocations etc. the python wrapper code can also be changed to c++, and that's part of the 'aot inductor' workstream. |
||
wrapper.add_import_once( | ||
"from torch.distributed._functional_collectives import _str_to_reduce_op" | ||
) | ||
wrapper.add_import_once( | ||
"from torch.distributed.distributed_c10d import _find_or_create_pg_by_ranks_and_tag" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find using c10d internals a bit problematic but we can iterate over this later. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. happy to iterate. but you'll have to be more specific about the problem :) |
||
) | ||
|
||
# extract references to our args in string form for codegen output | ||
(input_name,) = [t.codegen_reference() for t in self.inputs] | ||
output_name = self.get_name() | ||
reduce_op, tag, ranks, group_size = self.constant_args | ||
|
||
# TODO: avoid more than one ref of the same pg (even though they are cached inside the api) | ||
wrapper.writeline( | ||
f"{output_name}_pg = _find_or_create_pg_by_ranks_and_tag('{tag}', {ranks}, {group_size})" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be cached across invocations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea, currently we will be constructing more than one obj that really are the same pg. (and calling _find_or_create more than one time for the same pg) is this a serious problem at all? I assumed it is 'safe' but also not ideal. My todo above was framed as a cleanup for later. But if you see a more serious issue let me know |
||
) | ||
|
||
# We must copy our input buffer sometimes, but the scheduler will help us find opportunities | ||
# to reuse the input buffer. (This requires no other users of the input buffer.) | ||
if not wrapper.did_reuse(self, self.inputs[0]): | ||
wrapper.writeline(f"{output_name}.copy_({input_name})") | ||
|
||
# At this point, output_name points to a buffer that is either | ||
# (1) the input buffer, which we're allowed to inplace modify | ||
# (2) a freshly allocated buffer, which we've copied the input into above | ||
wrapper.writeline( | ||
f"{output_name}_work = dist.all_reduce({output_name}, async_op=True," | ||
f" group={output_name}_pg, op=_str_to_reduce_op('{str(reduce_op)}'))" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretty interested on how
SingleProcTestCase
works for collective, is it doing allreduce on a single rank?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is calling the allreduce op but maybe the nccl kernel is skipped.
these tests are only measuring whether inductor generates the right code to call into dist.* apis.
I assume the apis will work as intended.
Above there is one 'real' integration test that runs multi-proc