Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inductor support for aten::all_reduce #93111

Closed
wants to merge 41 commits into from
Closed
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ba04c86
Inductor support for aten::all_reduce
wconstab Jan 26, 2023
c77ba15
Update on "Inductor support for aten::all_reduce"
wconstab Jan 28, 2023
aa76f6d
Update on "Inductor support for aten::all_reduce"
wconstab Jan 30, 2023
c83ee3e
Update on "Inductor support for aten::all_reduce"
wconstab Jan 30, 2023
a673f6b
Update on "Inductor support for aten::all_reduce"
wconstab Jan 30, 2023
dfcf748
Update on "Inductor support for aten::all_reduce"
wconstab Jan 31, 2023
7ce7416
Update on "Inductor support for aten::all_reduce"
wconstab Jan 31, 2023
851a8fa
Update on "Inductor support for aten::all_reduce"
wconstab Feb 2, 2023
793be98
Update on "Inductor support for aten::all_reduce"
wconstab Feb 2, 2023
9ad496e
Update on "Inductor support for aten::all_reduce"
wconstab Feb 2, 2023
edad4be
Update on "Inductor support for aten::all_reduce"
wconstab Feb 3, 2023
0d111e6
Update on "Inductor support for aten::all_reduce"
wconstab Feb 6, 2023
16d3894
Update on "Inductor support for aten::all_reduce"
wconstab Feb 6, 2023
732d27e
Update on "Inductor support for aten::all_reduce"
wconstab Feb 7, 2023
4a78f87
Update on "Inductor support for aten::all_reduce"
wconstab Feb 7, 2023
bd814d2
Update on "Inductor support for aten::all_reduce"
wconstab Feb 7, 2023
fc94abc
Update on "Inductor support for aten::all_reduce"
wconstab Feb 7, 2023
48f3544
Update on "Inductor support for aten::all_reduce"
wconstab Feb 8, 2023
ab412f4
Update on "Inductor support for aten::all_reduce"
wconstab Feb 8, 2023
e0055d9
Update on "Inductor support for aten::all_reduce"
wconstab Feb 8, 2023
576e292
Update on "Inductor support for aten::all_reduce"
wconstab Feb 8, 2023
8b9adc1
Update on "Inductor support for aten::all_reduce"
wconstab Feb 8, 2023
e94ebfa
Update on "Inductor support for aten::all_reduce"
wconstab Feb 8, 2023
dc12a9d
Update on "Inductor support for aten::all_reduce"
wconstab Feb 8, 2023
9d626fa
Update on "Inductor support for aten::all_reduce"
wconstab Feb 8, 2023
47eb8c6
Update on "Inductor support for aten::all_reduce"
wconstab Feb 8, 2023
af99f73
Update on "Inductor support for aten::all_reduce"
wconstab Feb 8, 2023
0c66c3a
Update on "Inductor support for aten::all_reduce"
wconstab Feb 9, 2023
731a98b
Update on "Inductor support for aten::all_reduce"
wconstab Feb 9, 2023
fe76e7e
Update on "Inductor support for aten::all_reduce"
wconstab Feb 9, 2023
1d846b0
Update on "Inductor support for aten::all_reduce"
wconstab Feb 9, 2023
3eeb44a
Update on "Inductor support for aten::all_reduce"
wconstab Feb 10, 2023
f33314e
Update on "Inductor support for aten::all_reduce"
wconstab Feb 13, 2023
eb45a40
Update on "Inductor support for aten::all_reduce"
wconstab Feb 13, 2023
7b1026b
Update on "Inductor support for aten::all_reduce"
wconstab Feb 14, 2023
1402b00
Update on "Inductor support for aten::all_reduce"
wconstab Feb 15, 2023
c921522
Update on "Inductor support for aten::all_reduce"
wconstab Feb 16, 2023
da0ac0e
Update on "Inductor support for aten::all_reduce"
wconstab Feb 16, 2023
43b6f8c
Update on "Inductor support for aten::all_reduce"
wconstab Feb 16, 2023
9618b3d
Update on "Inductor support for aten::all_reduce"
wconstab Feb 16, 2023
0a61bca
Update on "Inductor support for aten::all_reduce"
wconstab Feb 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ test_dynamo_shard() {
test_inductor_distributed() {
# this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported
# with if required # gpus aren't available
PYTORCH_TEST_WITH_INDUCTOR=0 python test/run_test.py --include distributed/test_dynamo_distributed --verbose
PYTORCH_TEST_WITH_INDUCTOR=0 python test/run_test.py --include distributed/test_dynamo_distributed --include distributed/test_traceable_collectives --verbose
assert_git_not_dirty
}

Expand Down
1 change: 1 addition & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- torch/_subclasses/fake_utils.py
- torch/_subclasses/meta_utils.py
- test/distributed/test_dynamo_distributed.py
- test/distributed/test_traceable_collectives.py
- functorch/_src/partitioners.py
- functorch/_src/aot_autograd.py

Expand Down
236 changes: 236 additions & 0 deletions test/distributed/test_traceable_collectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Owner(s): ["module: dynamo"]
import functools
import unittest
from unittest.mock import patch
import torch
from torch._C import FileCheck
from torch._dispatch.python import enable_python_dispatcher
import torch._dynamo
import torch._dynamo.test_case
from torch._dynamo.utils import same
from torch._dynamo.testing import CompileCounter
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_distributed import (
DynamoDistributedSingleProcTestCase,
DynamoDistributedMultiProcTestCase,
_dynamo_dist_per_rank_init,
requires_nccl,
skip_if_lt_x_gpu
)
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
from torch._inductor.utils import has_triton, run_and_get_triton_code
import torch._dynamo.logging

# LOL if you don't remember to import this, then the op isn't registered and it hits
# the no-op C++ kernel that i am forced to implement despite not using it
import torch.distributed._functional_collectives


@requires_nccl()
class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
"""
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
"""
def get_world_trs(self):
return {
"tag": "",
"ranks": list(range(self.world_size)),
"group_size": self.world_size,
}

@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_allreduce_inductor(self):
"""
This is matmul/cat/allreduce is a pattern we aim to optimize.
"""

def matmul_cat_col(a, b, c, d, e, f, *, tag, ranks, group_size):
x = torch.matmul(a, b)
y = torch.matmul(c, d)
z = torch.cat((x, y))
ar = torch.ops.aten.all_reduce(z, "sum", tag, ranks, group_size)
g = torch.matmul(e, f)
ar = torch.ops.aten.wait_tensor(ar)
out = torch.add(ar, g.repeat(2, 1))
return (out, )

def compile(func, example_inputs):
graph = make_fx(func)(*example_inputs)
return inductor_compile_fx(graph, example_inputs)

with _dynamo_dist_per_rank_init(self.rank, self.world_size):

matmul_cat_col = functools.partial(
matmul_cat_col,
**self.get_world_trs(),
)
inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 6

# non-ideally, i seem to need to enable this at user level in order to construct a torchdispatch subclass
# inside py registered collective ops
with enable_python_dispatcher():
eager_out = matmul_cat_col(*inputs)
compiled_matmul_cat_col = compile(matmul_cat_col, inputs)
inductor_out = compiled_matmul_cat_col(*inputs)
assert same(eager_out, inductor_out, tol=0.001)


@requires_nccl()
class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty interested on how SingleProcTestCase works for collective, is it doing allreduce on a single rank?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is calling the allreduce op but maybe the nccl kernel is skipped.

these tests are only measuring whether inductor generates the right code to call into dist.* apis.

I assume the apis will work as intended.

Above there is one 'real' integration test that runs multi-proc

"""
Prefer single-proc test runner for basic tests as it is easier to work with.
"""
def get_world_trs(self, world_size=1):
return {
"tag": "",
"ranks": list(range(world_size)),
"group_size": world_size,
}

@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_inductor_single_op(self):
torch._inductor.config.debug = True

def func(inp, *, tag, ranks, group_size):
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size)
ar = torch.ops.aten.wait_tensor(ar)
return ar

inputs = torch.ones(4, 4, device="cuda")

with enable_python_dispatcher():
compiled = torch.compile(func)
out = compiled(inputs, **self.get_world_trs())
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
FileCheck() \
.check("buf0 = empty_strided") \
.check("buf0.copy_(arg0_1)") \
.check("buf0_work = dist.all_reduce(buf0") \
.check("buf0_work.wait()") \
.check("return (buf1, )") \
.run(code)
correct = func(inputs, **self.get_world_trs())
assert same(out, correct)

@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_inductor_steal_buffer(self):
"""
it's ok and optimal if inductor allreduce mutates the buffer of an intermediate
that isn't going to be used again
"""
torch._inductor.config.debug = True

def func(inp, *, tag, ranks, group_size):
x = inp + 1
ar = torch.ops.aten.all_reduce(x, "sum", tag, ranks, group_size)
ar = torch.ops.aten.wait_tensor(ar)
# ensure other is not incorrectly aliasing ar's buffer
other = torch.ones_like(inp) + 22
return ar, other

inputs = torch.ones(4, 4, device="cuda")

with enable_python_dispatcher():
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
FileCheck() \
.check("buf1 = buf0; del buf0 # reuse") \
.check_not("buf1.copy_(") \
.check("buf1_work = dist.all_reduce(buf1") \
.check("buf1_work.wait()") \
.check("buf2 = buf1") \
.check("buf3 = empty_strided") \
.check("return (buf2, buf3") \
.run(code)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
assert same(out, correct)

@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_inductor_doesnt_mutate_shared(self):
"""
make sure that an intermediate that's going to be reuse isn't mutated unless copied
"""
torch._inductor.config.debug = True

def func(inp, *, tag, ranks, group_size):
x = inp + 1
ar = torch.ops.aten.all_reduce(x, "sum", tag, ranks, group_size)
y = x + 2
ar = torch.ops.aten.wait_tensor(ar)
# ensure other is not incorrectly aliasing ar's buffer
other = torch.ones_like(inp) + 22
return ar, y, other

inputs = torch.ones(4, 4, device="cuda")

with enable_python_dispatcher():
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
FileCheck() \
.check("buf0 = empty_strided(") \
.check("buf2 = empty_strided") \
.check("triton__0.run(arg0_1, buf0, buf2") \
.check_not("copy_(") \
.check("buf1 = buf0; del buf0 # reuse") \
.check("buf1_work = dist.all_reduce(buf1") \
.check("buf1_work.wait()") \
.check("buf3 = buf1") \
.check("return (buf3, buf2, buf4") \
.run(code)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
assert same(out, correct)

def test_dynamo_trace_allreduce(self):
def func(inp, *, tag, ranks, group_size):
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dynamo works because we are calling the aten op not the functional collective directly, so we get around the AsyncTensor subclass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, this is not the real dynamo support. See a later PR in this stack. I change this test to call the real collective and make changes to dynamo to fix it.

return ar

inputs = torch.ones(4, 4, device="cuda")
counter = CompileCounter()
with enable_python_dispatcher():
compiled = torch.compile(func, backend=counter)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
assert counter.frame_count == 1
assert counter.op_count == 1
assert same(out, correct)

def test_backwards(self):
"""
It's probably not that common to need backwards support for collectives.

However, I wanted to at least see if it was possible to support it as a design goal.
"""
def func(inp, *, tag, ranks, group_size):
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size)
return ar

input = torch.ones(4, 4, device="cuda", requires_grad=True)
with enable_python_dispatcher():
# TODO implement backwards
with self.assertRaisesRegex(RuntimeError, "derivative for aten::all_reduce is not implemented"):
compiled = torch.compile(func, backend="aot_eager") # inductor bug with single-op allreduce graph
out = compiled(input, **self.get_world_trs())
out.sum().backward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I thought we didn't implement the allreduce backward yet, so it's a dummy function right now and we just test the correctness of dummy function to see if it could work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, i should delete this test. I asked Rodrigo to cover this in his own test file, and he did. Also, he dropped backward support for now so i think his test is a stub. we will add it later.


correct_input = input.clone().detach().requires_grad_()
correct = func(correct_input, **self.get_world_trs())
correct.sum().backward()
assert same(out, correct)
assert same(input.grad, correct_input.grad)

def test_meta(self):
x = torch.rand((2, 3, 4), device="meta")
out = torch.ops.aten.all_reduce(x, "sum", **self.get_world_trs())
assert x.size() == out.size()


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

run_tests()
108 changes: 106 additions & 2 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@
Tensors backed by views add one more indirection to the IR.
TensorBox -> View -> StorageBox -> Buffer
In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox.

For metadata mutation (e.g. as_strided_) we swing the TensorBox pointer.
"""


Expand Down Expand Up @@ -4202,3 +4200,109 @@ def debug_str(self, name="block"):
"",
code.strip().replace("def forward(", f"def {name}("),
)


class Wait(ExternKernel):
"""
Wait should not be used by itself. It should always be constructed in tandem
with a collective op that produces a work to wait on.
"""

def __init__(
self,
layout,
inputs,
constant_args=(),
):
super().__init__(None, layout, inputs, constant_args)
self.name = V.graph.register_buffer(self)

def should_allocate(self):
return False

def codegen(self, wrapper):
(input_collective,) = [t.codegen_reference() for t in self.inputs]
work = f"{input_collective}_work" # hacky way to name work objs..
wrapper.writeline(f"{work}.wait()")

# wait op still needs to produce a 'buffer' that represents the tensor output.
# this is a symbolic gesture, and it gets handled by WrapperCodegen.
# codegen outputs a '# reuse' line that assigns the input buffer here ('input_collective')
# to a new name (`self.get_name()`) and `del`s the old name.
wrapper.writeline(f"{self.get_name()} = {input_collective}")
wconstab marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def create(cls, collective_op: "TensorBox"):
return Wait(
layout=collective_op.get_layout(),
inputs=[collective_op],
)

def get_alias_names(self):
# Signal to codegen that our output buffer isn't safe to reuse
return [self.inputs[0].codegen_reference()]


class AllReduce(ExternKernel):
def __init__(
self,
layout,
inputs,
constant_args=(),
):
super().__init__(None, layout, inputs, constant_args)
self.name = V.graph.register_buffer(self)

def should_allocate(self):
return True

@classmethod
def create(
cls, x: "TensorBox", reduce_op: str, tag: str, ranks: List[int], group_size: int
):
x = cls.realize_input(x)

# is there a difference between literally using x.data.layout below, vs
# creating a new one that has the same properties?
new_layout = FlexibleLayout(x.get_device(), x.get_dtype(), x.get_size())

# AllReduce returns a 'work' object. But Inductor's scheduler doesn't need to know
# about that, and we just pretend for scheduling purposes that the work obj is a 1-elem tensor.
# Nobody should consume the output of AllReduce except 'Wait', which we control here.
return AllReduce(
layout=new_layout,
inputs=[x],
constant_args=[reduce_op, tag, ranks, group_size],
)

def codegen(self, wrapper):
wrapper.add_import_once("import torch.distributed as dist")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So right now it generates the triton python code I suppose, would it possible to generate a C++ kernel in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of the python code i'm generating here is not triton code. Triton is generated by one layer deeper of inductor, when it does a 'fusion' of some ops and then codegens a kernel. This code here is going into the 'top level wrapper' script inductor generates, which is what calls the generated triton kernels and also calls other eager ops or allocations etc.

the python wrapper code can also be changed to c++, and that's part of the 'aot inductor' workstream.

wrapper.add_import_once(
"from torch.distributed._functional_collectives import _str_to_reduce_op"
)
wrapper.add_import_once(
"from torch.distributed.distributed_c10d import _find_or_create_pg_by_ranks_and_tag"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find using c10d internals a bit problematic but we can iterate over this later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to iterate. but you'll have to be more specific about the problem :)

)

# extract references to our args in string form for codegen output
(input_name,) = [t.codegen_reference() for t in self.inputs]
output_name = self.get_name()
reduce_op, tag, ranks, group_size = self.constant_args

# TODO: avoid more than one ref of the same pg (even though they are cached inside the api)
wrapper.writeline(
f"{output_name}_pg = _find_or_create_pg_by_ranks_and_tag('{tag}', {ranks}, {group_size})"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be cached across invocations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, currently we will be constructing more than one obj that really are the same pg. (and calling _find_or_create more than one time for the same pg)

is this a serious problem at all? I assumed it is 'safe' but also not ideal. My todo above was framed as a cleanup for later. But if you see a more serious issue let me know

)

# We must copy our input buffer sometimes, but the scheduler will help us find opportunities
# to reuse the input buffer. (This requires no other users of the input buffer.)
if not wrapper.did_reuse(self, self.inputs[0]):
wrapper.writeline(f"{output_name}.copy_({input_name})")

# At this point, output_name points to a buffer that is either
# (1) the input buffer, which we're allowed to inplace modify
# (2) a freshly allocated buffer, which we've copied the input into above
wrapper.writeline(
f"{output_name}_work = dist.all_reduce({output_name}, async_op=True,"
f" group={output_name}_pg, op=_str_to_reduce_op('{str(reduce_op)}'))"
)