-
Notifications
You must be signed in to change notification settings - Fork 21.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
component-level configurable logging for dynamo, inductor, aot (#94858)
Summary: Adds NNC-like logging that is configured through an env var `TORCH_COMPILE_LOGS` Examples: `TORCH_LOGS="dynamo,guards" python script.py` - prints dynamo logs at level INFO with guards of all functions that are compiled `TORCH_LOGS="+dynamo,guards,graph" python script.py` - prints dynamo logs at level DEBUG with guards and graphs (in tabular) format of all graphs that are compiled [More examples with full output](https://gist.github.com/mlazos/b17f474457308ce15e88c91721ac1cce) Implementation: The implementation parses the log settings from the environment, finds any components (aot, dynamo, inductor) or other loggable objects (guards, graph, etc.) and generates a log_state object. This object contains all of the enabled artifacts, and a qualified log name -> level mapping. _init_logs then adds handlers to the highest level logs (the registered logs), and sets any artifact loggers to level DEBUG if the artifact is enabled. Note: set_logs is an alternative for manipulating the log_state, but if the environment contains TORCH_LOGS, the environment settings will be prioritized. Adding a new log: To add a new log, a dev should add their log name to torch._logging._registrations (there are examples there already). Adding a new artifact: To add a new artifact, a dev should add their artifact name to torch._logging._registrations as well. Additionally, wherever the artifact is logged, `torch._logging.getArtifactLogger(__name__, <artifact_name>)` should be used instead of the standard logging implementation. [design doc](https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#) Pull Request resolved: #94858 Approved by: https://github.com/ezyang
- Loading branch information
1 parent
086ce76
commit a1c46e5
Showing
18 changed files
with
909 additions
and
157 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Owner(s): ["module: dynamo"] | ||
import contextlib | ||
import functools | ||
import logging | ||
import unittest.mock | ||
|
||
import torch | ||
import torch._dynamo.test_case | ||
import torch._dynamo.testing | ||
|
||
from torch.testing._internal.inductor_utils import HAS_CUDA | ||
from torch.testing._internal.logging_utils import ( | ||
LoggingTestCase, | ||
make_logging_test, | ||
make_settings_test, | ||
) | ||
|
||
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda") | ||
|
||
|
||
def example_fn(a): | ||
output = a.mul(torch.ones(1000, 1000)) | ||
output = output.add(torch.ones(1000, 1000)) | ||
return output | ||
|
||
|
||
def dynamo_error_fn(a): | ||
output = a.mul(torch.ones(1000, 1000)) | ||
output = output.add(torch.ones(10, 10)) | ||
return output | ||
|
||
|
||
def inductor_error_fn(a): | ||
output = torch.round(a) | ||
return output | ||
|
||
|
||
def inductor_schedule_fn(a): | ||
output = a.add(torch.ones(1000, 1000, device="cuda")) | ||
return output | ||
|
||
|
||
ARGS = (torch.ones(1000, 1000, requires_grad=True),) | ||
|
||
|
||
def multi_record_test(num_records, **kwargs): | ||
@make_logging_test(**kwargs) | ||
def fn(self, records): | ||
fn_opt = torch._dynamo.optimize("inductor")(example_fn) | ||
fn_opt(*ARGS) | ||
self.assertEqual(len(records), num_records) | ||
|
||
return fn | ||
|
||
|
||
def within_range_record_test(num_records_lower, num_records_higher, **kwargs): | ||
@make_logging_test(**kwargs) | ||
def fn(self, records): | ||
fn_opt = torch._dynamo.optimize("inductor")(example_fn) | ||
fn_opt(*ARGS) | ||
self.assertGreaterEqual(len(records), num_records_lower) | ||
self.assertLessEqual(len(records), num_records_higher) | ||
|
||
return fn | ||
|
||
|
||
def single_record_test(**kwargs): | ||
return multi_record_test(1, **kwargs) | ||
|
||
|
||
class LoggingTests(LoggingTestCase): | ||
test_bytecode = multi_record_test(2, bytecode=True) | ||
test_output_code = multi_record_test(1, output_code=True) | ||
|
||
@requires_cuda() | ||
@make_logging_test(schedule=True) | ||
def test_schedule(self, records): | ||
fn_opt = torch._dynamo.optimize("inductor")(inductor_schedule_fn) | ||
fn_opt(torch.ones(1000, 1000, device="cuda")) | ||
self.assertGreater(len(records), 0) | ||
self.assertLess(len(records), 5) | ||
|
||
test_dynamo_debug = within_range_record_test(30, 50, dynamo=logging.DEBUG) | ||
test_dynamo_info = within_range_record_test(2, 10, dynamo=logging.INFO) | ||
|
||
@make_logging_test(dynamo=logging.ERROR) | ||
def test_dynamo_error(self, records): | ||
try: | ||
fn_opt = torch._dynamo.optimize("inductor")(dynamo_error_fn) | ||
fn_opt(*ARGS) | ||
except Exception: | ||
pass | ||
self.assertEqual(len(records), 1) | ||
|
||
test_aot = within_range_record_test(2, 6, aot=logging.INFO) | ||
test_inductor_debug = within_range_record_test(3, 15, inductor=logging.DEBUG) | ||
test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO) | ||
|
||
@make_logging_test(dynamo=logging.ERROR) | ||
def test_inductor_error(self, records): | ||
exitstack = contextlib.ExitStack() | ||
import torch._inductor.lowering | ||
|
||
def throw(x): | ||
raise AssertionError() | ||
|
||
# inject an error in the lowerings | ||
dict_entries = {} | ||
for x in list(torch._inductor.lowering.lowerings.keys()): | ||
if "round" in x.__name__: | ||
dict_entries[x] = throw | ||
|
||
exitstack.enter_context( | ||
unittest.mock.patch.dict(torch._inductor.lowering.lowerings, dict_entries) | ||
) | ||
|
||
try: | ||
fn_opt = torch._dynamo.optimize("inductor")(inductor_error_fn) | ||
fn_opt(*ARGS) | ||
except Exception: | ||
pass | ||
self.assertEqual(len(records), 1) | ||
self.assertIsInstance(records[0].msg, str) | ||
|
||
exitstack.close() | ||
|
||
# check that logging to a child log of a registered logger | ||
# does not register it and result in duplicated records | ||
@make_settings_test("torch._dynamo.output_graph") | ||
def test_open_registration_with_registered_parent(self, records): | ||
logger = logging.getLogger("torch._dynamo.output_graph") | ||
logger.info("hi") | ||
self.assertEqual(len(records), 1) | ||
|
||
# check logging to a random log that is not a child log of a registered | ||
# logger registers it and sets handlers properly | ||
@make_settings_test("torch.utils") | ||
def test_open_registration(self, records): | ||
logger = logging.getLogger("torch.utils") | ||
logger.info("hi") | ||
self.assertEqual(len(records), 1) | ||
|
||
|
||
# single record tests | ||
exclusions = {"bytecode", "output_code", "schedule"} | ||
for name in torch._logging._internal.log_registry.artifact_names: | ||
if name not in exclusions: | ||
setattr(LoggingTests, f"test_{name}", single_record_test(**{name: True})) | ||
|
||
if __name__ == "__main__": | ||
from torch._dynamo.test_case import run_tests | ||
|
||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Owner(s): ["module: dynamo"] | ||
import torch | ||
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test | ||
from torch._functorch.aot_autograd import aot_function | ||
from torch._functorch.compilers import nop | ||
import logging | ||
|
||
class TestAOTLogging(LoggingTestCase): | ||
|
||
@make_logging_test(aot=logging.DEBUG) | ||
def test_logging(self, records): | ||
def f(x): | ||
return torch.sin(x) | ||
compiled_f = aot_function( | ||
f, | ||
fw_compiler=nop, | ||
bw_compiler=nop | ||
) | ||
compiled_f(torch.randn(3)) | ||
self.assertGreater(len(records), 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.