diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 446cfa9fa8daf..2f7feb0752cc2 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -26,9 +26,9 @@ import numpy as np import torch - -import torch._dynamo.test_case import torch._dynamo.testing + +import torch._inductor.test_case import torch.onnx.operators import torch.utils._pytree as pytree @@ -151,7 +151,7 @@ def __getattr__(self, key): return self.__dict__[f"pfx_{key}"] -class MiscTests(torch._dynamo.test_case.TestCase): +class MiscTests(torch._inductor.test_case.TestCase): def test_get_cache_entry(self): def f(x): return x + 1 diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index deb2a2d548fe5..07f541edbe236 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -16,10 +16,11 @@ import torch._logging.structured import torch.distributed as dist +from torch._inductor.test_case import TestCase + from torch._logging._internal import TorchLogsFormatter from torch.nn.parallel import DistributedDataParallel as DDP - -from torch.testing._internal.common_utils import find_free_port, TestCase +from torch.testing._internal.common_utils import find_free_port from torch.testing._internal.inductor_utils import HAS_CUDA requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 31d83b3172ec6..7578dff26438e 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -482,6 +482,14 @@ def _reduce_symint(s): return (_ident, (str(s),)) +def _reduce_unsupported(s): + """ + See FxGraphCachePickler. Custom reducer to handle any objects that we don't + support and therefore raise to bypass caching. + """ + raise BypassFxGraphCache + + class FxGraphCachePickler(pickle.Pickler): """ Custom pickler to customize the pickling of some objects (Tensors), only for the @@ -494,6 +502,9 @@ class FxGraphCachePickler(pickle.Pickler): dispatch_table[FakeTensor] = _reduce_fake_tensor dispatch_table[torch.Tensor] = _reduce_tensor dispatch_table[torch.SymInt] = _reduce_symint + dispatch_table[ + torch.fx.experimental._backward_state.BackwardState + ] = _reduce_unsupported @classmethod def dumps(cls, obj) -> bytes: @@ -893,7 +904,6 @@ def load( Load a compiled graph from the cache. If a cached entry does not exist, compile the graph and save it to the cache. """ - compiled_graph = None try: FxGraphCache._check_can_cache(gm) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 0630614fb3247..748737a9e1f94 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -2385,6 +2385,7 @@ def check_equal(self, other: "ShapeEnv") -> None: "source_name_to_debug_name", "_prev_cache_key", "_version_counter", + "dim_constraints", ) # Mapping of the value of each to-be-compared field into the values that diff --git a/torch/testing/_internal/logging_utils.py b/torch/testing/_internal/logging_utils.py index 8bf762a6577e8..f97d0281b139b 100644 --- a/torch/testing/_internal/logging_utils.py +++ b/torch/testing/_internal/logging_utils.py @@ -7,6 +7,7 @@ import torch._logging import torch._logging._internal from torch._dynamo.utils import LazyString +from torch._inductor import config as inductor_config import logging import io @@ -74,6 +75,7 @@ def append_setting(name, level): # that the logs are setup correctly and capturing the correct records. def make_logging_test(**kwargs): def wrapper(fn): + @inductor_config.patch({"fx_graph_cache": False}) def test_fn(self): torch._dynamo.reset()