Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nicer logs for dynamic shapes #99277

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 25 additions & 4 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import enum
import functools
import inspect

import logging
import operator
import re
import types
from typing import Any, List, NamedTuple, Optional, Union

import torch
import torch._logging

from torch import SymInt
from torch._guards import GuardSource
Expand Down Expand Up @@ -105,6 +108,7 @@
)
from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable

log = logging.getLogger(__name__)

DimList = List

Expand Down Expand Up @@ -1176,19 +1180,37 @@ def wrap_to_fake_tensor_and_record(
# If there is no entry for this source, add the tensor to frame state with its current static size.
# E.g., {} -> {“x”: [2, 4]}
curr_sizes = list(e.size())
log.debug("Registered static shapes %s for %s", curr_sizes, name)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one gets very noisy, so maybe we keep him debug, and make the others info? Open to bikshedding. @ezyang opine please.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is hard to say without actually using the logs. Keeping everything as debug to start is a safe start, assuming this doesn't trigger too much (but I don't think it does, since we debug trace every opcode in dynamo). Note that you can selectively up logging from ONLY this module with TORCH_LOGS=+torch._dynamo.variables.builder

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would bikeshed the message here though; "registered" isn't very descriptive. In the log message, I think I probably want the words frame_state in it. Maybe frame_state[name] = curr_sizes. But in that case, why don't we move the log message down to where we actually do the setter, and then delete all the other branches?

else:
curr_sizes = tx.output.frame_state[name]
if curr_sizes is not None:
curr_ndim = len(curr_sizes)
if e.ndim != len(curr_sizes):
# If there is already an entry, and the dim mismatches, replace the frame state entry with None.
# E.g. {“x”: [2, 3, 4]} -> {“x”: None}
curr_sizes = None
log.debug(
"Registered fully dynamic shape for %s due to ndim change %s->%s",
name,
curr_ndim,
e.ndim,
)
else:
# If there is already an entry, and the dim matches, for every size in the frame state which
# disagrees with the current static size, replace it with None. E.g., {“x”: [2, 3]} -> {“x”: [2, None]}
for i, dim in enumerate(curr_sizes):
if e.size()[i] != dim:
curr_sizes[i] = None
for i, curr_dim in enumerate(curr_sizes):
new_dim = e.size()[i]
# None is fine here, we could have seen this and marked it dynamic prior.
if curr_dim is not None:
if new_dim != curr_dim:
log.debug(
"Registered dynamic dim for %s due to size change change %s->%s at dim %s",
name,
curr_dim,
new_dim,
i,
)
curr_sizes[i] = None

tx.output.frame_state[name] = curr_sizes

Expand Down Expand Up @@ -1224,7 +1246,6 @@ def wrap_to_fake_tensor_and_record(

# NB: both static and dynamic have precedence over
automatic_dynamic = curr_sizes is None or curr_sizes[i] is None

# We will process constraints first, as they will imply that we
# have a dynamic dimension
# Precedence: export constraints > eager constraints
Expand Down
4 changes: 2 additions & 2 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,10 +1945,10 @@ def hint():
msg = f" {len(error_msgs) + 1}. {msg()}"
error_msgs.append(msg)
if len(error_msgs) > 0:
log.warning("Warning only constraints violated %s", warn_msgs)
log.info("Warning only constraints violated %s", warn_msgs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, this looks like a botched merge, we shouldn't warn here at all, we're going to raise an error!

raise ConstraintViolationError(f"Constraints violated!\n{error_msgs}")
elif len(warn_msgs) > 0:
log.warning("%s Warning only constraints violated", len(warn_msgs))
log.info("%s Warning only constraints violated", len(warn_msgs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably suppresses too much but we can figure out a better strategy later (maybe only warn this once per frame)


return exprs

Expand Down