Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiled autograd] Fix LoggingTensor flaky test #126144

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions test/inductor/test_compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ def hook3(gI, gO):


class TestCompiledAutograd(TestCase):
def setUp(self):
xmfan marked this conversation as resolved.
Show resolved Hide resolved
super().setUp()
compiled_autograd.reset()

def tearDown(self):
xmfan marked this conversation as resolved.
Show resolved Hide resolved
super().tearDown()
compiled_autograd.reset()

def check_output_and_recompiles(
self, fn, count=1, compiler_fn=compiler_fn, compile_fn=False
):
Expand Down Expand Up @@ -322,6 +330,9 @@ def bytecode_hook(code, out_code):
handle.remove()

def test_inputs_aliasing_bytecode_stack_restore(self):
import logging
xmfan marked this conversation as resolved.
Show resolved Hide resolved

logging.getLogger().setLevel(logging.WARNING)
from torch.testing._internal.logging_tensor import LoggingTensor

# Create a graph that allows inputs stealing
Expand Down Expand Up @@ -752,6 +763,54 @@ def backward(ctx, gO_1, gO_2, gO_3):

self.check_output_and_recompiles(fn, count=2)

@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_logging_tensor_flaky(self):
xmfan marked this conversation as resolved.
Show resolved Hide resolved
# when you first run some test using triton and then run test_inputs_aliasing_bytecode_stack_restore
# resulting in:
# - pytest: `TypeError: unsupported operand type(s) for +: 'Tensor' and 'LoggingTensor'`
# - python: `TypeError: not all arguments converted during string formatting`

# 1. some triton involving test
def fn():
def _fn(x):
return x

x = torch.arange(
1, 10, requires_grad=True, dtype=torch.float16, device="cuda"
)
out = _fn(x)
loss = out.sum()
loss.backward()

with compiled_autograd.enable(compiler_fn):
fn()

import logging

logging.getLogger().setLevel(
logging.WARNING
) # triton setup overwrote it to INFO
# 2. test_inputs_aliasing_bytecode_stack_restore
from torch.testing._internal.logging_tensor import LoggingTensor

def forward(inputs):
add = inputs[0] + 1
add_1 = add + inputs[1]
out = add_1.cpu()
return (out,)

gm = torch.fx.symbolic_trace(forward)
print(gm.print_readable())
torch._dynamo.utils.set_locals_to_steal(gm, ["inputs"])
compiled_fn = torch.compile(gm)

inputs = [
torch.ones(1000000, dtype=torch.float32),
LoggingTensor(torch.ones(1)),
]

compiled_fn(inputs)

@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_custom_fn_output_metadata(self):
def my_compiler_fn(gm):
Expand Down
7 changes: 7 additions & 0 deletions torch/_dynamo/compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,10 @@ def disable():
if prior:
compiled_autograd_enabled = True
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)


# return to starting state of a new process
def reset():
xmfan marked this conversation as resolved.
Show resolved Hide resolved
compiled_autograd_enable = False
assert compiled_autograd_enabled_count == 0
torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
4 changes: 2 additions & 2 deletions torch/testing/_internal/logging_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import functools
from torch._C._profiler import gather_traceback, symbolize_tracebacks

logger = logging.getLogger("LoggingTensor")

_dtype_abbrs = {
torch.bfloat16: "bf16",
Expand Down Expand Up @@ -136,7 +137,7 @@ def emit(self, record):
self.tracebacks_list.append(record.traceback)

def log_input(name: str, var: object):
xmfan marked this conversation as resolved.
Show resolved Hide resolved
logging.getLogger("LoggingTensor").info("input", (name,), {}, var) # noqa: PLE1205
logger.info("input", (name,), {}, var) # noqa: PLE1205

class GatherTraceback(logging.Filter):
def __init__(self, python=True, script=True, cpp=False):
Expand All @@ -151,7 +152,6 @@ def filter(self, record):
@contextlib.contextmanager
def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[List[str]]:
collect_traceback = python_tb or script_tb or cpp_tb
logger = logging.getLogger("LoggingTensor")
log_list: List[str] = []
tracebacks_list: List[str] = []
handler = LoggingTensorHandler(
Expand Down
Loading