Skip to content

Commit

Permalink
Enable fx graph cache in torch_test.py when using PYTORCH_TEST_WITH_I…
Browse files Browse the repository at this point in the history
…NDUCTOR=1 (#122010)

Pull Request resolved: #122010
Approved by: https://github.com/eellison
  • Loading branch information
masnesral authored and pytorchmergebot committed Mar 19, 2024
1 parent 18d94d7 commit 6502c88
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion test/test_torch.py
Expand Up @@ -31,7 +31,7 @@
from torch.testing import make_tensor

from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_JETSON,
TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, run_tests, IS_JETSON,
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest, slowTestIf,
TEST_WITH_CROSSREF, skipIfTorchDynamo, skipRocmIfTorchInductor, set_default_dtype,
Expand Down Expand Up @@ -64,6 +64,11 @@
)
from torch.testing._internal.two_tensor import TwoTensor

if TEST_WITH_TORCHINDUCTOR:
from torch._inductor.test_case import TestCase
else:
from torch.testing._internal.common_utils import TestCase # type: ignore[assignment]


# Protects against includes accidentally setting the default dtype
assert torch.get_default_dtype() is torch.float32
Expand Down

0 comments on commit 6502c88

Please sign in to comment.