Skip to content

Commit

Permalink
Print per-tensor guard messages for TENSOR_MATCH (#107562)
Browse files Browse the repository at this point in the history
The new guard messages look like:

```
check_tensor(L['y'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3], stride=[1])  # _dynamo/variables/builder.py:1237 in wrap_fx_proxy_cls
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #107562
Approved by: https://github.com/anijain2305, https://github.com/jansel
ghstack dependencies: #107505, #107516, #107530, #107532
  • Loading branch information
ezyang authored and pytorchmergebot committed Aug 21, 2023
1 parent 3336aa1 commit ad07a4b
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __init__(
# TODO: something here
self.tensor_check_names: List[str] = []
self.tensor_check_examples: List[torch.Tensor] = []
self.tensor_check_guards: List[Guard] = []

self.check_fn_manager: CheckFunctionManager = check_fn_manager

Expand Down Expand Up @@ -655,6 +656,7 @@ def TENSOR_MATCH(self, guard: Guard, value=None):
else:
self.tensor_check_names.append(tensor_name)
self.tensor_check_examples.append(value)
self.tensor_check_guards.append(guard)

# A frame is valid for reuse with dynamic dimensions if the new dynamic dimensions are a
# strict subset of the old.
Expand Down Expand Up @@ -941,7 +943,7 @@ def compile_check_fn(
code_parts = ["___guarded_code.valid"]
base = os.path.dirname(__file__)

def add_code_part(code, guard):
def add_code_part(code, guard, log_only=False):
if guards_log.isEnabledFor(logging.DEBUG):
extra = ""
if guard is not None:
Expand Down Expand Up @@ -973,7 +975,8 @@ def add_code_part(code, guard):
maybe_user_stack,
)

code_parts.append(code)
if not log_only:
code_parts.append(code)

# TODO: Maybe better not to repeatedly spam the same guard information
# for each individual piece? Not sure.
Expand Down Expand Up @@ -1039,8 +1042,31 @@ def convert(size_or_stride):
tensor_check_args = ", ".join(
tensor_check_names + ["tensor_check_names=tensor_check_names"]
)
# TODO: we can give stack reporting here
add_code_part(f"___check_tensors({tensor_check_args})", None)
# Do this manually, to un-stagger the guards in log message
code_parts.append(f"___check_tensors({tensor_check_args})")
tensor_check_guards = (
local_builder.tensor_check_guards + global_builder.tensor_check_guards
)
for i, name in enumerate(tensor_check_names):
# This is a copy of what guards.cpp checks against
# Keep this in sync with TensorCheck constructor
t = tensor_check_examples[i]
pytype = type(t)
dispatch_key = (
torch._C._dispatch_keys(t)
| torch._C._dispatch_tls_local_include_set()
) - torch._C._dispatch_tls_local_exclude_set()
dtype = t.dtype
device_index = t.device.index
requires_grad = t.requires_grad
sizes = dynamic_dims_sizes[i]
strides = dynamic_dims_strides[i]
add_code_part(
f"check_tensor({name}, {pytype.__qualname__}, {dispatch_key}, {dtype}, "
f"device={device_index}, requires_grad={requires_grad}, size={sizes}, stride={strides})",
tensor_check_guards[i],
log_only=True,
)

aotautograd_guards: List[GuardEnvExpr] = (
self.output_graph.tracing_context.guards_context.aotautograd_guards
Expand All @@ -1055,6 +1081,8 @@ def convert(size_or_stride):
else:
raise RuntimeError(f"Unknown GuardEnvExpr: {guard}")

# TODO: the "guard" here is actually just the top level SHAPE_ENV
# which is useless. Get ShapeEnv to pass in more provenance.
for gcl in local_builder.shape_env_code:
for code in gcl.code_list:
add_code_part(code, gcl.guard)
Expand Down

0 comments on commit ad07a4b

Please sign in to comment.