Skip to content

Commit

Permalink
Move ShapeEnv config out of dynamo (#112933)
Browse files Browse the repository at this point in the history
Previously there was a circular dependency between fx and dynamo that happened
to work out since ShapeEnv didn't access the config at module init time.

Pull Request resolved: #112933
Approved by: https://github.com/ezyang
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Nov 7, 2023
1 parent b4dbb02 commit 65ecb36
Show file tree
Hide file tree
Showing 15 changed files with 82 additions and 64 deletions.
4 changes: 2 additions & 2 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3390,7 +3390,7 @@ def run(runner, args, original_dir=None):
args.repeat = 2

# Set translation validation on by default on CI accuracy runs.
torch._dynamo.config.translation_validation = True
torch.fx.experimental._config.translation_validation = True

if args.dynamic_ci_skips_only:
# Test only the incremental set of jobs whose skipped was
Expand Down Expand Up @@ -3712,7 +3712,7 @@ def run(runner, args, original_dir=None):

if args.no_translation_validation:
# Overwrite 'translation_validation' config, if specified.
torch._dynamo.config.translation_validation = False
torch.fx.experimental._config.translation_validation = False

experiment = functools.partial(experiment, args, runner.model_iter_fn)

Expand Down
4 changes: 0 additions & 4 deletions test/dynamo/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def test_config_compile_ignored(self):
# (no silent change to compilation behaviour)
"cache_size_limit",
"accumulated_cache_size_limit",
"print_specializations",
"replay_record_enabled",
"cprofile", # only wraps _compile, not graph
"repro_after",
Expand All @@ -79,9 +78,6 @@ def test_config_compile_ignored(self):
"report_guard_failures",
"report_all_guard_failures",
"base_dir", # used for minifying / logging
"translation_validation",
"translation_validation_timeout",
"translation_validation_no_bisect",
"DEBUG_DIR_VAR_NAME",
"debug_dir_root",
}
Expand Down
5 changes: 3 additions & 2 deletions test/dynamo/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from torch._dynamo import config
from torch._dynamo.testing import make_test_cls_with_patches
from torch.fx.experimental import _config as fx_config
from torch.testing._internal.common_utils import TEST_Z3

try:
Expand Down Expand Up @@ -44,8 +45,8 @@ def make_dynamic_cls(cls):
suffix,
(config, "assume_static_by_default", False),
(config, "specialize_int", False),
(config, "translation_validation", TEST_Z3),
(config, "check_shape_env_recorded_events", True),
(fx_config, "translation_validation", TEST_Z3),
(fx_config, "check_shape_env_recorded_events", True),
xfail_prop="_expected_failure_dynamic",
)

Expand Down
12 changes: 8 additions & 4 deletions test/dynamo/test_exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,13 @@ def fn001(x):

@skipIf(not TEST_Z3, "z3 not installed")
@torch._dynamo.config.patch(
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
assume_static_by_default=False,
suppress_errors=False,
)
@torch.fx.experimental._config.patch(
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
translation_validation=True,
translation_validation_no_bisect=True,
suppress_errors=False,
)
def test_trigger_on_error(self):
from torch.fx.experimental.validator import ValidationException
Expand Down Expand Up @@ -262,11 +264,13 @@ def fn(x, shape):

@skipIf(not TEST_Z3, "z3 not installed")
@torch._dynamo.config.patch(
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
assume_static_by_default=False,
translation_validation=True,
suppress_errors=False,
)
@torch.fx.experimental._config.patch(
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
translation_validation=True,
)
def test_trigger_bisect_on_error(self):
from torch.fx.experimental.validator import BisectValidationException

Expand Down
1 change: 1 addition & 0 deletions test/dynamo/test_trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"torch.package.package_exporter.PackageExporter",
"torch.serialization._opener",
"torch.sparse.check_sparse_tensor_invariants",
"torch.utils._config_module.ContextDecorator",
"torch.utils._contextlib._DecoratorContextManager",
"torch.utils._device.DeviceContext",
"torch.utils._python_dispatch.TorchDispatchMode",
Expand Down
2 changes: 1 addition & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,7 +1612,7 @@ def f(a, b, c, d, e):
fx_g = _trace(f, 2, 4, 8, 16, 32)
self.assertExpectedInline(show_guards(fx_g), """""")

@torch._dynamo.config.patch(translation_validation=True)
@torch.fx.experimental._config.patch(translation_validation=True)
def test_constant_specialization(self):
def f(t):
assert t.shape[0] == 10
Expand Down
34 changes: 1 addition & 33 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@
# Rewrite assert statement in python with torch._assert
rewrite_assert_with_torch_assert = True

# [@compile_ignored: debug] Show a warning for every specialization
print_specializations = False

# Disable dynamo
disable = os.environ.get("TORCH_COMPILE_DISABLE", False)

Expand Down Expand Up @@ -277,30 +274,6 @@
# [@compile_ignored: debug] root folder of the project
base_dir = dirname(dirname(dirname(abspath(__file__))))

# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations.
translation_validation = (
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1"
)
# Timeout (in milliseconds) for z3 finding a solution.
# [@compile_ignored: debug]
translation_validation_timeout = int(
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000")
)
# Disables bisection for translation validation.
#
# Translation validation bisection is enabled by default, if translation validation
# is also enabled. This should help finding guard simplification issues. However,
# since validation uses Z3 for bisecting, it might take a lot of time.
#
# Set this configuration option so as to avoid bisecting.
# [@compile_ignored: debug]
translation_validation_no_bisect = (
os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1"
)
# Checks whether replaying ShapeEnv events on a freshly constructed one yields
# the a ShapeEnv with the same state. This should be used only in testing.
check_shape_env_recorded_events = False

# Trace through NumPy or graphbreak
trace_numpy = True

Expand Down Expand Up @@ -361,11 +334,6 @@ def is_fbcode():
# used for testing
inject_BUILD_SET_unimplemented_TESTING_ONLY = False

# wraps (un)equalities with 'Not' class after recording the correct expression
# in the FX graph. This should incorrectly construct the divisible and replacement
# lists, and incorrectly issue guards.
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False

_autograd_backward_strict_mode_banned_ops = [
"stride",
"requires_grad",
Expand All @@ -383,6 +351,6 @@ def is_fbcode():
# WARNING: this is an experimental flag and is subject to change.
_experimental_support_context_fn_in_torch_utils_checkpoint = False

from .config_utils import install_config_module
from torch.utils._config_module import install_config_module

install_config_module(sys.modules[__name__])
2 changes: 2 additions & 0 deletions torch/_dynamo/debug_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,11 @@ def generate_config_string(*, stable_output=False):
import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
import torch.fx.experimental._config
{torch._dynamo.config.codegen_config()}
{torch._inductor.config.codegen_config()}
{torch._functorch.config.codegen_config()}
{torch.fx.experimental._config.codegen_config()}
"""


Expand Down
2 changes: 1 addition & 1 deletion torch/_functorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
max_dist_from_bw = 3


from .._dynamo.config_utils import install_config_module
from torch.utils._config_module import install_config_module

# adds patch, save_config, invalid config checks, etc
install_config_module(sys.modules[__name__])
2 changes: 1 addition & 1 deletion torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ class trace:
}


from .._dynamo.config_utils import install_config_module
from torch.utils._config_module import install_config_module

# adds patch, save_config, etc
install_config_module(sys.modules[__name__])
39 changes: 39 additions & 0 deletions torch/fx/experimental/_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import sys

# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations.
translation_validation = (
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1"
)
# Timeout (in milliseconds) for z3 finding a solution.
# [@compile_ignored: debug]
translation_validation_timeout = int(
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000")
)
# Disables bisection for translation validation.
#
# Translation validation bisection is enabled by default, if translation validation
# is also enabled. This should help finding guard simplification issues. However,
# since validation uses Z3 for bisecting, it might take a lot of time.
#
# Set this configuration option so as to avoid bisecting.
# [@compile_ignored: debug]
translation_validation_no_bisect = (
os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1"
)
# Checks whether replaying ShapeEnv events on a freshly constructed one yields
# the a ShapeEnv with the same state. This should be used only in testing.
check_shape_env_recorded_events = False


# [@compile_ignored: debug] Show a warning for every specialization
print_specializations = False

# wraps (un)equalities with 'Not' class after recording the correct expression
# in the FX graph. This should incorrectly construct the divisible and replacement
# lists, and incorrectly issue guards.
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False

from torch.utils._config_module import install_config_module

install_config_module(sys.modules[__name__])
9 changes: 5 additions & 4 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
import torch.fx
import torch.fx.traceback as fx_traceback
from torch.fx.experimental import _config as config

from torch.fx.experimental.recording import (
FakeTensorMeta,
Expand Down Expand Up @@ -1445,15 +1446,15 @@ def __init__(
if should_record_events is not None
else (
self._translation_validation_enabled
and not torch._dynamo.config.translation_validation_no_bisect
and not config.translation_validation_no_bisect
)
)

# Enable event recording check if both:
# - It should record events
# - The recording check is enabled
self.check_recorded_events = (
self.should_record_events and torch._dynamo.config.check_shape_env_recorded_events
self.should_record_events and config.check_shape_env_recorded_events
)

# This will make sure we only record the top-level function call.
Expand Down Expand Up @@ -3039,7 +3040,7 @@ def _set_replacement(self, a: "sympy.Symbol", expr: "sympy.Expr") -> None:
Adds or updates a replacement for a symbol.
Use this instead of `self.replacements[a] = expr`.
"""
if torch._dynamo.config.print_specializations and isinstance(expr, (sympy.Integer, sympy.Float)):
if config.print_specializations and isinstance(expr, (sympy.Integer, sympy.Float)):
# specializing to a constant, which is likely unexpected

# NOTE(avik): It is possible that we try logging the same specialization multiple times, e.g.,
Expand Down Expand Up @@ -3305,7 +3306,7 @@ def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None):
self._check_frozen(expr, concrete_val)

if (
torch._dynamo.config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
and isinstance(hint, bool)
and isinstance(expr, (sympy.Eq, sympy.Ne))
):
Expand Down
4 changes: 2 additions & 2 deletions torch/fx/experimental/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def validate(self) -> None:
"ValidationException", "BisectValidationException",
]

from torch._dynamo import config
from torch.fx.experimental import _config as config

def translation_validation_enabled() -> bool:
# Checks everytime this function is called, in case the Dynamo
Expand Down Expand Up @@ -684,7 +684,7 @@ def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]:
log.info("translation validation succeeded: no errors found.")
return

if not shape_env.should_record_events or torch._dynamo.config.translation_validation_no_bisect:
if not shape_env.should_record_events or config.translation_validation_no_bisect:
# Bisection is off.
# Return the last ValidationException we got.
raise last_exception
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,7 +1339,7 @@ def wrapper(*args, **kwargs):
TEST_WITH_TV = os.getenv('PYTORCH_TEST_WITH_TV') == '1'

if TEST_WITH_TV:
torch._dynamo.config.translation_validation = True
torch.fx.experimental._config.translation_validation = True

# Some tests take too long when dynamic_shapes is combined with
# translation_validation. Whenever that happens, we solve that by
Expand All @@ -1349,7 +1349,7 @@ def disable_translation_validation_if_dynamic_shapes(fn):
def wrapper(*args, **kwargs):
if torch._dynamo.config.dynamic_shapes:
# Turning TV off due to high latency on dynamic shapes.
torch._dynamo.config.translation_validation = False
torch.fx.experimental._config.translation_validation = False
return fn(*args, **kwargs)
return wrapper

Expand Down
22 changes: 14 additions & 8 deletions torch/_dynamo/config_utils.py → torch/utils/_config_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def visit(source, dest, prefix):
else:
raise AssertionError(f"Unhandled config {key}={value} ({type(value)})")

config = dict()
default = dict()
config: Dict[str, Any] = dict()
default: Dict[str, Any] = dict()

compile_ignored_keys = get_assignments_with_compile_ignored_comments(module)

Expand Down Expand Up @@ -115,6 +115,8 @@ class ConfigModule(ModuleType):
_allowed_keys: Set[str]
_bypass_keys: Set[str]
_compile_ignored_keys: Set[str]
_is_dirty: bool
_hash_digest: bytes

def __init__(self):
raise NotImplementedError(
Expand Down Expand Up @@ -177,10 +179,8 @@ def get_hash(self):

def to_dict(self):
warnings.warn(
(
"config.to_dict() has been deprecated. It may no longer change the underlying config.",
"use config.shallow_copy_dict() or config.get_config_copy() instead",
),
"config.to_dict() has been deprecated. It may no longer change the underlying config."
" use config.shallow_copy_dict() or config.get_config_copy() instead",
DeprecationWarning,
)
return self.shallow_copy_dict()
Expand Down Expand Up @@ -227,7 +227,7 @@ def foo(...):
changes = kwargs
assert arg2 is None
assert isinstance(changes, dict), f"expected `dict` got {type(changes)}"
prior = {}
prior: Dict[str, Any] = {}
config = self
dirty = False

Expand Down Expand Up @@ -257,10 +257,16 @@ class ContextDecorator(contextlib.ContextDecorator):
`unittest.TestCase`
"""

def __enter__(self):
raise NotImplementedError("NYI")

def __exit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError("NYI")

def __call__(self, func):
if isinstance(func, type) and issubclass(func, unittest.TestCase):

class _TestCase(func):
class _TestCase(func): # type: ignore[valid-type, misc]
@classmethod
def setUpClass(cls):
self.__enter__()
Expand Down

0 comments on commit 65ecb36

Please sign in to comment.