Skip to content

Commit

Permalink
[cudagraph] add config for cudagraph managed input mutation support (#…
Browse files Browse the repository at this point in the history
…124754)

Summary: [#123231](#123231) adds cudagraph supports for more types of functions (i.e., cudagraph managed input mutation). These newly supported functions may have mutated static inputs, leading to assertion errors in some workload which skip cudagraph previously. This diff adds a config to opt in the new feature.

Test Plan: ci

Differential Revision: D56481353

Pull Request resolved: #124754
Approved by: https://github.com/eellison
  • Loading branch information
BoyuanFeng authored and pytorchmergebot committed Apr 24, 2024
1 parent bee924d commit b91f83f
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 0 deletions.
38 changes: 38 additions & 0 deletions test/inductor/test_cudagraph_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ def foo(x, y):

@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
def test_mutation_on_inp(self, backend):
def foo(x):
x.add_(2)
Expand Down Expand Up @@ -339,6 +341,38 @@ def foo(mod, x):

@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", False)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
def test_mutation_cudagraph_managed_tensors_config(self, backend):
def foo(x):
return x + 1

def mut(x):
x.add_(2)
return x

def non_mut(x):
return x.add(2)

mut = get_compile_fn(backend)(mut)
foo = get_compile_fn(backend)(foo)

with capture_stderr() as captured_output:
for i in range(3):
torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")

tmp = foo(inp)
mut_out = mut(tmp)
self.assertEqual(mut_out, non_mut(foo(inp)))
FileCheck().check_count(
"skipping cudagraphs due to mutation on input.", 1, exactly=True
).run(captured_output[0])

@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
def test_mutation_cudagraph_managed_tensors(self, backend):
def foo(x):
return x + 1
Expand Down Expand Up @@ -380,6 +414,8 @@ def non_mut(x):

@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
def test_mutation_cudagraph_managed_tensor_warn(self, backend):
def foo(x):
return x.add_(1)
Expand All @@ -403,6 +439,8 @@ def inp():

@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
def test_mutation_cudagraph_managed_tensor_warn_only_once(self, backend):
def foo(x):
return x + 1
Expand Down
19 changes: 19 additions & 0 deletions torch/_dynamo/backends/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from typing import Dict, List, Optional

import torch
from torch._dynamo import config
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.backends.debugging import boxed_nop
from torch._inductor.cudagraph_utils import (
BoxedDeviceIndex,
check_multiple_devices_or_any_cpu_nodes,
get_mutation_stack_trace,
get_placeholders,
)
from torch._inductor.utils import (
Expand Down Expand Up @@ -74,7 +76,24 @@ def get_device_node_mapping(gm: torch.fx.GraphModule):
return device_node_mapping


def check_for_mutation_ignore_cuda_graph_managed_tensor(
aot_model: torch.fx.GraphModule, num_fixed
) -> Optional[str]:
mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
if not mutation_indices:
return None

placeholders = [node for node in aot_model.graph.nodes if node.op == "placeholder"]
return get_mutation_stack_trace(placeholders, mutation_indices)


def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
if not config.cudagraph_backend_support_input_mutation:
if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor(
aot_model, num_fixed
):
return mut_skip

if skip := check_multiple_devices_or_any_cpu_nodes(
get_device_node_mapping(aot_model)
):
Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,9 @@ def default_debug_dir_root():
# can prevent cudagraphing.
cudagraph_backend_keep_input_mutation = False

# enable cudagraph support for mutated inputs from prior cudagraph pool
cudagraph_backend_support_input_mutation = False

# When True, only ops that have the torch.Tag.pt2_compliant tag
# will be allowed into the graph; all other ops will be disallowed
# and will fall back to eager-mode PyTorch. Useful to ensure
Expand Down
18 changes: 18 additions & 0 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,25 @@ def compile_fx_inner(
if isinstance(t, torch.Tensor)
)

if not config.triton.cudagraph_support_input_mutation:
# Skip supports for cudagraph-managed tensors
from torch._inductor.cudagraph_utils import (
check_for_mutation_ignore_cuda_graph_managed_tensor,
)

has_mutation_str = check_for_mutation_ignore_cuda_graph_managed_tensor(
gm, compiled_graph, num_fixed
)
has_mutation = has_mutation_str is not None

if has_mutation:
compiled_graph.disabled_cudagraphs_reason = has_mutation_str
else:
# Check mutation later to support cudagraph-managed tensors
has_mutation = None

cudagraph_tests = [
(not has_mutation, "mutated inputs"),
(not has_incompatible_cudagraph_ops(gm), "incompatible ops"),
(not complex_memory_overlap_inputs, "complex memory overlap"),
(
Expand Down
3 changes: 3 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,9 @@ class triton:
# TODO - need to debug why this prevents cleanup
cudagraph_trees_history_recording = False

# Enable cudagraph support for mutated inputs from prior cudagraph pool
cudagraph_support_input_mutation = False

# synchronize after cudagraph invocation
force_cudagraph_sync = False

Expand Down
22 changes: 22 additions & 0 deletions torch/_inductor/cudagraph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,25 @@ class BoxedDeviceIndex:
def set(self, device_idx: Optional[int]):
assert device_idx is None or isinstance(device_idx, int)
self.value = device_idx


def check_for_mutation_ignore_cuda_graph_managed_tensor(
gm: torch.fx.GraphModule, compiled_graph, num_fixed: int
) -> Optional[str]:
default_msg = format_default_skip_message("mutated inputs")

# doesnt work for non-trees because the warmup run would apply mutation twice
if torch._inductor.config.triton.cudagraph_trees:
# checking if mutation is only on parameters/static inputs
mutation_indices = [
idx for idx in compiled_graph.mutated_input_idxs if idx >= num_fixed
]
has_mutation = len(mutation_indices) != 0
if not has_mutation:
return None
placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
return get_mutation_stack_trace(placeholders, mutation_indices)

else:
has_mutation = len(compiled_graph.mutated_inputs) != 0
return None if not has_mutation else default_msg

0 comments on commit b91f83f

Please sign in to comment.