Skip to content

Commit

Permalink
component-level configurable logging for dynamo, inductor, aot (#94858)
Browse files Browse the repository at this point in the history
Summary:

Adds NNC-like logging that is configured through an env var `TORCH_COMPILE_LOGS`
Examples:
`TORCH_LOGS="dynamo,guards" python script.py` - prints dynamo logs at level INFO with guards of all functions that are compiled

`TORCH_LOGS="+dynamo,guards,graph" python script.py` - prints dynamo logs at level DEBUG with guards and graphs (in tabular) format of all graphs that are compiled

[More examples with full output](https://gist.github.com/mlazos/b17f474457308ce15e88c91721ac1cce)

Implementation:
The implementation parses the log settings from the environment, finds any components (aot, dynamo, inductor) or other loggable objects (guards, graph, etc.) and generates a log_state object. This object contains all of the enabled artifacts, and a qualified log name -> level mapping. _init_logs then adds handlers to the highest level logs (the registered logs), and sets any artifact loggers to level DEBUG if the artifact is enabled.

Note: set_logs is an alternative for manipulating the log_state, but if the environment contains TORCH_LOGS, the environment settings will be prioritized.

Adding a new log:
To add a new log, a dev should add their log name to torch._logging._registrations (there are examples there already).

Adding a new artifact:
To add a new artifact, a dev should add their artifact name to torch._logging._registrations as well.
Additionally, wherever the artifact is logged, `torch._logging.getArtifactLogger(__name__, <artifact_name>)` should be used instead of the standard logging implementation.

[design doc](https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#)

Pull Request resolved: #94858
Approved by: https://github.com/ezyang
  • Loading branch information
mlazos authored and pytorchmergebot committed Mar 18, 2023
1 parent 086ce76 commit a1c46e5
Show file tree
Hide file tree
Showing 18 changed files with 909 additions and 157 deletions.
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ include_patterns = [
'test/test_value_ranges.py',
'torch/utils/_sympy/interp.py',
'torch/utils/_sympy/reference.py',
'torch/_logging/**/*.py',
'torch/nn/parallel/distributed.py',
]
command = [
Expand Down
153 changes: 153 additions & 0 deletions test/dynamo/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Owner(s): ["module: dynamo"]
import contextlib
import functools
import logging
import unittest.mock

import torch
import torch._dynamo.test_case
import torch._dynamo.testing

from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.logging_utils import (
LoggingTestCase,
make_logging_test,
make_settings_test,
)

requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")


def example_fn(a):
output = a.mul(torch.ones(1000, 1000))
output = output.add(torch.ones(1000, 1000))
return output


def dynamo_error_fn(a):
output = a.mul(torch.ones(1000, 1000))
output = output.add(torch.ones(10, 10))
return output


def inductor_error_fn(a):
output = torch.round(a)
return output


def inductor_schedule_fn(a):
output = a.add(torch.ones(1000, 1000, device="cuda"))
return output


ARGS = (torch.ones(1000, 1000, requires_grad=True),)


def multi_record_test(num_records, **kwargs):
@make_logging_test(**kwargs)
def fn(self, records):
fn_opt = torch._dynamo.optimize("inductor")(example_fn)
fn_opt(*ARGS)
self.assertEqual(len(records), num_records)

return fn


def within_range_record_test(num_records_lower, num_records_higher, **kwargs):
@make_logging_test(**kwargs)
def fn(self, records):
fn_opt = torch._dynamo.optimize("inductor")(example_fn)
fn_opt(*ARGS)
self.assertGreaterEqual(len(records), num_records_lower)
self.assertLessEqual(len(records), num_records_higher)

return fn


def single_record_test(**kwargs):
return multi_record_test(1, **kwargs)


class LoggingTests(LoggingTestCase):
test_bytecode = multi_record_test(2, bytecode=True)
test_output_code = multi_record_test(1, output_code=True)

@requires_cuda()
@make_logging_test(schedule=True)
def test_schedule(self, records):
fn_opt = torch._dynamo.optimize("inductor")(inductor_schedule_fn)
fn_opt(torch.ones(1000, 1000, device="cuda"))
self.assertGreater(len(records), 0)
self.assertLess(len(records), 5)

test_dynamo_debug = within_range_record_test(30, 50, dynamo=logging.DEBUG)
test_dynamo_info = within_range_record_test(2, 10, dynamo=logging.INFO)

@make_logging_test(dynamo=logging.ERROR)
def test_dynamo_error(self, records):
try:
fn_opt = torch._dynamo.optimize("inductor")(dynamo_error_fn)
fn_opt(*ARGS)
except Exception:
pass
self.assertEqual(len(records), 1)

test_aot = within_range_record_test(2, 6, aot=logging.INFO)
test_inductor_debug = within_range_record_test(3, 15, inductor=logging.DEBUG)
test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO)

@make_logging_test(dynamo=logging.ERROR)
def test_inductor_error(self, records):
exitstack = contextlib.ExitStack()
import torch._inductor.lowering

def throw(x):
raise AssertionError()

# inject an error in the lowerings
dict_entries = {}
for x in list(torch._inductor.lowering.lowerings.keys()):
if "round" in x.__name__:
dict_entries[x] = throw

exitstack.enter_context(
unittest.mock.patch.dict(torch._inductor.lowering.lowerings, dict_entries)
)

try:
fn_opt = torch._dynamo.optimize("inductor")(inductor_error_fn)
fn_opt(*ARGS)
except Exception:
pass
self.assertEqual(len(records), 1)
self.assertIsInstance(records[0].msg, str)

exitstack.close()

# check that logging to a child log of a registered logger
# does not register it and result in duplicated records
@make_settings_test("torch._dynamo.output_graph")
def test_open_registration_with_registered_parent(self, records):
logger = logging.getLogger("torch._dynamo.output_graph")
logger.info("hi")
self.assertEqual(len(records), 1)

# check logging to a random log that is not a child log of a registered
# logger registers it and sets handlers properly
@make_settings_test("torch.utils")
def test_open_registration(self, records):
logger = logging.getLogger("torch.utils")
logger.info("hi")
self.assertEqual(len(records), 1)


# single record tests
exclusions = {"bytecode", "output_code", "schedule"}
for name in torch._logging._internal.log_registry.artifact_names:
if name not in exclusions:
setattr(LoggingTests, f"test_{name}", single_record_test(**{name: True}))

if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

run_tests()
20 changes: 20 additions & 0 deletions test/functorch/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Owner(s): ["module: dynamo"]
import torch
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
from torch._functorch.aot_autograd import aot_function
from torch._functorch.compilers import nop
import logging

class TestAOTLogging(LoggingTestCase):

@make_logging_test(aot=logging.DEBUG)
def test_logging(self, records):
def f(x):
return torch.sin(x)
compiled_f = aot_function(
f,
fw_compiler=nop,
bw_compiler=nop
)
compiled_f(torch.randn(3))
self.assertGreater(len(records), 0)
4 changes: 4 additions & 0 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1639,3 +1639,7 @@ def _sparse_coo_tensor_unsafe(*args, **kwargs):
'use torch.sparse_coo_tensor(..., check_invariants=False) instead.')
kwargs['check_invariants'] = False
return torch.sparse_coo_tensor(*args, **kwargs)


from . import _logging
_logging._init_logs()
23 changes: 11 additions & 12 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,24 @@

from .logging import get_loggers_level, set_loggers_level

# log level (levels print what it says + all levels listed below it)
# logging.DEBUG print full traces <-- lowest level + print tracing of every instruction
# logging.INFO print the steps that dynamo is running and optionally, compiled functions + graphs
# logging.WARN print warnings (including graph breaks)
# logging.ERROR print exceptions (and what user code was being processed when it occurred)

# Note (mlazos): This is deprecated and will be removed very soon
# to configure logging for dynamo, aot, and inductor
# use the following API in the torch._logging module
# torch._logging.set_logs(dynamo=<level>, aot=<level>, inductor<level>)
# or use the environment variable TORCH_LOGS="dynamo,aot,inductor" (use a prefix + to indicate higher verbosity)
# see this design doc for more detailed info
# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#
log_level = property(
lambda _: get_loggers_level(), lambda _, lvl: set_loggers_level(lvl)
)

# log compiled function + graphs at level INFO
output_code = False

# the name of a file to write the logs to
log_file_name = None

# Verbose will print full stack traces on warnings and errors
verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"

# If true, traced graph outputs will be outputted as Python GraphModule code.
# If false, traced graph outputs will be outputted in tabular form.
output_graph_code = False

# verify the correctness of optimized backend
verify_correctness = False

Expand Down Expand Up @@ -59,6 +55,9 @@
torch._utils.is_compiling: True,
}

# Here for bw compat, will be removed (mlazos)
# see above notes for log_level on how to configure the new logging system
output_code = None

# don't specialize on shapes and strides and put shape ops in graph
dynamic_shapes = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1"
Expand Down
54 changes: 30 additions & 24 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, Optional, Set

import torch
import torch._logging
from torch._guards import tracing
from torch.fx.graph_module import _forward_from_src as original_forward_from_src

Expand Down Expand Up @@ -38,15 +39,18 @@
gen_record_file_name,
guard_failures,
increment_frame,
init_logging,
is_namedtuple,
istype,
orig_code_map,
reset_graph_break_dup_checker,
setup_compile_debug,
troubleshooting_url,
write_record_to_file,
)

log = logging.getLogger(__name__)
guards_log = torch._logging.getArtifactLogger(__name__, "guards")
bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode")


class Tracker:
Expand Down Expand Up @@ -101,9 +105,11 @@ def _fn(*args, **kwargs):
cuda_rng_state = torch.cuda.get_rng_state()
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
cleanup = setup_compile_debug()
try:
return fn(*args, **kwargs)
finally:
cleanup.close()
torch._C._set_grad_enabled(prior_grad_mode)
torch.random.set_rng_state(rng_state)
if torch.cuda.is_available():
Expand Down Expand Up @@ -195,7 +201,7 @@ def convert_frame_assert(
export: bool = False,
):
"""Fully convert a frame into an FX graph"""
init_logging()
reset_graph_break_dup_checker()

def _convert_frame_assert(frame: types.FrameType, cache_size: int, hooks: Hooks):
increment_frame()
Expand Down Expand Up @@ -339,25 +345,26 @@ def transform(instructions, code_options):
return None
output_codes.add(out_code)

if config.output_code:
log.info(
format_bytecode(
"ORIGINAL BYTECODE",
code.co_name,
code.co_filename,
code.co_firstlineno,
code,
),
)
log.info(
format_bytecode(
"MODIFIED BYTECODE",
code.co_name,
code.co_filename,
code.co_firstlineno,
out_code,
),
)
def log_bytecode(prefix, name, filename, line_no, code):
if bytecode_log.isEnabledFor(logging.DEBUG):
bytecode_log.debug(
format_bytecode(prefix, name, filename, line_no, code)
)

log_bytecode(
"ORIGINAL BYTECODE",
code.co_name,
code.co_filename,
code.co_firstlineno,
code,
)
log_bytecode(
"MODIFIED BYTECODE",
code.co_name,
code.co_filename,
code.co_firstlineno,
out_code,
)

assert output is not None
assert output.guards is not None
Expand All @@ -371,12 +378,12 @@ def transform(instructions, code_options):

guarded_code = GuardedCode(out_code, check_fn.check_fn)

if config.output_code:
if guards_log.isEnabledFor(logging.DEBUG):
guard_str = "GUARDS:\n"
guard_str += "\n".join(
[f" - {str(guard)}" for guard in sorted(output.guards)]
)
log.info(guard_str)
guards_log.debug(guard_str)

if hooks.guard_export_fn is not None:
hooks.guard_export_fn(output.guards)
Expand Down Expand Up @@ -423,7 +430,6 @@ def replay(filename):

original_replay_val = config.replay_record_enabled
config.replay_record_enabled = False
init_logging()
with open(filename, "rb") as in_file:
record = ExecutionRecord.load(in_file)
record.globals = {
Expand Down
10 changes: 8 additions & 2 deletions torch/_dynamo/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,14 @@ def format_error_msg(exc, code, record_filename=None, frame=None):
msg = os.linesep * 2

if config.verbose:
msg = format_bytecode(
"WON'T CONVERT", code.co_name, code.co_filename, code.co_firstlineno, code
msg = str(
format_bytecode(
"WON'T CONVERT",
code.co_name,
code.co_filename,
code.co_firstlineno,
code,
)
)
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
msg += format_exc()
Expand Down

0 comments on commit a1c46e5

Please sign in to comment.