Skip to content
30 changes: 30 additions & 0 deletions test/dynamo/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,36 @@ def fn(a, b):
# one graph now, as we didn't wait for recompile
self.assertEqual(cnt_dynamic.frame_count, 1)

def test_config_compile_ignored(self):
# Remove from this list if no longer relevant
dynamo_guarded_config_ignorelist = {
"log_file_name",
"verbose",
"verify_correctness", # will not affect model, will raise RuntimeError
# (no silent change to compilation behaviour)
"cache_size_limit",
"accumulated_cache_size_limit",
"print_specializations",
"replay_record_enabled",
"cprofile", # only wraps _compile, not graph
"repro_after",
"repro_level",
"repro_forward_only",
"repro_tolerance",
"same_two_models_use_fp64",
"error_on_recompile", # safe because: will throw error
"report_guard_failures",
"report_all_guard_failures",
"base_dir", # used for minifying / logging
"translation_validation",
"translation_validation_timeout",
"translation_validation_no_bisect",
"DEBUG_DIR_VAR_NAME",
"debug_dir_root",
}
for k in dynamo_guarded_config_ignorelist:
assert k in torch._dynamo.config._compile_ignored_keys


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
6 changes: 3 additions & 3 deletions test/inductor/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ def tearDown(self):
def test_set(self):
config.max_fusion_size = 13337
self.assertEqual(config.max_fusion_size, 13337)
self.assertEqual(config.to_dict()["max_fusion_size"], 13337)
config.to_dict()["max_fusion_size"] = 32
self.assertEqual(config.shallow_copy_dict()["max_fusion_size"], 13337)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, is this BC breaking?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a deprecation warning to to_dict

config.max_fusion_size = 32
self.assertEqual(config.max_fusion_size, 32)

# a nested config
prior = config.triton.cudagraphs
config.triton.cudagraphs = not prior
self.assertEqual(config.triton.cudagraphs, not prior)
self.assertEqual(config.to_dict()["triton.cudagraphs"], not prior)
self.assertEqual(config.shallow_copy_dict()["triton.cudagraphs"], not prior)

def test_save_load(self):
config.max_fusion_size = 123
Expand Down
2 changes: 1 addition & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,7 +1591,7 @@ def apply_options(self, options: Optional[Dict[str, Any]]):
return

from torch._inductor import config
current_config: Dict[str, Any] = config.to_dict() # type: ignore[attr-defined]
current_config: Dict[str, Any] = config.shallow_copy_dict() # type: ignore[attr-defined]

for key, val in options.items():
attr_name = key.replace("-", "_")
Expand Down
47 changes: 33 additions & 14 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@
import torch
from . import external_utils


# to configure logging for dynamo, aot, and inductor
# use the following API in the torch._logging module
# torch._logging.set_logs(dynamo=<level>, aot=<level>, inductor<level>)
# or use the environment variable TORCH_LOGS="dynamo,aot,inductor" (use a prefix + to indicate higher verbosity)
# see this design doc for more detailed info
# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#
# the name of a file to write the logs to
# [@compile_ignored: debug]
log_file_name = None

# Verbose will print full stack traces on warnings and errors
# [@compile_ignored: debug] Verbose will print full stack traces on warnings and errors
verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"

# verify the correctness of optimized backend
# [@compile_ignored: runtime_behaviour] verify the correctness of optimized backend
verify_correctness = False

# need this many ops to create an FX graph
Expand All @@ -35,8 +35,10 @@
# controls the maximum number of cache entries with a guard on same ID_MATCH'd
# object. It also controls the maximum size of cache entries if they don't have
# any ID_MATCH'd guards.
# [@compile_ignored: runtime_behaviour]
cache_size_limit = 8
# controls the maximum number of entries for a code object.

# [@compile_ignored: runtime_behaviour] controls the maximum number of entries for a code object.
accumulated_cache_size_limit = 64

# whether or not to specialize on int inputs. This only has an effect with
Expand Down Expand Up @@ -133,18 +135,19 @@

# Record and write an execution record of the current frame to a file
# if an exception is encountered
# @compile_ignored[debug]
replay_record_enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"

# Rewrite assert statement in python with torch._assert
rewrite_assert_with_torch_assert = True

# Show a warning for every specialization
# [@compile_ignored: debug] Show a warning for every specialization
print_specializations = False

# Disable dynamo
disable = os.environ.get("TORCH_COMPILE_DISABLE", False)

# Get a cprofile trace of Dynamo
# [@compile_ignored: runtime_behaviour] Get a cprofile trace of Dynamo
cprofile = os.environ.get("TORCH_COMPILE_CPROFILE", False)

# legacy config, does nothing now!
Expand All @@ -166,12 +169,15 @@
# None - Minifier is switched off
# dynamo - Runs minifier on the TorchDynamo produced graphs, if compilation fails
# aot - Runs minifier on the Aot Autograd produced graphs, if compilation fails
# [@compile_ignored: debug]
repro_after = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None)

# Compiler compilation debug info
# 1: Dumps the original graph out to repro.py if compilation fails
# 2: Dumps a minifier_launcher.py if compilation fails.
# 3: Always dumps a minifier_launcher.py. Good for segfaults.
# 4: Dumps a minifier_launcher.py if the accuracy fails.
# [@compile_ignored: debug]
repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))

# By default, we try to detect accuracy failure by running both forward
Expand All @@ -182,16 +188,19 @@
# backwards step
# TODO: Detect this situation automatically so the user doesn't need
# to manually configure this
# [@compile_ignored: debug]
repro_forward_only = os.environ.get("TORCHDYNAMO_REPRO_FORWARD_ONLY") == "1"

# The tolerance we should use when testing if a compiled graph
# has diverged so that we should treat it as an accuracy failure
# [@compile_ignored: debug]
repro_tolerance = 1e-3

# If True, when testing if two models are the same, we will test them against
# a third fp64 reference and only report a problem if the RMSE relative to the
# fp64 is greater. However, this will use more memory; you may disable this
# if memory usage is too high.
# [@compile_ignored: runtime_behaviour]
same_two_models_use_fp64 = True

# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type.
Expand Down Expand Up @@ -254,23 +263,26 @@

# If true, error if we try to compile a function that has
# been seen before.
# [@compile_ignored: runtime_behaviour]
error_on_recompile = False

# reports why guards fail. Useful to identify the guards failing frequently and
# causing recompilations.
# [@compile_ignored: debug]
report_guard_failures = os.environ.get("TORCHDYNAMO_REPORT_GUARD_FAILURES") == "1"

# Whether to report all guard failures or just the first one that fails
# [@compile_ignored: debug] Whether to report all guard failures or just the first one that fails
report_all_guard_failures = False

# root folder of the project
# [@compile_ignored: debug] root folder of the project
base_dir = dirname(dirname(dirname(abspath(__file__))))

# Uses z3 for validating the guard optimizations transformations.
# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations.
translation_validation = (
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1"
)
# Timeout (in milliseconds) for z3 finding a solution.
# [@compile_ignored: debug]
translation_validation_timeout = int(
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000")
)
Expand All @@ -281,6 +293,7 @@
# since validation uses Z3 for bisecting, it might take a lot of time.
#
# Set this configuration option so as to avoid bisecting.
# [@compile_ignored: debug]
translation_validation_no_bisect = (
os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1"
)
Expand Down Expand Up @@ -308,16 +321,22 @@ def is_fbcode():
return not hasattr(torch.version, "git_version")


DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR"
DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR" # [@compile_ignored: debug]

if DEBUG_DIR_VAR_NAME in os.environ:
debug_dir_root = os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug")
debug_dir_root = os.path.join( # [@compile_ignored: debug]
os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug"
)
elif is_fbcode():
debug_dir_root = os.path.join(tempfile.gettempdir(), "torch_compile_debug")
debug_dir_root = os.path.join( # [@compile_ignored: debug]
tempfile.gettempdir(), "torch_compile_debug"
)
else:
debug_dir_root = os.path.join(os.getcwd(), "torch_compile_debug")

debug_dir_root = os.path.join( # [@compile_ignored: debug]
os.getcwd(), "torch_compile_debug"
)

# [@compile_ignored: debug]
_save_config_ignore = {
"repro_after",
"repro_level",
Expand Down
75 changes: 70 additions & 5 deletions torch/_dynamo/config_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import contextlib

import copy
import inspect
import io
import pickle
import tokenize
import unittest
import warnings
from types import FunctionType, ModuleType
from typing import Any, Dict, Set
from unittest import mock
Expand Down Expand Up @@ -42,13 +46,61 @@ def visit(source, dest, prefix):

config = dict()
default = dict()

compile_ignored_keys = get_assignments_with_compile_ignored_comments(module)

visit(module, module, "")
module._config = config
module._default = default
module._allowed_keys = set(config.keys())
module._compile_ignored_keys = compile_ignored_keys
module.__class__ = ConfigModuleInstance


COMPILE_IGNORED_MARKER = "@compile_ignored"


# Gets all the keys (i.e. assignments) with a @compile_ignored comment
def get_assignments_with_compile_ignored_comments(module):
source_code = inspect.getsource(module)
assignments = set()

# Tokenize the source code to retrieve comments
tokens = tokenize.tokenize(io.BytesIO(source_code.encode("utf-8")).readline)
current_comment = "", -1
prev_name = ""
prev_assigned = "", -1

for token in tokens:
if token.type == tokenize.COMMENT:
maybe_current = token.string.strip()
if COMPILE_IGNORED_MARKER in maybe_current:
assert current_comment == (
"",
-1,
), f"unconsumed {COMPILE_IGNORED_MARKER}"
current_comment = maybe_current, token.start[0]
if token.start[0] == prev_assigned[1]:
# Check if the current assignment is followed with
# a same-line comment with COMPILE_IGNORED_MARKER
assignments.add(prev_assigned[0])
current_comment = "", -1 # reset
elif token.type == tokenize.NAME:
prev_name = token.string
elif token.type == tokenize.OP and token.string == "=":
prev_assigned = prev_name, token.start[0]
# Check if the current assignment follows a comment
# with COMPILE_IGNORED_MARKER
if (
COMPILE_IGNORED_MARKER in current_comment[0]
and current_comment[1] == token.start[0] - 1
):
assignments.add(prev_name)
current_comment = "", -1 # reset
assert current_comment == ("", -1), f"unconsumed {COMPILE_IGNORED_MARKER}"
return assignments


class ConfigModule(ModuleType):
# The default values of the configuration settings. This can be used to
# determine if the config has been changed or not.
Expand All @@ -59,6 +111,7 @@ class ConfigModule(ModuleType):
_config: Dict[str, Any]
_allowed_keys: Set[str]
_bypass_keys: Set[str]
_compile_ignored_keys: Set[str]

def __init__(self):
raise NotImplementedError(
Expand Down Expand Up @@ -106,12 +159,24 @@ def codegen_config(self):
lines.append(f"{mod}.{k} = {v!r}")
return "\n".join(lines)

def load_config(self, data):
"""Restore from a prior call to save_config()"""
self.to_dict().update(pickle.loads(data))

def to_dict(self):
return self._config
warnings.warn(
(
"config.to_dict() has been deprecated. It may no longer change the underlying config.",
"use config.shallow_copy_dict() or config.get_config_copy() instead",
),
DeprecationWarning,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok this seems legit

return self.shallow_copy_dict()

def shallow_copy_dict(self):
return {**self._config}

def load_config(self, config):
"""Restore from a prior call to save_config() or shallow_copy_dict()"""
if not isinstance(config, dict):
config = pickle.loads(config)
self._config.update(config)

def get_config_copy(self):
return copy.deepcopy(self._config)
Expand Down
8 changes: 4 additions & 4 deletions torch/_dynamo/test_minifier_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def _maybe_subprocess_run(self, args, *, isolate, cwd=None):

# NB: Can't use save_config because that will omit some fields,
# but we must save and reset ALL fields
dynamo_config = torch._dynamo.config._config.copy()
inductor_config = torch._inductor.config._config.copy()
dynamo_config = torch._dynamo.config.shallow_copy_dict()
inductor_config = torch._inductor.config.shallow_copy_dict()
try:
stderr = io.StringIO()
log_handler = logging.StreamHandler(stderr)
Expand All @@ -122,8 +122,8 @@ def _maybe_subprocess_run(self, args, *, isolate, cwd=None):
# around
torch._dynamo.reset()
finally:
object.__setattr__(torch._dynamo.config, "_config", dynamo_config)
object.__setattr__(torch._inductor.config, "_config", inductor_config)
torch._dynamo.config.load_config(dynamo_config)
torch._inductor.config.load_config(inductor_config)

# TODO: return a more appropriate data structure here
return subprocess.CompletedProcess(
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def list_options() -> List[str]:

from torch._inductor import config

current_config: Dict[str, Any] = config.to_dict() # type: ignore[attr-defined]
current_config: Dict[str, Any] = config.shallow_copy_dict() # type: ignore[attr-defined]

return list(current_config.keys())

Expand Down