Skip to content

Commit

Permalink
Add torch.while_loop support to AOT Inductor (#123586)
Browse files Browse the repository at this point in the history
Summary: Previously, `torch.while_loop` was supported only in JIT inductor (added in #122069). Here we extend the support to AOT Inductor.

Test Plan:

```
$ python test/inductor/test_aot_inductor.py -k test_while_loop
...
----------------------------------------------------------------------
Ran 24 tests in 129.236s

OK (skipped=8)

$ python test/inductor/test_control_flow.py
...
----------------------------------------------------------------------
Ran 50 tests in 136.199s

OK
```

Pull Request resolved: #123586
Approved by: https://github.com/jansel, https://github.com/chenyang78
  • Loading branch information
aakhundov authored and pytorchmergebot committed Apr 9, 2024
1 parent 3908ebc commit c773913
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 16 deletions.
109 changes: 107 additions & 2 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,21 @@
try:
try:
from .test_aot_inductor_utils import AOTIRunnerUtil
from .test_control_flow import CondModels, prepend_predicates
from .test_control_flow import (
CondModels,
prepend_counters,
prepend_predicates,
WhileLoopModels,
)
from .test_torchinductor import copy_tests, requires_multigpu, TestFailure
except ImportError:
from test_aot_inductor_utils import AOTIRunnerUtil
from test_control_flow import CondModels, prepend_predicates
from test_control_flow import (
CondModels,
prepend_counters,
prepend_predicates,
WhileLoopModels,
)
from test_torchinductor import copy_tests, requires_multigpu, TestFailure
except (unittest.SkipTest, ImportError) as e:
if __name__ == "__main__":
Expand Down Expand Up @@ -969,6 +979,96 @@ def test_cond_non_tensor_predicates(self, dynamic):
dynamic_shapes=dynamic_shapes,
)

@skipIfRocm
def test_while_loop_simple(self):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"ci": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
WhileLoopModels.Simple(),
prepend_counters(inputs),
dynamic_shapes=dynamic_shapes,
)

@skipIfRocm
def test_while_loop_nested(self):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"ci": {},
"cj": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
WhileLoopModels.Nested(),
prepend_counters(inputs, num_counters=2),
dynamic_shapes=dynamic_shapes,
)

@skipIfRocm
def test_while_loop_with_outer_code(self):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"c": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
WhileLoopModels.OuterCode(),
prepend_counters(inputs),
dynamic_shapes=dynamic_shapes,
)

@skipIfRocm
def test_while_loop_with_parameters(self):
inputs = (torch.randn((10, 20), device=self.device),)
dim0_a = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"c": {},
"a": {0: dim0_a, 1: None},
}
self.check_model_with_multiple_inputs(
WhileLoopModels.Parameters(self.device),
prepend_counters(inputs),
dynamic_shapes=dynamic_shapes,
)

@skipIfRocm
def test_while_loop_with_outer_buffers(self):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
# dynamic shapes don't work now due to
# https://github.com/pytorch/pytorch/issues/123596
# dim0_ab = Dim("s0", min=2, max=1024)
# dynamic_shapes = {
# "c": {},
# "a": {0: dim0_ab, 1: None},
# "b": {0: dim0_ab, 1: None},
# }
dynamic_shapes = None
self.check_model_with_multiple_inputs(
WhileLoopModels.OuterBuffers(),
prepend_counters(inputs),
dynamic_shapes=dynamic_shapes,
)

@config.patch({"is_predispatch": True})
def test_constant(self):
class M(torch.nn.Module):
Expand Down Expand Up @@ -2590,6 +2690,11 @@ def fail_non_abi_compatible_cuda(is_skip=False):
# https://github.com/pytorch/pytorch/issues/122991
"test_runtime_checks_complex": fail_with_and_without_stack_allocation(is_skip=True),
"test_runtime_checks_fp8": fail_with_and_without_stack_allocation(is_skip=True),
"test_while_loop_simple": fail_stack_allocation(is_skip=True),
"test_while_loop_nested": fail_stack_allocation(is_skip=True),
"test_while_loop_with_outer_code": fail_stack_allocation(is_skip=True),
"test_while_loop_with_parameters": fail_stack_allocation(is_skip=True),
"test_while_loop_with_outer_buffers": fail_stack_allocation(is_skip=True),
}

CUDA_TEST_FAILURES = {
Expand Down
49 changes: 43 additions & 6 deletions test/inductor/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,11 @@ def _run_test(

for inputs in input_sets:
for inputs_with_predicates in prepend_predicates(inputs, num_predicates):
cloned_inputs = [inp.clone() for inp in inputs_with_predicates]
result = model(*inputs_with_predicates)
result_compiled = compiled_model(*inputs_with_predicates)
# inputs must not be mutated
torch.testing.assert_close(cloned_inputs, inputs_with_predicates)
torch.testing.assert_close(result, result_compiled)

self.assertEqual(cnt.frame_count, 1, "only one compilation expected")
Expand Down Expand Up @@ -488,8 +491,6 @@ def body_fn_nested(i2, j2, x2, y2):

return torch._higher_order_ops.while_loop(cond_fn, body_fn, (ci, cj, a, b))

# TODO(aakhundov): add while_loop test with parametrs
# once dynamo / export allows while_loop closure capture
class Parameters(torch.nn.Module):
class InnerModel(torch.nn.Module):
def __init__(self, device):
Expand Down Expand Up @@ -526,7 +527,9 @@ def body_fn(c, x, y):
return f * g / 1.41

# TODO(aakhundov): add while_loop test with outer buffers
# once dynamo / export allows while_loop closure capture
# with dynamic=True once dynamo / export allows while_loop
# closure capture with mark_dynamic:
# https://github.com/pytorch/pytorch/issues/123596
class OuterBuffers(torch.nn.Module):
def forward(self, c, a, b):
d = a * 2
Expand Down Expand Up @@ -570,8 +573,12 @@ def _run_test(

for inputs in input_sets:
for inputs_with_counters in prepend_counters(inputs, num_counters):
cloned_inputs = [inp.clone() for inp in inputs_with_counters]
result = model(*inputs_with_counters)
result_compiled = compiled_model(*inputs_with_counters)
with torch.no_grad():
result_compiled = compiled_model(*inputs_with_counters)
# inputs must not be mutated
torch.testing.assert_close(cloned_inputs, inputs_with_counters)
torch.testing.assert_close(
result, result_compiled, atol=1e-4, rtol=1e-4
)
Expand All @@ -597,7 +604,7 @@ def test_while_loop_simple_control_flow(self, device, dynamic):
@parametrize("device", ["cpu", "cuda"])
@parametrize("dynamic", [False, True])
def test_while_loop_nested_control_flow(self, device, dynamic):
# while_loop control flow without nesting
# while_loop control flow with nesting
self._run_test(
model=WhileLoopModels.Nested(),
inputs=(
Expand All @@ -613,7 +620,7 @@ def test_while_loop_nested_control_flow(self, device, dynamic):
@parametrize("device", ["cpu", "cuda"])
@parametrize("dynamic", [False, True])
def test_while_loop_with_outer_code(self, device, dynamic):
# while_loop control flow without nesting
# while_loop control flow with outer code
self._run_test(
model=WhileLoopModels.OuterCode(),
inputs=(
Expand All @@ -624,6 +631,36 @@ def test_while_loop_with_outer_code(self, device, dynamic):
dynamic=dynamic,
)

@skipIfRocm
@requires_cuda
@parametrize("device", ["cpu", "cuda"])
@parametrize("dynamic", [False, True])
def test_while_loop_with_parameters(self, device, dynamic):
# while_loop control flow with parameters
self._run_test(
model=WhileLoopModels.Parameters(device),
inputs=(torch.randn(10, 20),),
device=device,
dynamic=dynamic,
)

@requires_cuda
@parametrize("device", ["cpu", "cuda"])
# dynamic=True doesn't work now due to
# https://github.com/pytorch/pytorch/issues/123596
@parametrize("dynamic", [False])
def test_while_loop_with_outer_buffers(self, device, dynamic):
# while_loop control flow with outer code
self._run_test(
model=WhileLoopModels.OuterBuffers(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
)


instantiate_parametrized_tests(CondTests)
instantiate_parametrized_tests(WhileLoopTests)
Expand Down
75 changes: 75 additions & 0 deletions torch/_inductor/codegen/cpp_wrapper_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,9 @@ def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
# to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
# constructor is deleted.
src = f"std::move({src})"
# in case the outer_output carried a value
# before (e.g., in the while_loop codegen)
self.writeline(f"{outer_output}.reset();")
self.writeline(f"{outer_output} = {src}{self.ending}")

def codegen_conditional(self, conditional):
Expand Down Expand Up @@ -1740,6 +1743,78 @@ def codegen_conditional(self, conditional):
self.writeline(ExitSubgraphLine(self))
self.writeline("}")

def codegen_while_loop(self, while_loop):
name = while_loop.get_name()
outer_carried_inputs = [
buf.codegen_reference() for buf in while_loop.carried_inputs
]
outer_additional_inputs = [
buf.codegen_reference() for buf in while_loop.additional_inputs
]
cond_result_name = f"{name}_cond_result"

if config.abi_compatible:
self.writeline(f"RAIIAtenTensorHandle {cond_result_name};")

cond_outer_inputs = []
for inp, out in zip(outer_carried_inputs, while_loop.outputs):
# in ABI-compatible mode, the carried inputs are codegened
# as buffers outside the while loop and set to the initial
# values. at the end of each while_loop iteration, they
# will be assined the carried values.
out_name = out.get_name()
self.writeline(f"AtenTensorHandle {out_name}_handle;")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({inp}, &{out_name}_handle));"
)
self.writeline(f"RAIIAtenTensorHandle {out_name}({out_name}_handle);")
cond_outer_inputs.append(out_name)

# additional inputs will be assinged within the while_loop
# iteration directly from the corresponding outer graph buffers
cond_outer_inputs.extend(outer_additional_inputs)
else:
self.writeline(f"at::Tensor {cond_result_name};")
self.writeline(f"at::Tensor {name}[{len(outer_carried_inputs)}];")
for i, inp in enumerate(outer_carried_inputs):
# set the initial state before the loop
self.writeline(f"{name}[{i}] = {inp};")

cond_outer_inputs = [
*[f"{name}[{i}]" for i in range(len(outer_carried_inputs))],
*outer_additional_inputs,
]

cond_outer_outputs = [cond_result_name]
body_outer_inputs = list(cond_outer_inputs)
body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)]

self.writeline("while (1) {")
self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph))
self.codegen_subgraph(
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
)

if config.abi_compatible:
cond_result = f"{cond_result_name}_scalar"
self.writeline(f"bool {cond_result};")
# in ABI-compatible mode, we need to use the ABI shim function
# to extract a C++ bool from the unrelying scalar bool Tensor
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_bool({cond_result_name}, &{cond_result}));"
)
else:
cond_result = f"{cond_result_name}.item<bool>()"
self.writeline(f"if (!{cond_result}) break;")

self.writeline(ExitSubgraphLine(self))
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
self.codegen_subgraph(
while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
)
self.writeline(ExitSubgraphLine(self))
self.writeline("}")

def generate_extern_kernel_args_decl_if_needed(
self, op_overload, raw_args, output_args
):
Expand Down
18 changes: 10 additions & 8 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,22 +1550,24 @@ def codegen_while_loop(self, while_loop):
outer_additional_inputs = [
buf.codegen_reference() for buf in while_loop.additional_inputs
]
outer_inputs = outer_carried_inputs + outer_additional_inputs

self.writeline(f"{name} = [None] * {len(outer_inputs)}")
for i, inp in enumerate(outer_inputs):
self.writeline(f"{name} = [None] * {len(outer_carried_inputs)}")
for i, inp in enumerate(outer_carried_inputs):
# set the initial state before the loop
self.writeline(f"{name}[{i}] = {inp}")

cond_outer_inputs = [f"{name}[{i}]" for i in range(len(outer_inputs))]
cond_outer_inputs = [
*[f"{name}[{i}]" for i in range(len(outer_carried_inputs))],
*outer_additional_inputs,
]
cond_outer_outputs = [f"{name}_cond_result"]
body_outer_inputs = list(
cond_outer_inputs
) # same inputs for cond_fn and body_fn

# Carry over the state from body_fn. Note: We only carry over the carried_inputs part of the inputs,
# the additional ones are passed in as they're before.
body_outer_outputs = [f"{name}[{i}]" for i in range(len(outer_carried_inputs))]
# Carry over the state from body_fn. Note: We only carry over
# the carried_inputs part of the inputs, the additional ones
# are passed in as they're before.
body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)]

self.writeline("while True:")
self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph))
Expand Down
9 changes: 9 additions & 0 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7365,6 +7365,15 @@ def create(
for i, output in enumerate(body_outputs)
]

for inp, out in zip(carried_inputs, outputs):
if inp.get_name() in V.graph.graph_inputs:
# if a carried input of the while_loop is a graph input,
# it can be returned as is when the number of iterations
# is zero. due to this, we can't (generally) reuse the
# output buffers corresponding to the graph inputs, as
# the inputs may end up being mutated.
V.graph.never_reuse_buffers.add(out.get_name())

while_loop.outputs = outputs
return outputs

Expand Down

0 comments on commit c773913

Please sign in to comment.