Skip to content

Commit

Permalink
[compiled autograd] Fix flaky tests
Browse files Browse the repository at this point in the history
ghstack-source-id: 9e999edf4e9a1e41c381fdf20063338a6eb2f313
Pull Request resolved: #126144
  • Loading branch information
xmfan committed May 14, 2024
1 parent bd3cbdb commit f13bfd8
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 3 deletions.
56 changes: 56 additions & 0 deletions test/inductor/test_compiled_autograd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["module: inductor"]
import functools
import logging
import re
import sys
import unittest
Expand Down Expand Up @@ -51,6 +52,14 @@ def hook3(gI, gO):


class TestCompiledAutograd(TestCase):
def setUp(self) -> None:
super().setUp()
compiled_autograd.reset()

def tearDown(self) -> None:
super().tearDown()
compiled_autograd.reset()

def check_output_and_recompiles(
self, fn, count=1, compiler_fn=compiler_fn, compile_fn=False
):
Expand Down Expand Up @@ -322,6 +331,7 @@ def bytecode_hook(code, out_code):
handle.remove()

def test_inputs_aliasing_bytecode_stack_restore(self):
logging.getLogger().setLevel(logging.WARNING)
from torch.testing._internal.logging_tensor import LoggingTensor

# Create a graph that allows inputs stealing
Expand Down Expand Up @@ -752,6 +762,52 @@ def backward(ctx, gO_1, gO_2, gO_3):

self.check_output_and_recompiles(fn, count=2)

@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_logging_tensor_flaky(self) -> None:
# when you first run some test using triton and then run test_inputs_aliasing_bytecode_stack_restore
# resulting in:
# - pytest: `TypeError: unsupported operand type(s) for +: 'Tensor' and 'LoggingTensor'`
# - python: `TypeError: not all arguments converted during string formatting`

# 1. some triton involving test
def fn():
def _fn(x):
return x

x = torch.arange(
1, 10, requires_grad=True, dtype=torch.float16, device="cuda"
)
out = _fn(x)
loss = out.sum()
loss.backward()

with compiled_autograd.enable(compiler_fn):
fn()

logging.getLogger().setLevel(
logging.WARNING
) # triton setup overwrote it to INFO
# 2. test_inputs_aliasing_bytecode_stack_restore
from torch.testing._internal.logging_tensor import LoggingTensor

def forward(inputs):
add = inputs[0] + 1
add_1 = add + inputs[1]
out = add_1.cpu()
return (out,)

gm = torch.fx.symbolic_trace(forward)
print(gm.print_readable())
torch._dynamo.utils.set_locals_to_steal(gm, ["inputs"])
compiled_fn = torch.compile(gm)

inputs = [
torch.ones(1000000, dtype=torch.float32),
LoggingTensor(torch.ones(1)),
]

compiled_fn(inputs)

@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_custom_fn_output_metadata(self):
def my_compiler_fn(gm):
Expand Down
7 changes: 7 additions & 0 deletions torch/_dynamo/compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,10 @@ def disable():
if prior:
compiled_autograd_enabled = True
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)


# return to starting state of a new process
def reset() -> None:
compiled_autograd_enable = False
assert compiled_autograd_enabled_count == 0
torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
6 changes: 3 additions & 3 deletions torch/testing/_internal/logging_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import functools
from torch._C._profiler import gather_traceback, symbolize_tracebacks

logger = logging.getLogger("LoggingTensor")

_dtype_abbrs = {
torch.bfloat16: "bf16",
Expand Down Expand Up @@ -135,8 +136,8 @@ def emit(self, record):
if self.tracebacks_list is not None:
self.tracebacks_list.append(record.traceback)

def log_input(name: str, var: object):
logging.getLogger("LoggingTensor").info("input", (name,), {}, var) # noqa: PLE1205
def log_input(name: str, var: object) -> None:
logger.info("input", (name,), {}, var) # noqa: PLE1205

class GatherTraceback(logging.Filter):
def __init__(self, python=True, script=True, cpp=False):
Expand All @@ -151,7 +152,6 @@ def filter(self, record):
@contextlib.contextmanager
def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[List[str]]:
collect_traceback = python_tb or script_tb or cpp_tb
logger = logging.getLogger("LoggingTensor")
log_list: List[str] = []
tracebacks_list: List[str] = []
handler = LoggingTensorHandler(
Expand Down

0 comments on commit f13bfd8

Please sign in to comment.