Skip to content

Commit

Permalink
Update base for Update on "[inductor] Clear cache on ctx manager exit"
Browse files Browse the repository at this point in the history
FIXES #126128.

Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again,  usually fine in tests.

Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. 
TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
  • Loading branch information
xmfan committed May 14, 2024
1 parent b404ae5 commit 42eb91a
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 9 deletions.
11 changes: 4 additions & 7 deletions test/inductor/test_compiled_autograd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["module: inductor"]
import functools
import logging
import re
import sys
import unittest
Expand Down Expand Up @@ -51,11 +52,11 @@ def hook3(gI, gO):


class TestCompiledAutograd(TestCase):
def setUp(self):
def setUp(self) -> None:
super().setUp()
compiled_autograd.reset()

def tearDown(self):
def tearDown(self) -> None:
super().tearDown()
compiled_autograd.reset()

Expand Down Expand Up @@ -330,8 +331,6 @@ def bytecode_hook(code, out_code):
handle.remove()

def test_inputs_aliasing_bytecode_stack_restore(self):
import logging

logging.getLogger().setLevel(logging.WARNING)
from torch.testing._internal.logging_tensor import LoggingTensor

Expand Down Expand Up @@ -764,7 +763,7 @@ def backward(ctx, gO_1, gO_2, gO_3):
self.check_output_and_recompiles(fn, count=2)

@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_logging_tensor_flaky(self):
def test_logging_tensor_flaky(self) -> None:
# when you first run some test using triton and then run test_inputs_aliasing_bytecode_stack_restore
# resulting in:
# - pytest: `TypeError: unsupported operand type(s) for +: 'Tensor' and 'LoggingTensor'`
Expand All @@ -785,8 +784,6 @@ def _fn(x):
with compiled_autograd.enable(compiler_fn):
fn()

import logging

logging.getLogger().setLevel(
logging.WARNING
) # triton setup overwrote it to INFO
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def disable():


# return to starting state of a new process
def reset():
def reset() -> None:
compiled_autograd_enable = False
assert compiled_autograd_enabled_count == 0
torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
2 changes: 1 addition & 1 deletion torch/testing/_internal/logging_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def emit(self, record):
if self.tracebacks_list is not None:
self.tracebacks_list.append(record.traceback)

def log_input(name: str, var: object):
def log_input(name: str, var: object) -> None:
logger.info("input", (name,), {}, var) # noqa: PLE1205

class GatherTraceback(logging.Filter):
Expand Down

0 comments on commit 42eb91a

Please sign in to comment.