Skip to content

Commit

Permalink
Make 'pytest test/inductor/test_memory_planning.py' work (#126397)
Browse files Browse the repository at this point in the history
There's still another naughty direct test_* import, I'm out of patience
right now though.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #126397
Approved by: https://github.com/peterbell10, https://github.com/int3
  • Loading branch information
ezyang authored and ZelboK committed May 19, 2024
1 parent 4c93c7a commit f1897d4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 26 deletions.
6 changes: 3 additions & 3 deletions test/inductor/test_memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import sys

import unittest

from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CUDA

Expand All @@ -13,14 +15,12 @@
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock") # noqa: F821

import unittest

import torch
from test_torchinductor import run_and_get_cpp_code
from torch._C import FileCheck
from torch._dynamo.utils import same
from torch._inductor import config
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_cpp_code
from torch.export import Dim
from torch.utils._triton import has_triton

Expand Down
24 changes: 1 addition & 23 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
aoti_eager_cache_dir,
load_aoti_eager_cache,
run_and_get_code,
run_and_get_cpp_code,
run_and_get_triton_code,
)
from torch._inductor.virtualized import V
Expand Down Expand Up @@ -342,29 +343,6 @@ def clone_preserve_strides(x, device=None):
return out


def run_and_get_cpp_code(fn, *args, **kwargs):
# We use the patch context manager instead of using it as a decorator.
# In this way, we can ensure that the attribute is patched and unpatched correctly
# even if this run_and_get_cpp_code function is called multiple times.
with patch.object(config, "debug", True):
torch._dynamo.reset()
import io
import logging

log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
from torch._inductor.graph import output_code_log

output_code_log.addHandler(ch)
prev_level = output_code_log.level
output_code_log.setLevel(logging.DEBUG)
result = fn(*args, **kwargs)
s = log_capture_string.getvalue()
output_code_log.setLevel(prev_level)
output_code_log.removeHandler(ch)
return result, s


def check_model(
self: TestCase,
model,
Expand Down
23 changes: 23 additions & 0 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,3 +1696,26 @@ def aoti_compile_with_persistent_cache(
return kernel_lib_path
except Exception as e:
return ""


def run_and_get_cpp_code(fn, *args, **kwargs):
# We use the patch context manager instead of using it as a decorator.
# In this way, we can ensure that the attribute is patched and unpatched correctly
# even if this run_and_get_cpp_code function is called multiple times.
with unittest.mock.patch.object(config, "debug", True):
torch._dynamo.reset()
import io
import logging

log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
from torch._inductor.graph import output_code_log

output_code_log.addHandler(ch)
prev_level = output_code_log.level
output_code_log.setLevel(logging.DEBUG)
result = fn(*args, **kwargs)
s = log_capture_string.getvalue()
output_code_log.setLevel(prev_level)
output_code_log.removeHandler(ch)
return result, s

0 comments on commit f1897d4

Please sign in to comment.