Skip to content

Commit

Permalink
[Dynamo][15/N] Merge allow_in_graph/inline/skip trace rules check int…
Browse files Browse the repository at this point in the history
…o trace_rule.lookup (#118971)

Finally we have this PR to merge allow_in_graph/inline/skip trace rules into ```trace_rules.lookup_inner```, where we can define and lookup trace rules at both function level and file level. Going forward, this is the central place that we define and consulte Dynamo trace rule for any function.
* ```trace_rules.looup``` is the API can return allow_in_graph, inline or skip.
* ```skipfiles.check``` is the API can return inline or skip, since we have multiple places that only do inline/skip check.
  *  I'll move ```skipfiles.check``` to ```trace_rules.check``` as one of the follow-ups.
* Both functions consulte ```trace_rules.lookup_inner``` to get the tracing rule.

To avoid a single big PR, I left a few items as the follow-ups:
* Remove ```skipfiles.py``` and merge the code into ```trace_rules.py```.
* We do double check in ```symbolic_convert.check_inlineable```, will refactor and simplify it. We should only do inline/skip check before generating ```SkipFilesVariable``` and ```UserFunctionVariable```.
* Rename ```SkipFilesVariable``` as ```SkipFunctionVariable```, since we only handle functions.
* The inline/skip reasons are not logged for some cases, since the new lookup framework doesn't always return inline/skip reasons. I'll refactor loggings to record the inline/skip reason in next step.

Pull Request resolved: #118971
Approved by: https://github.com/jansel
  • Loading branch information
yanboliang authored and pytorchmergebot committed Feb 7, 2024
1 parent 284b0b5 commit 0f478d9
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 175 deletions.
75 changes: 36 additions & 39 deletions test/dynamo/test_trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,15 @@
import torch._dynamo.config as config
import torch._dynamo.test_case
import torch._functorch.deprecated as deprecated_func
from torch._dynamo.skipfiles import (
FUNC_INLINELIST,
LEGACY_MOD_INLINELIST,
MOD_INLINELIST,
)
from torch._dynamo.skipfiles import LEGACY_MOD_INLINELIST, MOD_INLINELIST
from torch._dynamo.trace_rules import (
load_object,
manual_torch_name_rule_map,
torch_c_binding_in_graph_functions,
torch_non_c_binding_in_graph_functions,
)
from torch._dynamo.utils import hashable, is_safe_constant, istype
from torch._dynamo.variables import TorchInGraphFunctionVariable
from torch._dynamo.variables import TorchInGraphFunctionVariable, UserFunctionVariable

try:
from .utils import create_dummy_module_and_function
Expand Down Expand Up @@ -282,19 +279,6 @@ def _find_torch_objects(module):
)


def gen_get_func_inlinelist(dummy_func_inlinelist):
def get_func_inlinelist():
inlinelist = set()
for f in dummy_func_inlinelist:
module_name, fn_name = f.rsplit(".", 1)
m = importlib.import_module(module_name)
fn = getattr(m, fn_name)
inlinelist.add(fn.__code__)
return inlinelist

return get_func_inlinelist


class TraceRuleTests(torch._dynamo.test_case.TestCase):
def _check_set_equality(self, generated, used, rule_map, ignored_set):
x = generated - used
Expand All @@ -321,13 +305,6 @@ def test_skipfiles_inlinelist(self):
isinstance(importlib.import_module(m), types.ModuleType),
f"{m} from skipfiles.MOD_INLINELIST/LEGACY_MOD_INLINELIST is not a python module, please check and correct it.",
)
for f in FUNC_INLINELIST:
module_name, fn_name = f.rsplit(".", 1)
m = importlib.import_module(module_name)
self.assertTrue(
isinstance(getattr(m, fn_name), types.FunctionType),
f"{f} from skipfiles.FUNC_INLINELIST is not a python function, please check and correct it.",
)

def test_torch_name_rule_map_updated(self):
# Generate the allowed objects based on heuristic defined in `allowed_functions.py`,
Expand Down Expand Up @@ -363,48 +340,68 @@ def test_torch_name_rule_map_updated(self):
)
)

def test_func_inlinelist_torch_function(self):
def test_force_inline_torch_function(self):
# `torch._dynamo.utils.istype` is skipped by default
def fn(x):
if istype(x, torch.Tensor):
return x + 1
else:
return x - 1

func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy()
func_inlinelist.add("torch._dynamo.utils.istype")
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
# Force inline `torch._dynamo.utils.istype` by setting trace rule.
_manual_torch_name_rule_map["torch._dynamo.utils.istype"] = UserFunctionVariable

_torch_name_rule_map = [
_manual_torch_name_rule_map,
torch_c_binding_in_graph_functions,
torch_non_c_binding_in_graph_functions,
]

self.assertTrue(
"torch._dynamo" not in torch._dynamo.skipfiles.LEGACY_MOD_INLINELIST
)
self.assertTrue("torch._dynamo" not in torch._dynamo.skipfiles.MOD_INLINELIST)

with unittest.mock.patch(
"torch._dynamo.skipfiles.get_func_inlinelist",
gen_get_func_inlinelist(func_inlinelist),
"torch._dynamo.trace_rules.torch_name_rule_map",
_torch_name_rule_map,
), unittest.mock.patch(
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache
):
x = torch.rand(3)
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)

def test_func_inlinelist_third_party_function(self):
def test_force_inline_custom_function(self):
mod, func = create_dummy_module_and_function()

def fn(x):
return func(x)

func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy()
func_inlinelist.add(f"{mod.__name__}.{func.__name__}")
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
# Force inline `mod.func` by setting trace rule.
_manual_torch_name_rule_map[
f"{mod.__name__}.{func.__name__}"
] = UserFunctionVariable

_torch_name_rule_map = [
_manual_torch_name_rule_map,
torch_c_binding_in_graph_functions,
torch_non_c_binding_in_graph_functions,
]

with unittest.mock.patch(
"torch._dynamo.skipfiles.get_func_inlinelist",
gen_get_func_inlinelist(func_inlinelist),
"torch._dynamo.trace_rules.torch_name_rule_map",
_torch_name_rule_map,
), unittest.mock.patch(
"torch._dynamo.skipfiles.SKIP_DIRS",
torch._dynamo.skipfiles.SKIP_DIRS.copy(),
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,
):
# First adding the module to SKIP_DIRS so that it will be skipped.
# First adding the module to SKIP_DIRS so that it will be skipped by default.
torch._dynamo.skipfiles.add(mod.__name__)
x = torch.rand(3)
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
Expand Down
5 changes: 3 additions & 2 deletions torch/_dynamo/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def fn(a):
if isinstance(fn, (list, tuple)):
return [allow_in_graph(x) for x in fn]
assert callable(fn), "allow_in_graph expects a callable"
if trace_rules.lookup(fn) != variables.TorchInGraphFunctionVariable:
if trace_rules.lookup_callable(fn) != variables.TorchInGraphFunctionVariable:
trace_rules._disallowed_callable_ids.remove(id(fn))
trace_rules._allowed_callable_ids.add(id(fn))
return fn
Expand All @@ -106,8 +106,9 @@ def inner(fn):
assert callable(fn), "disallow_in_graph expects a callable"
if (
throw_if_not_allowed
and trace_rules.lookup_callable(fn)
!= variables.TorchInGraphFunctionVariable
and trace_rules.lookup(fn) != variables.TorchInGraphFunctionVariable
and fn not in trace_rules._allowed_callable_ids
):
raise IncorrectUsage(
"disallow_in_graph is expected to be used on an already allowed callable (like torch.* ops). "
Expand Down
51 changes: 16 additions & 35 deletions torch/_dynamo/skipfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
from ..utils import _config_module
from .utils import getfile

from .variables.functions import (
from .variables import (
FunctorchVmapHigherOrderVariable,
NestedUserFunctionVariable,
SkipFilesVariable,
UserFunctionVariable,
UserMethodVariable,
)
Expand Down Expand Up @@ -160,17 +162,6 @@ def _module_dir(m: types.ModuleType):
return file and _strip_init_py(file)


# TODO: Add a decoractor for easily adding functions to FUNC_INLINELIST
# after resolving all circular import issues.
FUNC_INLINELIST = {
"torch._constrain_as_size",
"torch._constrain_as_value",
"torch._tensor._convert",
"torch.backends.mha.get_fastpath_enabled",
"torch.jit._unwrap_optional",
}


# These are legacy workarounds, don't add new modules to this list.
# Please use the MOD_INLINELIST instead to force inline functions under particular modules.
LEGACY_MOD_INLINELIST = {
Expand Down Expand Up @@ -240,18 +231,6 @@ def _module_dir(m: types.ModuleType):
MOD_INLINELIST.add("torch.distributed._functional_collectives")


# TODO: support adding bound method into this list
@functools.lru_cache(None)
def get_func_inlinelist():
inlinelist = set()
for f in FUNC_INLINELIST:
module_name, fn_name = f.rsplit(".", 1)
m = importlib.import_module(module_name)
fn = getattr(m, fn_name)
inlinelist.add(fn.__code__)
return inlinelist


@functools.lru_cache(None)
def get_legacy_mod_inlinelist():
inlinelist = set()
Expand Down Expand Up @@ -401,20 +380,22 @@ def check_verbose(obj, is_inlined_call=False):
)
else:
fi = FunctionInfo(obj, None, getfile(obj), None)
# Go through function based skip/inline rules.
if fi.code in get_func_inlinelist():

# Consulte the central trace rules defined in torch._dynamo.trace_rules.
rule = torch._dynamo.trace_rules.lookup_inner(
fi.py_obj, fi.name, fi.filename, is_inlined_call
)
if rule in [UserFunctionVariable, FunctorchVmapHigherOrderVariable]:
return SkipResult(
False,
"inlined according skipfiles.FUNC_INLINELIST",
"inlined according trace_rules.lookup",
)
else:
assert rule == SkipFilesVariable, rule
return SkipResult(
True,
"skipped according trace_rules.lookup",
)
if is_inlined_call:
if fi.name == "patched_init":
return SkipResult(True, "patched init cannot be inlined.")
elif fi.name == "__torch_function__":
return SkipResult(False, "allow inlining __torch_function__")

# Go through file based skip/inline rules.
return check_file(fi.filename, is_inlined_call)


def check(obj, is_inlined_call=False):
Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
InlinedClosureVariable,
NullVariable,
PythonModuleVariable,
SkipFilesVariable,
UnknownVariable,
)
from .variables.nn_module import NNModuleVariable
Expand Down Expand Up @@ -2290,6 +2291,8 @@ def check_inlineable(func):
def inline_call_(
parent, func: VariableTracker, args: List[VariableTracker], kwargs
):
if isinstance(func, SkipFilesVariable):
unimplemented("inline with functions in skip files")
assert isinstance(
func,
(UserFunctionVariable, NestedUserFunctionVariable),
Expand Down
69 changes: 54 additions & 15 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

import torch

from .utils import hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper
from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper

from .variables import (
BuiltinVariable,
FunctorchVmapHigherOrderVariable,
SkipFilesVariable,
TorchInGraphFunctionVariable,
Expand Down Expand Up @@ -151,6 +152,11 @@
"torch._functorch.vmap.unwrap_batched": UserFunctionVariable,
"torch._functorch.vmap.vmap_impl": FunctorchVmapHigherOrderVariable,
"torch._functorch.vmap.wrap_batched": UserFunctionVariable,
"torch._constrain_as_size": UserFunctionVariable,
"torch._constrain_as_value": UserFunctionVariable,
"torch._tensor._convert": UserFunctionVariable,
"torch.jit._unwrap_optional": UserFunctionVariable,
"torch.backends.mha.get_fastpath_enabled": UserFunctionVariable,
}


Expand Down Expand Up @@ -2062,8 +2068,6 @@
"torch._check_with",
"torch._check",
"torch._compile._disable_dynamo",
"torch._constrain_as_size",
"torch._constrain_as_value",
"torch._functorch.apis.chunk_vmap",
"torch._functorch.autograd_function.custom_function_call_functionalize",
"torch._functorch.autograd_function.custom_function_call_grad",
Expand Down Expand Up @@ -2765,8 +2769,7 @@ def load_object(name):
else:
assert len(x) == 1, f"Invalid obj name {name}"
val = _load_obj_from_str(x[0])
if hasattr(val, "__wrapped__"):
val = val.__wrapped__
val = unwrap_if_wrapper(val)
except (AttributeError, ImportError):
val = None
return val
Expand Down Expand Up @@ -2969,23 +2972,59 @@ def is_numpy(obj) -> bool:
return isinstance(obj, (np.ndarray, np.generic)) or id(obj) in _numpy_function_ids


"""
Main entry point for looking up the trace rule (the Dynamo variable) for a given callable object.
"""


def lookup_callable(obj):
if not hashable(obj):
return None
# Custom allow/disallow in graph takes precedence over the general lookup.
if is_callable_disallowed(obj):
return SkipFilesVariable
if is_callable_allowed(obj):
return TorchInGraphFunctionVariable
if is_builtin_callable(obj):
return BuiltinVariable


"""
Main entry point for looking up the trace rule (the Dynamo variable) for a given function object.
E.g, the lookup result of `torch.sin` is `TorchInGraphFunctionVariable`.
"""


def lookup(obj):
# Unwrap if it's a functools.lru_cache wrapper
obj = unwrap_if_wrapper(obj)
return lookup_inner(obj)


def lookup_inner(obj, name=None, filename=None, is_direct_call=True):
# Step 1: lookup obj's tracing rule in `torch_name_rule_map`.
# The rules defined in `torch_name_rule_map` mainly includes two parts:
# - Manually defined rules for any functions.
# - The list of torch in graph functions.
if not hashable(obj):
return None
# Custom allow/disallow in graph takes precedence over the `torch_name_rule_map`.
if callable(obj) and is_callable_disallowed(obj):
if obj is not None:
if is_aten_op_or_tensor_method(obj):
return TorchInGraphFunctionVariable
rule = get_torch_obj_rule_map().get(obj, None)
if rule is not None:
return rule

# Step 2: lookup obj's tracing rule by function name.
if is_direct_call:
if name == "patched_init":
return SkipFilesVariable
elif name == "__torch_function__":
return UserFunctionVariable

# Step 3: lookup obj's tracing rule by filename.
if filename is None:
filename = getfile(obj)

if torch._dynamo.skipfiles.check_file(filename, is_direct_call).skipped:
return SkipFilesVariable
if callable(obj) and is_callable_allowed(obj):
return TorchInGraphFunctionVariable
if is_aten_op_or_tensor_method(obj):
return TorchInGraphFunctionVariable
rule = get_torch_obj_rule_map().get(obj, None)
return rule
else:
return UserFunctionVariable

0 comments on commit 0f478d9

Please sign in to comment.