Skip to content

Commit

Permalink
Update on "Add tensor to fake clone snapshot for immutable source of …
Browse files Browse the repository at this point in the history
…truth"


There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand). 

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors. 

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time. 

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`. 

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.



cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
  • Loading branch information
voznesenskym committed Apr 26, 2023
1 parent 623e47f commit 634b845
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,9 +1015,8 @@ def speculate_branch(branch):
tx.output.tracing_context.guards_context.dynamo_guards |= true_guards

# Add tracking

tx.output.tensor_id_to_fake_clone |= true_tensor_id_to_fake_clone
tx.output.tensor_id_to_fake_clone |= false_tensor_id_to_fake_clone
tx.output.tensor_id_to_fake_clone.update(true_tensor_id_to_fake_clone)
tx.output.tensor_id_to_fake_clone.update(false_tensor_id_to_fake_clone)

true_name = add_subgraph(
"true", torch.fx.GraphModule(true_nn_modules, true_graph)
Expand Down Expand Up @@ -1081,7 +1080,7 @@ def speculate_branch(branch):
# Add guards
tx.output.tracing_context.guards_context.dynamo_guards |= body_guards
# Add tracking
tx.output.tensor_id_to_fake_clone |= body_tensor_id_to_fake_clone
tx.output.tensor_id_to_fake_clone.update(body_tensor_id_to_fake_clone)

body_name = add_subgraph(
"body", torch.fx.GraphModule(body_nn_modules, body_graph)
Expand Down

0 comments on commit 634b845

Please sign in to comment.