Skip to content

Commit

Permalink
Add tensor to fake clone snapshot for immutable source of truth (#100128
Browse files Browse the repository at this point in the history
)

There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand).

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors.

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time.

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`.

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.

Pull Request resolved: #100128
Approved by: https://github.com/ezyang
  • Loading branch information
voznesenskym authored and pytorchmergebot committed Apr 27, 2023
1 parent ca1cf43 commit a145a33
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 8 deletions.
16 changes: 8 additions & 8 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
)
from torch.fx.experimental.symbolic_shapes import is_concrete_int, SYMPY_INTERP

from torch.utils.weak import WeakIdRef

from . import config, convert_frame, mutation_guard
from .eval_frame import set_guard_error_hook, set_guard_fail_hook
from .exc import unimplemented
Expand Down Expand Up @@ -754,19 +756,17 @@ def convert(size_or_stride):

dynamic_dims_sizes = [
convert(
self.output_graph.tracing_context.fake_mode.from_tensor(
t,
memoized_only=True,
).size()
self.output_graph.tensor_weakref_to_sizes_strides[WeakIdRef(t)][
"size"
]
)
for t in tensor_check_examples
]
dynamic_dims_strides = [
convert(
self.output_graph.tracing_context.fake_mode.from_tensor(
t,
memoized_only=True,
).stride()
self.output_graph.tensor_weakref_to_sizes_strides[WeakIdRef(t)][
"stride"
]
)
for t in tensor_check_examples
]
Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
TracingContext,
)
from torch.fx.experimental.symbolic_shapes import free_symbols, ShapeEnv
from torch.utils.weak import WeakIdKeyDictionary

from . import config, logging as torchdynamo_logging, variables
from .backends.registry import CompiledFn, CompilerFn
Expand Down Expand Up @@ -87,6 +88,7 @@ class OutputGraphState(NamedTuple):
param_name_to_source: Optional[Dict[str, Source]]
side_effects: SideEffects
timestamp: int
tensor_weakref_to_sizes_strides: WeakIdKeyDictionary

def diff(self, other: "OutputGraphState", *, prefix: str = "") -> Optional[str]:
for k in self._fields:
Expand Down Expand Up @@ -207,6 +209,7 @@ def __init__(
self.export = export
self.export_constraints = export_constraints
self.frame_state = frame_state
self.tensor_weakref_to_sizes_strides: WeakIdKeyDictionary = {}
# In export mode, we force the shape_env to strictly disallow any constraining
# of the user marked dynamic dims
fake_mode = torch._subclasses.FakeTensorMode(
Expand Down Expand Up @@ -308,6 +311,7 @@ def copy_graphstate(self) -> OutputGraphState:
dict(self.param_name_to_source),
self.side_effects.clone(),
self.timestamp,
dict(self.tensor_weakref_to_sizes_strides),
)
self.timestamp += 1
return state
Expand All @@ -322,6 +326,7 @@ def restore_graphstate(self, state: OutputGraphState):
self.param_name_to_source,
self.side_effects,
self.timestamp,
self.tensor_weakref_to_sizes_strides,
) = state
self.tracing_context.guards_context.restore_graphstate(guards_state)
# FX deepcopy doesn't work for a partially created graph, so just remove new nodes
Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
RelaxedUnspecConstraint,
)
from torch.fx.immutable_collections import immutable_list
from torch.utils.weak import WeakIdRef

from .. import config, mutation_guard, replay_record, skipfiles
from ..allowed_functions import is_allowed, is_builtin_callable, is_numpy
Expand Down Expand Up @@ -1261,6 +1262,10 @@ def wrap_to_fake_tensor_and_record(
)
if is_tensor and not (static_shapes and source.is_nn_module()):
tx.output.tracked_fakes.append(TrackedFake(fake_e, source, constraint_dims))
tx.output.tensor_weakref_to_sizes_strides[WeakIdRef(e)] = {
"size": fake_e.size(),
"stride": fake_e.stride(),
}
return fake_e
else:
return e
21 changes: 21 additions & 0 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,11 +1033,25 @@ def speculate_branch(branch):
true_tracked_fakes = true_cmp.output.tracked_fakes
false_tracked_fakes = false_cmp.output.tracked_fakes
tx.output.tracked_fakes = list({*false_tracked_fakes, *true_tracked_fakes})
true_tensor_weakref_to_sizes_strides = (
true_cmp.output.tensor_weakref_to_sizes_strides
)
false_tensor_weakref_to_sizes_strides = (
false_cmp.output.tensor_weakref_to_sizes_strides
)

# Add guards
tx.output.tracing_context.guards_context.dynamo_guards |= false_guards
tx.output.tracing_context.guards_context.dynamo_guards |= true_guards

# Add tracking
tx.output.tensor_weakref_to_sizes_strides.update(
true_tensor_weakref_to_sizes_strides
)
tx.output.tensor_weakref_to_sizes_strides.update(
false_tensor_weakref_to_sizes_strides
)

true_name = add_subgraph(
"true", torch.fx.GraphModule(true_nn_modules, true_graph)
)
Expand Down Expand Up @@ -1095,9 +1109,16 @@ def speculate_branch(branch):
parent_tracked_fakes = parent_cmp.output.tracked_fakes
body_tracked_fakes = body_cmp.output.tracked_fakes
tx.output.tracked_fakes = list({*parent_tracked_fakes, *body_tracked_fakes})
body_tensor_weakref_to_sizes_strides = (
body_cmp.output.tensor_weakref_to_sizes_strides
)

# Add guards
tx.output.tracing_context.guards_context.dynamo_guards |= body_guards
# Add tracking
tx.output.tensor_weakref_to_sizes_strides.update(
body_tensor_weakref_to_sizes_strides
)

body_name = add_subgraph(
"body", torch.fx.GraphModule(body_nn_modules, body_graph)
Expand Down

0 comments on commit a145a33

Please sign in to comment.