Skip to content

Commit

Permalink
[dynamo] Initialize tensor_weakref_to_sizes_strides with a weak dict (#…
Browse files Browse the repository at this point in the history
…113412)

Spotted while working on getting output_graph.py to typecheck.

The type hint indicates that it was intended to be initialized with a
WeakIdKeyDictionary, but the actual runtime value was a regular dict.
Not sure if there's some kind of test we should add for this fix.

Looks like the code was originally added in
#100128.

Pull Request resolved: #113412
Approved by: https://github.com/Skylion007, https://github.com/voznesenskym
ghstack dependencies: #113413, #113518, #113519
  • Loading branch information
int3 authored and pytorchmergebot committed Nov 13, 2023
1 parent 6ed20af commit 68278cf
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 15 deletions.
14 changes: 3 additions & 11 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
)

from torch.utils._traceback import format_frame, report_compile_source_on_error
from torch.utils.weak import TensorWeakRef, WeakIdRef
from torch.utils.weak import TensorWeakRef

from . import config, convert_frame, exc, mutation_guard
from .eval_frame import set_guard_error_hook
Expand Down Expand Up @@ -1068,20 +1068,12 @@ def convert(size_or_stride):
return converted

dynamic_dims_sizes = [
convert(
self.output_graph.tensor_weakref_to_sizes_strides[WeakIdRef(t)][
"size"
]
)
convert(self.output_graph.tensor_weakref_to_sizes_strides[t]["size"])
for t in tensor_check_examples
]

dynamic_dims_strides = [
convert(
self.output_graph.tensor_weakref_to_sizes_strides[WeakIdRef(t)][
"stride"
]
)
convert(self.output_graph.tensor_weakref_to_sizes_strides[t]["stride"])
for t in tensor_check_examples
]

Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from torch._utils_internal import signpost_event
from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv
from torch.utils.weak import WeakIdKeyDictionary
from torch.utils.weak import WeakTensorKeyDictionary

from . import config, logging as torchdynamo_logging, variables
from .backends.registry import CompiledFn, CompilerFn
Expand Down Expand Up @@ -244,7 +244,7 @@ def __init__(
self.export = export
self.export_constraints = export_constraints
self.frame_state = frame_state
self.tensor_weakref_to_sizes_strides: WeakIdKeyDictionary = {}
self.tensor_weakref_to_sizes_strides = WeakTensorKeyDictionary()
self.cleanup_hooks: List[Callable[[], Any]] = []

# TODO: maybe should just pass the entire f_code in here? Not
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torch.fx.immutable_collections import immutable_list
from torch.nested._internal.nested_tensor import NestedTensor
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils.weak import TensorWeakRef, WeakIdRef
from torch.utils.weak import TensorWeakRef
from .. import config, mutation_guard, replay_record, skipfiles, trace_rules
from ..allowed_functions import (
is_allowed,
Expand Down Expand Up @@ -1746,7 +1746,7 @@ def wrap_to_fake_tensor_and_record(
if is_tensor and not (static_shapes and source.is_nn_module()):
tx.output.tracked_fakes.append(TrackedFake(fake_e, source, constraint_dims))
tx.output.tracked_fakes_id_to_source[id(e)].append(source)
tx.output.tensor_weakref_to_sizes_strides[WeakIdRef(e)] = {
tx.output.tensor_weakref_to_sizes_strides[e] = {
"size": fake_e.size(),
"stride": fake_e.stride(),
}
Expand Down

0 comments on commit 68278cf

Please sign in to comment.