Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "all" option to logging #100664

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ The following components and artifacts are configurable through the ``TORCH_LOGS
variable (see torch._logging.set_logs for the python API):

Components:
``all``
Special component which configures the default log level of all components. Default: ``logging.WARN``

``dynamo``
The log level for the TorchDynamo component. Default: ``logging.WARN``

Expand Down
14 changes: 14 additions & 0 deletions test/dynamo/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,20 @@ def test_open_registration_python_api(self, records):
logger.info("hi")
self.assertEqual(len(records), 1)

@make_logging_test(all=logging.DEBUG, dynamo=logging.INFO)
def test_all(self, _):
registry = torch._logging._internal.log_registry
state = torch._logging._internal.log_state

dynamo_qname = registry.log_alias_to_log_qname["dynamo"]
for logger_qname in torch._logging._internal.log_registry.get_log_qnames():
logger = logging.getLogger(logger_qname)

if logger_qname == dynamo_qname:
self.assertEqual(logger.level, logging.INFO)
else:
self.assertEqual(logger.level, logging.DEBUG)


# single record tests
exclusions = {
Expand Down
62 changes: 51 additions & 11 deletions torch/_logging/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,11 @@ def clear(self):

def set_logs(
*,
dynamo: int = DEFAULT_LOG_LEVEL,
aot: int = DEFAULT_LOG_LEVEL,
inductor: int = DEFAULT_LOG_LEVEL,
all: Optional[int] = None,
dynamo: Optional[int] = None,
aot: Optional[int] = None,
dynamic: int = None,
inductor: int = None,
bytecode: bool = False,
aot_graphs: bool = False,
aot_joint_graph: bool = False,
Expand Down Expand Up @@ -169,15 +171,21 @@ def set_logs(
is set to a log level less than or equal to the log level of the artifact.

Keyword args:
dynamo (:class:`int`):
all (:class:`Optional[int]`):
The default log level for all components. Default: ``logging.WARN``

dynamo (:class:`Optional[int]`):
The log level for the TorchDynamo component. Default: ``logging.WARN``

aot (:class:`int`):
aot (:class:`Optional[int]`):
The log level for the AOTAutograd component. Default: ``logging.WARN``

inductor (:class:`int`):
inductor (:class:`Optional[int]`):
The log level for the TorchInductor component. Default: ``logging.WARN``

dynamic (:class:`Optional[int]`):
The log level for dynamic shapes. Default: ``logging.WARN``

bytecode (:class:`bool`):
Whether to emit the original and generated bytecode from TorchDynamo.
Default: ``False``
Expand Down Expand Up @@ -250,7 +258,25 @@ def set_logs(
modules = modules or {}

def _set_logs(**kwargs):
default_level = kwargs.pop("all", None)
if default_level:
if default_level not in logging._levelToName:
raise ValueError(
f"Unrecognized log level for kwarg all: {default_level}, valid level values "
f"are: {','.join([str(k) for k in logging._levelToName.keys()])}"
)

# add any missing aliases to kwargs
for alias in log_registry.log_alias_to_log_qname.keys():
if alias not in kwargs:
kwargs[alias] = default_level
else:
default_level = DEFAULT_LOG_LEVEL

for alias, val in itertools.chain(kwargs.items(), modules.items()): # type: ignore[union-attr]
if val is None:
val = default_level

if log_registry.is_artifact(alias):
if val:
log_state.enable_artifact(alias)
Expand All @@ -260,10 +286,10 @@ def _set_logs(**kwargs):
f"Unrecognized log level for log {alias}: {val}, valid level values "
f"are: {','.join([str(k) for k in logging._levelToName.keys()])}"
)
if val != DEFAULT_LOG_LEVEL:
log_state.enable_log(
log_registry.log_alias_to_log_qname[alias], val
)

log_state.enable_log(log_registry.log_alias_to_log_qname[alias], val)
elif alias == "all":
continue
else:
raise ValueError(
f"Unrecognized log or artifact name passed to set_logs: {alias}"
Expand All @@ -272,9 +298,11 @@ def _set_logs(**kwargs):
_init_logs()

_set_logs(
all=all,
dynamo=dynamo,
aot=aot,
inductor=inductor,
dynamic=dynamic,
bytecode=bytecode,
aot_graphs=aot_graphs,
aot_joint_graph=aot_joint_graph,
Expand Down Expand Up @@ -357,7 +385,9 @@ def _validate_settings(settings):
def _invalid_settings_err_msg(settings):
entities = "\n " + "\n ".join(
itertools.chain(
log_registry.log_alias_to_log_qname.keys(), log_registry.artifact_names
["all"],
log_registry.log_alias_to_log_qname.keys(),
log_registry.artifact_names,
)
)
msg = (
Expand Down Expand Up @@ -392,14 +422,24 @@ def get_name_level_pair(name):
return clean_name, level

log_state = LogState()

for name in log_names:
name, level = get_name_level_pair(name)
if name == "all":
for log_qname in log_registry.get_log_qnames():
log_state.enable_log(log_qname, level)

for name in log_names:
name, level = get_name_level_pair(name)

if log_registry.is_log(name):
assert level is not None
log_qname = log_registry.log_alias_to_log_qname[name]
log_state.enable_log(log_qname, level)
elif log_registry.is_artifact(name):
log_state.enable_artifact(name)
elif name == "all":
continue
elif _is_valid_module(name):
if not _has_registered_parent(name):
log_registry.register_log(name, name)
Expand Down