diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 810ffa40255a..86929421fc4b 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -522,7 +522,13 @@ def dumps(cls, obj) -> bytes: """ with io.BytesIO() as stream: pickler = cls(stream) - pickler.dump(obj) + try: + pickler.dump(obj) + except (TypeError, AttributeError) as e: + # Some configs options are callables, e.g., post_grad_custom_pre_pass, + # and may not pickle. + log.warning("Can't pickle", exc_info=True) + raise BypassFxGraphCache from e return stream.getvalue() @classmethod @@ -661,14 +667,7 @@ def __init__( # Also hash on various system info (including the triton compiler version). self.torch_version = torch_key() self.system_info = CacheBase.get_system() - - try: - self.inductor_config = config.save_config() - except (TypeError, AttributeError) as e: - # Some configs options are callables, e.g., post_grad_custom_pre_pass, - # and may not pickle. - log.debug("Can't pickle inductor config: %s", e) - raise BypassFxGraphCache from e + self.inductor_config = config.save_config_portable() def debug_str(self) -> str: """ diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index a7dc6cd026fe..e15181ce0263 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -875,10 +875,19 @@ class trace: log_autotuning_results: bool = False -_save_config_ignore = { +_save_config_ignore = [ # workaround: "Can't pickle " "trace.upload_tar", -} +] + +_cache_config_ignore_prefix = [ + # trace functions are not relevant to config caching + "trace", + # uses absolute path + "cuda.cutlass_dir", + # not relevant + "compile_threads", +] if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index ef0478535772..f468e2d84890 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -156,6 +156,19 @@ def save_config(self) -> bytes: config.pop(key) return pickle.dumps(config, protocol=2) + def save_config_portable(self) -> Dict[str, Any]: + """Convert config to portable format""" + config: Dict[str, Any] = {} + for key in sorted(self._config): + if key.startswith("_"): + continue + if any( + key.startswith(e) for e in self._config["_cache_config_ignore_prefix"] + ): + continue + config[key] = self._config[key] + return config + def codegen_config(self) -> str: """Convert config to Python statements that replicate current config. This does NOT include config settings that are at default values. diff --git a/torch/utils/_config_typing.pyi b/torch/utils/_config_typing.pyi index c31eb5f34a59..b2d99e67fabb 100644 --- a/torch/utils/_config_typing.pyi +++ b/torch/utils/_config_typing.pyi @@ -23,6 +23,7 @@ Note that the import should happen before the call to install_config_module(), o assert TYPE_CHECKING, "Do not use at runtime" def save_config() -> bytes: ... +def save_config_portable() -> Dict[str, Any]: ... def codegen_config() -> str: ... def get_hash() -> bytes: ... def to_dict() -> Dict[str, Any]: ...