Skip to content

Commit

Permalink
Make inductor config hashing more portable (#127022)
Browse files Browse the repository at this point in the history
Summary: masnesral and I noticed that config contains non portable artifacts. Lets fix that.

Test Plan: adhoc testing

Differential Revision: D57748025

Pull Request resolved: #127022
Approved by: https://github.com/masnesral
  • Loading branch information
oulgen authored and pytorchmergebot committed May 25, 2024
1 parent 6656377 commit 52bcf12
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
17 changes: 8 additions & 9 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,13 @@ def dumps(cls, obj) -> bytes:
"""
with io.BytesIO() as stream:
pickler = cls(stream)
pickler.dump(obj)
try:
pickler.dump(obj)
except (TypeError, AttributeError) as e:
# Some configs options are callables, e.g., post_grad_custom_pre_pass,
# and may not pickle.
log.warning("Can't pickle", exc_info=True)
raise BypassFxGraphCache from e
return stream.getvalue()

@classmethod
Expand Down Expand Up @@ -661,14 +667,7 @@ def __init__(
# Also hash on various system info (including the triton compiler version).
self.torch_version = torch_key()
self.system_info = CacheBase.get_system()

try:
self.inductor_config = config.save_config()
except (TypeError, AttributeError) as e:
# Some configs options are callables, e.g., post_grad_custom_pre_pass,
# and may not pickle.
log.debug("Can't pickle inductor config: %s", e)
raise BypassFxGraphCache from e
self.inductor_config = config.save_config_portable()

def debug_str(self) -> str:
"""
Expand Down
13 changes: 11 additions & 2 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,10 +875,19 @@ class trace:
log_autotuning_results: bool = False


_save_config_ignore = {
_save_config_ignore = [
# workaround: "Can't pickle <function ...>"
"trace.upload_tar",
}
]

_cache_config_ignore_prefix = [
# trace functions are not relevant to config caching
"trace",
# uses absolute path
"cuda.cutlass_dir",
# not relevant
"compile_threads",
]

if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
Expand Down
13 changes: 13 additions & 0 deletions torch/utils/_config_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,19 @@ def save_config(self) -> bytes:
config.pop(key)
return pickle.dumps(config, protocol=2)

def save_config_portable(self) -> Dict[str, Any]:
"""Convert config to portable format"""
config: Dict[str, Any] = {}
for key in sorted(self._config):
if key.startswith("_"):
continue
if any(
key.startswith(e) for e in self._config["_cache_config_ignore_prefix"]
):
continue
config[key] = self._config[key]
return config

def codegen_config(self) -> str:
"""Convert config to Python statements that replicate current config.
This does NOT include config settings that are at default values.
Expand Down
1 change: 1 addition & 0 deletions torch/utils/_config_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Note that the import should happen before the call to install_config_module(), o
assert TYPE_CHECKING, "Do not use at runtime"

def save_config() -> bytes: ...
def save_config_portable() -> Dict[str, Any]: ...
def codegen_config() -> str: ...
def get_hash() -> bytes: ...
def to_dict() -> Dict[str, Any]: ...
Expand Down

0 comments on commit 52bcf12

Please sign in to comment.