Skip to content

Commit

Permalink
Better graph break msg (and warning) on Dynamo x Python C++ extension (
Browse files Browse the repository at this point in the history
…#127301)

Dynamo graph breaks on Python C/C++ extensions (e.g. pybinded
functions). The usual way to handle this is to turn those extensions
into custom ops. This PR adds a nicer graph break message and also
changes it to unconditionally warn on this graph break (because graph
break messages are usually not visible).

Fixes #126799

Test Plan:
- new test

Pull Request resolved: #127301
Approved by: https://github.com/jansel
ghstack dependencies: #127291, #127292, #127400, #127423
  • Loading branch information
zou3519 authored and pytorchmergebot committed May 30, 2024
1 parent c9beea1 commit ffe506e
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
33 changes: 33 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import torch.onnx.operators

import torch.utils._pytree as pytree
import torch.utils.cpp_extension
from torch import Tensor
from torch._C import FileCheck
from torch._dynamo import allow_in_graph
Expand Down Expand Up @@ -223,6 +224,38 @@ def fn(x):
with self.assertRaises(TypeError):
fn(torch.randn(16))

def test_cpp_extension_recommends_custom_ops(self):
cpp_source = """
#include <torch/extension.h>
at::Tensor foobar(const at::Tensor& x) {
return x.clone();
}
"""
module = torch.utils.cpp_extension.load_inline(
name="mylib",
cpp_sources=cpp_source,
functions="foobar",
verbose=True,
)

x = torch.ones(2, 2, requires_grad=True)
counters.clear()

@torch.compile(backend="eager")
def f(x):
return module.foobar(x)

with self.assertWarnsOnceRegex(
UserWarning, ".*https://pytorch.org/docs/main/notes/custom_operators.html.*"
):
f(x)
self.assertEqual(len(counters["graph_break"]), 1)
first_graph_break = list(counters["graph_break"].keys())[0]
self.assertExpectedInline(
first_graph_break,
"""Graph break due to unsupported builtin mylib.PyCapsule.foobar. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/docs/main/notes/custom_operators.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.""",
)

def test_callpacked(self):
def call_packed(args):
a, b, c = args
Expand Down
26 changes: 24 additions & 2 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import inspect
import itertools
import types
import warnings
from typing import Dict, List, Optional, TYPE_CHECKING, Union

import torch
Expand Down Expand Up @@ -634,9 +635,30 @@ def wraps(fn):
else:
try:
path = inspect.getfile(self.value)
msg = f"'skip function {self.value.__qualname__} in file {path}'"
except TypeError:
path = f"Builtin {self.value.__name__}"
msg = f"'skip function {self.value.__qualname__} in file {path}'"
known_python_builtin_modules = {"_abc", "_warnings"}
if self.value.__module__ in known_python_builtin_modules:
msg = (
f"Graph break due to unsupported Python builtin {self.value.__module__}.{self.value.__qualname__}. "
f"Please file an issue on GitHub "
f"so the PyTorch team can add support for it. "
)
else:
msg = (
f"Graph break due to unsupported builtin {self.value.__module__}.{self.value.__qualname__}. "
f"This function is either a Python builtin (e.g. _warnings.warn) "
f"or a third-party C/C++ Python extension (perhaps created with pybind). "
f"If it is a Python builtin, please file an issue on GitHub "
f"so the PyTorch team can add support for it and see the next case for a workaround. "
f"If it is a third-party C/C++ Python extension, please "
f"either wrap it into a PyTorch-understood custom operator "
f"(see https://pytorch.org/docs/main/notes/custom_operators.html "
f"for more details) or, if it is traceable, use "
f"torch.compiler.allow_in_graph."
)
# also warn on it because most users won't see the graph break message
warnings.warn(msg)
msg += f"', {self.reason}'" if self.reason else ""
unimplemented(msg)

Expand Down

0 comments on commit ffe506e

Please sign in to comment.