Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/inductor/test_cudagraph_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,28 @@ def forward(self, input, other, out):
"skipping cudagraphs due to mutaton on input. Found from"
).check("torch.logical_xor").run(captured_output[0])

@requires_multigpu()
def test_multiple_devices_msg(self):
@torch.compile()
def foo(x, y):
return (x + 1, y + 2)

with capture_stderr() as captured_output:
foo(torch.ones([10], device="cuda"), torch.ones([20]))

FileCheck().check("skipping cudagraphs due to cpu device.").check(
"y + 2"
).run(captured_output[0])

with capture_stderr() as captured_output:
foo(
torch.ones([10], device="cuda:0"), torch.ones([10], device="cuda:1")
)

FileCheck().check("skipping cudagraphs due to multiple devices").run(
captured_output[0]
)

def test_mutation(self):
@torch.compile()
def foo(x):
Expand Down
17 changes: 9 additions & 8 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ def compile_fx_inner(
compiled_graph.disabled_cudagraphs_reason = has_mutation_str

cudagraph_tests = [
(set(compiled_graph.device_types) == {"cuda"}, "non-cuda device in graph"),
(not has_mutation, "mutated inputs"),
(not has_incompatible_cudagraph_ops(gm), "incompatible ops"),
(not complex_memory_overlap_inputs, "complex memory overlap"),
Expand All @@ -370,13 +369,6 @@ def compile_fx_inner(
),
"non-Tensor inputs",
),
(
(
len(compiled_graph.device_idxs) == 1
or not config.triton.cudagraph_trees
),
"multiple device indices with cudagraph_trees",
),
]
cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]

Expand Down Expand Up @@ -564,6 +556,15 @@ def fx_codegen_and_compile(
if V.aot_compilation is True:
return compiled_fn

if cudagraphs and not V.graph.disable_cudagraphs_reason:
from torch._inductor.cudagraph_utils import (
check_lowering_disable_cudagraph,
)

V.graph.disable_cudagraphs_reason = check_lowering_disable_cudagraph(
V.graph.device_node_mapping
)

compiled_graph = CompiledFxGraph(
compiled_fn, graph, output_strides, V.graph.disable_cudagraphs_reason
)
Expand Down
36 changes: 35 additions & 1 deletion torch/_inductor/cudagraph_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Dict, Optional

import torch
from torch._inductor.codecache import CompiledFxGraph
Expand Down Expand Up @@ -54,3 +54,37 @@ def check_for_mutation(
else:
has_mutation = len(compiled_graph.mutated_inputs) != 0
return None if not has_mutation else default_msg


def get_use_stack_trace(node) -> Optional[str]:
for use in node.users:
if stack_trace := use.meta.get("stack_trace", None):
return stack_trace
return None


def check_multiple_devices_or_any_cpu_nodes(
device_node_mapping: Dict[torch.device, torch.fx.Node]
) -> Optional[str]:
if cpu_node := device_node_mapping.get(torch.device("cpu")):
if stack_trace := get_use_stack_trace(cpu_node):
return format_default_skip_message(
f"cpu device. Found from : \n {stack_trace}"
)

return format_default_skip_message("cpu device")

if (
len(device_node_mapping) == 1
and next(iter(device_node_mapping.keys())).type == "cuda"
):
return None

keys_repr = (repr(key) for key in device_node_mapping.keys())
return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}")


def check_lowering_disable_cudagraph(
device_node_mapping: Dict[torch.device, torch.fx.Node]
):
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
8 changes: 6 additions & 2 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,10 @@ def __init__(
[]
) # This is the linemap used by the profiler to mark custom compiled kernels getting run
# Used if lowering encounters cases where cudagraphs are not supported
self.disable_cudagraphs = False
self.disable_cudagraphs_reason = ""
self.disable_cudagraphs_reason: Optional[str] = None

# only keeping one node per device for stack trace purposes
self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
self.orig_gm: torch.fx.GraphModule = gm.__copy__()
self.dynamo_flat_name_to_original_fqn = self.module.meta.get(
"dynamo_flat_name_to_original_fqn", {}
Expand Down Expand Up @@ -488,6 +490,8 @@ def add_device_info(self, device: torch.device):
self.device_types.add(device.type)
if device.index is not None:
self.device_idxs.add(device.index)
if V.graph.current_node and device not in self.device_node_mapping:
self.device_node_mapping[device] = V.graph.current_node

@property
def fake_mode(self):
Expand Down
1 change: 0 additions & 1 deletion torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3656,7 +3656,6 @@ def unflatten_args(new_tensor_args, new_non_tensor_args):
)
for t in example_out_li:
if isinstance(t, torch.Tensor) and t.is_sparse:
V.graph.disable_cudagraphs = True
msg = "sparsity not handled. Please file issue for sparse inference weights."
if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
msg = f"{msg} Found from : \n {stack_trace}"
Expand Down
1 change: 0 additions & 1 deletion torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2986,7 +2986,6 @@ def index_put_as_masked_fill(self, indices, value, accumulate):
def index_put_fallback(self, indices, values, accumulate):
deterministic = torch.are_deterministic_algorithms_enabled()
if is_triton(values) and (accumulate or deterministic):
V.graph.disable_cudagraphs = True
msg = (
"index put with accumulate."
if not deterministic
Expand Down