Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

We probably are allowing mutations to happen on fake tensor in VariableTracker #93610

Closed
ezyang opened this issue Nov 4, 2022 · 8 comments
Closed
Assignees
Labels
bug module: dynamo triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented Nov 4, 2022

🐛 Describe the bug

VariableTracker is supposed to be immutable, so we can do checkpointing and rollback when instruction translation fails and we need to create a graph break.

However, we (indirectly) store fake tensor in VariableTracker, and fake tensor is not treated immutably. In particular, when we run nodes on fake tensor, those nodes may modify the metadata of the input fake tensor, and I don't see any logic that appropriately duplicates fake tensors so that we can safely run modifications on them.

The hypothetical bug is that we checkpoint, we then run the symbolic evaluator on squeeze_, and that mutates the fake tensor, but then we discover we need to rollback, and now subsequent code will see incorrect metadata for the input to squeeze. Discussing this with jansel, this is unlikely to happen in practice today, because the only time we rollback is right before we are generating a graph break, and at that point we're not going to actually insert any more nodes into the graph. But it will be good to get this correct, in case we end up using checkpointing more seriously in the future. More generally, we probably need some sort of lint to make sure that VariableTrackers do not actually get mutated, because they're not going to turn into real bugs immediately if you corrupt the state.

cc @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire @eellison @jansel

Error logs

No response

Minified repro

No response

@ezyang ezyang added the bug label Nov 4, 2022
@ezyang
Copy link
Contributor Author

ezyang commented Nov 4, 2022

Also note the obvious fix of copying fake tensors is wrong, because that won't preserve leaf-ness correctly on the tensor. You must do the fancy algorithm in meta converter.

@eellison
Copy link
Contributor

eellison commented Nov 4, 2022

I think we can detect this and re-compute all fake tensors from the real_value_cache as needed

@ezyang
Copy link
Contributor Author

ezyang commented Nov 4, 2022

I thought real_value_cache is supposed to be an export only thing though

@eellison
Copy link
Contributor

eellison commented Nov 4, 2022

Well, in either case, recompute all of the fake tensors needed from the input right ? although it might be a pain swapping out of the fake tensors in various guards

@ezyang
Copy link
Contributor Author

ezyang commented Nov 4, 2022

We could, but if we were willing to rerun the bytecode from the beginning we wouldn't even need immutable VariableTracker lol.

@mlazos mlazos added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 7, 2022
ezyang referenced this issue Dec 7, 2022
…rsively call export"


The original implementation of cond() operator support in dynamo operated by recursively calling export() on the inner subgraph.  This is problematic for a number of reasons:

* My original motivating reason: the original implementation had to play tricks to feed real tensors to the recursive export call, which means that it doesn't work well with tracing with dynamic shapes (where we MUST stay in fake tensors to accurately track dynamic shapes across the cond invocation)
* If there are pending side effects, the recursive export() call won't see those side effects (as they are only tracked by Dynamo, not actually applied to the Python environment.) You can see an example where dynamo cond tracing does the wrong thing at #90208
* If there were side effects inside the true/false branch, these side effects were silently lost (as the export only returns the graph of tensor operations, and not any of the residual Python bytecodes necessary to reapply any side effects.) This could have substantive effects on the export of subsequent parts of the model, as those parts of the models could rely on the side effects.
* It was not possible to track NN module accesses inside the true/false branches, necessitating a hack where the NN module was explicitly passed in as an input to cond #87020 (comment) which doesn't really make any sense from a backend compilation perspective
* Guards induced from the inside of the true/false branch were not properly propagated to the top level guards; they were just silently dropped (in fact, the original implementation checked that the true/false branch produce the same guards which... is not useful? Like, I don't think that actually is even necessary for correctness)

This PR replaces the old implementation with a new implementation based on graphstate checkpointing. The basic idea is to process a cond(), we checkpoint the state of our interpreter, run the true branch, rollback to our checkpoint, run the false branch, rollback to our checkpoint and then merge the changes from both of the checkpoints. I require the true/false branches to have exactly the same side effects, but union their guards.

Some of the details:

* Dynamo is too aggressive with tracking side effects when processing closures, c.f. https://github.com/pytorch/torchdynamo/pull/233/files#r1040480078 The basic problem is whenever I define a closure, this immediately counts as a side effect, even if I didn't actually mutate anything. This triggered on the nested cond export example. To prevent this from happening, I optimistically avoid tracking side effects, but if a STORE_DEREF happens, I restart analysis with the relevant Source.name() added to `mutated_closure_cell_contents` so we start tracking on closure allocation. This is enough to fix the relevant test.
* For the most part, I assert that the graph states must be equivalent after applying the true/false branches. During debugging, I found it useful to be able to compare two graph states and give a better description about what the divergence was. You can test this using the `diff()` method I've added to a few structures.
* The implementation now supports NestedUserFunctionVariable, which is nice as it allows the true/false branches to be defined closer to the cond implementation.
* I fixed the naming of the true/false subgraphs; previously they were named `name_0`, `name_1`, now they are named `cond_true_0` and `cond_false_0`
* I added `name_to_input` to the saved graph state. I don't actually know if this is necessary, but it seemed like a good idea.
* I have to play some tricks to get the speculating execution of the true/false branch to record into a subgraph. After a careful read of OutputGraph, I found that what would work is overriding graph with a fresh Graph that we want to write things into, and manually setting up the inputs/outputs. It's a little delicate as you have to make sure you reset the Graph to its original before you restore a checkpoint, as checkpoints don't actually save graph for efficiency, and just undo changes on the graph. This capability may usefully get refactored to OutputGraph but I didn't do it in this PR for simplicity.

There are some further problems with the cond() implementation that I leave for future work. Most of these were preexisting with the original implementation.

* Not a problem per se, but if an NN module is used by both the true/false branch, it will show up in the final graph twice (since it has to be a submodule of the GraphModule that makes use of it.) I hope the export pipeline can deal with this.
* List of tensor output for cond is not supported.
* The true/false return values may not have consistent sizes/dims/etc, and we don't check them for consistency.
* If we modify fake tensors in the true/false branches, we aren't rolling them back, c.f. https://github.com/pytorch/torchdynamo/issues/1840

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang referenced this issue Dec 7, 2022
…ort"


The original implementation of cond() operator support in dynamo operated by recursively calling export() on the inner subgraph.  This is problematic for a number of reasons:

* My original motivating reason: the original implementation had to play tricks to feed real tensors to the recursive export call, which means that it doesn't work well with tracing with dynamic shapes (where we MUST stay in fake tensors to accurately track dynamic shapes across the cond invocation)
* If there are pending side effects, the recursive export() call won't see those side effects (as they are only tracked by Dynamo, not actually applied to the Python environment.) You can see an example where dynamo cond tracing does the wrong thing at #90208
* If there were side effects inside the true/false branch, these side effects were silently lost (as the export only returns the graph of tensor operations, and not any of the residual Python bytecodes necessary to reapply any side effects.) This could have substantive effects on the export of subsequent parts of the model, as those parts of the models could rely on the side effects.
* It was not possible to track NN module accesses inside the true/false branches, necessitating a hack where the NN module was explicitly passed in as an input to cond #87020 (comment) which doesn't really make any sense from a backend compilation perspective
* Guards induced from the inside of the true/false branch were not properly propagated to the top level guards; they were just silently dropped (in fact, the original implementation checked that the true/false branch produce the same guards which... is not useful? Like, I don't think that actually is even necessary for correctness)

This PR replaces the old implementation with a new implementation based on graphstate checkpointing. The basic idea is to process a cond(), we checkpoint the state of our interpreter, run the true branch, rollback to our checkpoint, run the false branch, rollback to our checkpoint and then merge the changes from both of the checkpoints. I require the true/false branches to have exactly the same side effects, but union their guards.

Some of the details:

* Dynamo is too aggressive with tracking side effects when processing closures, c.f. https://github.com/pytorch/torchdynamo/pull/233/files#r1040480078 The basic problem is whenever I define a closure, this immediately counts as a side effect, even if I didn't actually mutate anything. This triggered on the nested cond export example. To prevent this from happening, I optimistically avoid tracking side effects, but if a STORE_DEREF happens, I restart analysis with the relevant Source.name() added to `mutated_closure_cell_contents` so we start tracking on closure allocation. This is enough to fix the relevant test.
* For the most part, I assert that the graph states must be equivalent after applying the true/false branches. During debugging, I found it useful to be able to compare two graph states and give a better description about what the divergence was. You can test this using the `diff()` method I've added to a few structures.
* The implementation now supports NestedUserFunctionVariable, which is nice as it allows the true/false branches to be defined closer to the cond implementation.
* I fixed the naming of the true/false subgraphs; previously they were named `name_0`, `name_1`, now they are named `cond_true_0` and `cond_false_0`
* I added `name_to_input` to the saved graph state. I don't actually know if this is necessary, but it seemed like a good idea.
* I have to play some tricks to get the speculating execution of the true/false branch to record into a subgraph. After a careful read of OutputGraph, I found that what would work is overriding graph with a fresh Graph that we want to write things into, and manually setting up the inputs/outputs. It's a little delicate as you have to make sure you reset the Graph to its original before you restore a checkpoint, as checkpoints don't actually save graph for efficiency, and just undo changes on the graph. This capability may usefully get refactored to OutputGraph but I didn't do it in this PR for simplicity.

There are some further problems with the cond() implementation that I leave for future work. Most of these were preexisting with the original implementation.

* Not a problem per se, but if an NN module is used by both the true/false branch, it will show up in the final graph twice (since it has to be a submodule of the GraphModule that makes use of it.) I hope the export pipeline can deal with this.
* List of tensor output for cond is not supported.
* The true/false return values may not have consistent sizes/dims/etc, and we don't check them for consistency.
* If we modify fake tensors in the true/false branches, we aren't rolling them back, c.f. https://github.com/pytorch/torchdynamo/issues/1840

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang referenced this issue Dec 7, 2022
…rsively call export"


The original implementation of cond() operator support in dynamo operated by recursively calling export() on the inner subgraph.  This is problematic for a number of reasons:

* My original motivating reason: the original implementation had to play tricks to feed real tensors to the recursive export call, which means that it doesn't work well with tracing with dynamic shapes (where we MUST stay in fake tensors to accurately track dynamic shapes across the cond invocation)
* If there are pending side effects, the recursive export() call won't see those side effects (as they are only tracked by Dynamo, not actually applied to the Python environment.) You can see an example where dynamo cond tracing does the wrong thing at #90208
* If there were side effects inside the true/false branch, these side effects were silently lost (as the export only returns the graph of tensor operations, and not any of the residual Python bytecodes necessary to reapply any side effects.) This could have substantive effects on the export of subsequent parts of the model, as those parts of the models could rely on the side effects.
* It was not possible to track NN module accesses inside the true/false branches, necessitating a hack where the NN module was explicitly passed in as an input to cond #87020 (comment) which doesn't really make any sense from a backend compilation perspective
* Guards induced from the inside of the true/false branch were not properly propagated to the top level guards; they were just silently dropped (in fact, the original implementation checked that the true/false branch produce the same guards which... is not useful? Like, I don't think that actually is even necessary for correctness)

This PR replaces the old implementation with a new implementation based on graphstate checkpointing. The basic idea is to process a cond(), we checkpoint the state of our interpreter, run the true branch, rollback to our checkpoint, run the false branch, rollback to our checkpoint and then merge the changes from both of the checkpoints. I require the true/false branches to have exactly the same side effects, but union their guards.

Some of the details:

* Dynamo is too aggressive with tracking side effects when processing closures, c.f. https://github.com/pytorch/torchdynamo/pull/233/files#r1040480078 The basic problem is whenever I define a closure, this immediately counts as a side effect, even if I didn't actually mutate anything. This triggered on the nested cond export example. To prevent this from happening, I optimistically avoid tracking side effects, but if a STORE_DEREF happens, I restart analysis with the relevant Source.name() added to `mutated_closure_cell_contents` so we start tracking on closure allocation. This is enough to fix the relevant test.
* For the most part, I assert that the graph states must be equivalent after applying the true/false branches. During debugging, I found it useful to be able to compare two graph states and give a better description about what the divergence was. You can test this using the `diff()` method I've added to a few structures.
* The implementation now supports NestedUserFunctionVariable, which is nice as it allows the true/false branches to be defined closer to the cond implementation.
* I fixed the naming of the true/false subgraphs; previously they were named `name_0`, `name_1`, now they are named `cond_true_0` and `cond_false_0`
* I added `name_to_input` to the saved graph state. I don't actually know if this is necessary, but it seemed like a good idea.
* I have to play some tricks to get the speculating execution of the true/false branch to record into a subgraph. After a careful read of OutputGraph, I found that what would work is overriding graph with a fresh Graph that we want to write things into, and manually setting up the inputs/outputs. It's a little delicate as you have to make sure you reset the Graph to its original before you restore a checkpoint, as checkpoints don't actually save graph for efficiency, and just undo changes on the graph. This capability may usefully get refactored to OutputGraph but I didn't do it in this PR for simplicity.

There are some further problems with the cond() implementation that I leave for future work. Most of these were preexisting with the original implementation.

* Not a problem per se, but if an NN module is used by both the true/false branch, it will show up in the final graph twice (since it has to be a submodule of the GraphModule that makes use of it.) I hope the export pipeline can deal with this.
* List of tensor output for cond is not supported.
* The true/false return values may not have consistent sizes/dims/etc, and we don't check them for consistency.
* If we modify fake tensors in the true/false branches, we aren't rolling them back, c.f. https://github.com/pytorch/torchdynamo/issues/1840

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang referenced this issue Dec 7, 2022
…ort"


The original implementation of cond() operator support in dynamo operated by recursively calling export() on the inner subgraph.  This is problematic for a number of reasons:

* My original motivating reason: the original implementation had to play tricks to feed real tensors to the recursive export call, which means that it doesn't work well with tracing with dynamic shapes (where we MUST stay in fake tensors to accurately track dynamic shapes across the cond invocation)
* If there are pending side effects, the recursive export() call won't see those side effects (as they are only tracked by Dynamo, not actually applied to the Python environment.) You can see an example where dynamo cond tracing does the wrong thing at #90208
* If there were side effects inside the true/false branch, these side effects were silently lost (as the export only returns the graph of tensor operations, and not any of the residual Python bytecodes necessary to reapply any side effects.) This could have substantive effects on the export of subsequent parts of the model, as those parts of the models could rely on the side effects.
* It was not possible to track NN module accesses inside the true/false branches, necessitating a hack where the NN module was explicitly passed in as an input to cond #87020 (comment) which doesn't really make any sense from a backend compilation perspective
* Guards induced from the inside of the true/false branch were not properly propagated to the top level guards; they were just silently dropped (in fact, the original implementation checked that the true/false branch produce the same guards which... is not useful? Like, I don't think that actually is even necessary for correctness)

This PR replaces the old implementation with a new implementation based on graphstate checkpointing. The basic idea is to process a cond(), we checkpoint the state of our interpreter, run the true branch, rollback to our checkpoint, run the false branch, rollback to our checkpoint and then merge the changes from both of the checkpoints. I require the true/false branches to have exactly the same side effects, but union their guards.

Some of the details:

* Dynamo is too aggressive with tracking side effects when processing closures, c.f. https://github.com/pytorch/torchdynamo/pull/233/files#r1040480078 The basic problem is whenever I define a closure, this immediately counts as a side effect, even if I didn't actually mutate anything. This triggered on the nested cond export example. To prevent this from happening, I optimistically avoid tracking side effects, but if a STORE_DEREF happens, I restart analysis with the relevant Source.name() added to `mutated_closure_cell_contents` so we start tracking on closure allocation. This is enough to fix the relevant test.
* For the most part, I assert that the graph states must be equivalent after applying the true/false branches. During debugging, I found it useful to be able to compare two graph states and give a better description about what the divergence was. You can test this using the `diff()` method I've added to a few structures.
* The implementation now supports NestedUserFunctionVariable, which is nice as it allows the true/false branches to be defined closer to the cond implementation.
* I fixed the naming of the true/false subgraphs; previously they were named `name_0`, `name_1`, now they are named `cond_true_0` and `cond_false_0`
* I added `name_to_input` to the saved graph state. I don't actually know if this is necessary, but it seemed like a good idea.
* I have to play some tricks to get the speculating execution of the true/false branch to record into a subgraph. After a careful read of OutputGraph, I found that what would work is overriding graph with a fresh Graph that we want to write things into, and manually setting up the inputs/outputs. It's a little delicate as you have to make sure you reset the Graph to its original before you restore a checkpoint, as checkpoints don't actually save graph for efficiency, and just undo changes on the graph. This capability may usefully get refactored to OutputGraph but I didn't do it in this PR for simplicity.

There are some further problems with the cond() implementation that I leave for future work. Most of these were preexisting with the original implementation.

* Not a problem per se, but if an NN module is used by both the true/false branch, it will show up in the final graph twice (since it has to be a submodule of the GraphModule that makes use of it.) I hope the export pipeline can deal with this.
* List of tensor output for cond is not supported.
* The true/false return values may not have consistent sizes/dims/etc, and we don't check them for consistency.
* If we modify fake tensors in the true/false branches, we aren't rolling them back, c.f. https://github.com/pytorch/torchdynamo/issues/1840

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang referenced this issue Dec 7, 2022
…rsively call export"


The original implementation of cond() operator support in dynamo operated by recursively calling export() on the inner subgraph.  This is problematic for a number of reasons:

* My original motivating reason: the original implementation had to play tricks to feed real tensors to the recursive export call, which means that it doesn't work well with tracing with dynamic shapes (where we MUST stay in fake tensors to accurately track dynamic shapes across the cond invocation)
* If there are pending side effects, the recursive export() call won't see those side effects (as they are only tracked by Dynamo, not actually applied to the Python environment.) You can see an example where dynamo cond tracing does the wrong thing at #90208
* If there were side effects inside the true/false branch, these side effects were silently lost (as the export only returns the graph of tensor operations, and not any of the residual Python bytecodes necessary to reapply any side effects.) This could have substantive effects on the export of subsequent parts of the model, as those parts of the models could rely on the side effects.
* It was not possible to track NN module accesses inside the true/false branches, necessitating a hack where the NN module was explicitly passed in as an input to cond #87020 (comment) which doesn't really make any sense from a backend compilation perspective
* Guards induced from the inside of the true/false branch were not properly propagated to the top level guards; they were just silently dropped (in fact, the original implementation checked that the true/false branch produce the same guards which... is not useful? Like, I don't think that actually is even necessary for correctness)

This PR replaces the old implementation with a new implementation based on graphstate checkpointing. The basic idea is to process a cond(), we checkpoint the state of our interpreter, run the true branch, rollback to our checkpoint, run the false branch, rollback to our checkpoint and then merge the changes from both of the checkpoints. I require the true/false branches to have exactly the same side effects, but union their guards.

Some of the details:

* Dynamo is too aggressive with tracking side effects when processing closures, c.f. https://github.com/pytorch/torchdynamo/pull/233/files#r1040480078 The basic problem is whenever I define a closure, this immediately counts as a side effect, even if I didn't actually mutate anything. This triggered on the nested cond export example. To prevent this from happening, I optimistically avoid tracking side effects, but if a STORE_DEREF happens, I restart analysis with the relevant Source.name() added to `mutated_closure_cell_contents` so we start tracking on closure allocation. This is enough to fix the relevant test.
* For the most part, I assert that the graph states must be equivalent after applying the true/false branches. During debugging, I found it useful to be able to compare two graph states and give a better description about what the divergence was. You can test this using the `diff()` method I've added to a few structures.
* The implementation now supports NestedUserFunctionVariable, which is nice as it allows the true/false branches to be defined closer to the cond implementation.
* I fixed the naming of the true/false subgraphs; previously they were named `name_0`, `name_1`, now they are named `cond_true_0` and `cond_false_0`
* I added `name_to_input` to the saved graph state. I don't actually know if this is necessary, but it seemed like a good idea.
* I have to play some tricks to get the speculating execution of the true/false branch to record into a subgraph. After a careful read of OutputGraph, I found that what would work is overriding graph with a fresh Graph that we want to write things into, and manually setting up the inputs/outputs. It's a little delicate as you have to make sure you reset the Graph to its original before you restore a checkpoint, as checkpoints don't actually save graph for efficiency, and just undo changes on the graph. This capability may usefully get refactored to OutputGraph but I didn't do it in this PR for simplicity.

There are some further problems with the cond() implementation that I leave for future work. Most of these were preexisting with the original implementation.

* Not a problem per se, but if an NN module is used by both the true/false branch, it will show up in the final graph twice (since it has to be a submodule of the GraphModule that makes use of it.) I hope the export pipeline can deal with this.
* List of tensor output for cond is not supported.
* The true/false return values may not have consistent sizes/dims/etc, and we don't check them for consistency.
* If we modify fake tensors in the true/false branches, we aren't rolling them back, c.f. https://github.com/pytorch/torchdynamo/issues/1840

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang referenced this issue Dec 7, 2022
…ort"


The original implementation of cond() operator support in dynamo operated by recursively calling export() on the inner subgraph.  This is problematic for a number of reasons:

* My original motivating reason: the original implementation had to play tricks to feed real tensors to the recursive export call, which means that it doesn't work well with tracing with dynamic shapes (where we MUST stay in fake tensors to accurately track dynamic shapes across the cond invocation)
* If there are pending side effects, the recursive export() call won't see those side effects (as they are only tracked by Dynamo, not actually applied to the Python environment.) You can see an example where dynamo cond tracing does the wrong thing at #90208
* If there were side effects inside the true/false branch, these side effects were silently lost (as the export only returns the graph of tensor operations, and not any of the residual Python bytecodes necessary to reapply any side effects.) This could have substantive effects on the export of subsequent parts of the model, as those parts of the models could rely on the side effects.
* It was not possible to track NN module accesses inside the true/false branches, necessitating a hack where the NN module was explicitly passed in as an input to cond #87020 (comment) which doesn't really make any sense from a backend compilation perspective
* Guards induced from the inside of the true/false branch were not properly propagated to the top level guards; they were just silently dropped (in fact, the original implementation checked that the true/false branch produce the same guards which... is not useful? Like, I don't think that actually is even necessary for correctness)

This PR replaces the old implementation with a new implementation based on graphstate checkpointing. The basic idea is to process a cond(), we checkpoint the state of our interpreter, run the true branch, rollback to our checkpoint, run the false branch, rollback to our checkpoint and then merge the changes from both of the checkpoints. I require the true/false branches to have exactly the same side effects, but union their guards.

Some of the details:

* Dynamo is too aggressive with tracking side effects when processing closures, c.f. https://github.com/pytorch/torchdynamo/pull/233/files#r1040480078 The basic problem is whenever I define a closure, this immediately counts as a side effect, even if I didn't actually mutate anything. This triggered on the nested cond export example. To prevent this from happening, I optimistically avoid tracking side effects, but if a STORE_DEREF happens, I restart analysis with the relevant Source.name() added to `mutated_closure_cell_contents` so we start tracking on closure allocation. This is enough to fix the relevant test.
* For the most part, I assert that the graph states must be equivalent after applying the true/false branches. During debugging, I found it useful to be able to compare two graph states and give a better description about what the divergence was. You can test this using the `diff()` method I've added to a few structures.
* The implementation now supports NestedUserFunctionVariable, which is nice as it allows the true/false branches to be defined closer to the cond implementation.
* I fixed the naming of the true/false subgraphs; previously they were named `name_0`, `name_1`, now they are named `cond_true_0` and `cond_false_0`
* I added `name_to_input` to the saved graph state. I don't actually know if this is necessary, but it seemed like a good idea.
* I have to play some tricks to get the speculating execution of the true/false branch to record into a subgraph. After a careful read of OutputGraph, I found that what would work is overriding graph with a fresh Graph that we want to write things into, and manually setting up the inputs/outputs. It's a little delicate as you have to make sure you reset the Graph to its original before you restore a checkpoint, as checkpoints don't actually save graph for efficiency, and just undo changes on the graph. This capability may usefully get refactored to OutputGraph but I didn't do it in this PR for simplicity.

There are some further problems with the cond() implementation that I leave for future work. Most of these were preexisting with the original implementation.

* Not a problem per se, but if an NN module is used by both the true/false branch, it will show up in the final graph twice (since it has to be a submodule of the GraphModule that makes use of it.) I hope the export pipeline can deal with this.
* List of tensor output for cond is not supported.
* The true/false return values may not have consistent sizes/dims/etc, and we don't check them for consistency.
* If we modify fake tensors in the true/false branches, we aren't rolling them back, c.f. https://github.com/pytorch/torchdynamo/issues/1840

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
pytorchmergebot referenced this issue Dec 8, 2022
The original implementation of cond() operator support in dynamo operated by recursively calling export() on the inner subgraph.  This is problematic for a number of reasons:

* My original motivating reason: the original implementation had to play tricks to feed real tensors to the recursive export call, which means that it doesn't work well with tracing with dynamic shapes (where we MUST stay in fake tensors to accurately track dynamic shapes across the cond invocation)
* If there are pending side effects, the recursive export() call won't see those side effects (as they are only tracked by Dynamo, not actually applied to the Python environment.) You can see an example where dynamo cond tracing does the wrong thing at #90208
* If there were side effects inside the true/false branch, these side effects were silently lost (as the export only returns the graph of tensor operations, and not any of the residual Python bytecodes necessary to reapply any side effects.) This could have substantive effects on the export of subsequent parts of the model, as those parts of the models could rely on the side effects.
* It was not possible to track NN module accesses inside the true/false branches, necessitating a hack where the NN module was explicitly passed in as an input to cond #87020 (comment) which doesn't really make any sense from a backend compilation perspective
* Guards induced from the inside of the true/false branch were not properly propagated to the top level guards; they were just silently dropped (in fact, the original implementation checked that the true/false branch produce the same guards which... is not useful? Like, I don't think that actually is even necessary for correctness)

This PR replaces the old implementation with a new implementation based on graphstate checkpointing. The basic idea is to process a cond(), we checkpoint the state of our interpreter, run the true branch, rollback to our checkpoint, run the false branch, rollback to our checkpoint and then merge the changes from both of the checkpoints. I require the true/false branches to have exactly the same side effects, but union their guards.

Some of the details:

* Dynamo is too aggressive with tracking side effects when processing closures, c.f. https://github.com/pytorch/torchdynamo/pull/233/files#r1040480078 The basic problem is whenever I define a closure, this immediately counts as a side effect, even if I didn't actually mutate anything. This triggered on the nested cond export example. To prevent this from happening, I optimistically avoid tracking side effects, but if a STORE_DEREF happens, I restart analysis with the relevant Source.name() added to `mutated_closure_cell_contents` so we start tracking on closure allocation. This is enough to fix the relevant test.
* For the most part, I assert that the graph states must be equivalent after applying the true/false branches. During debugging, I found it useful to be able to compare two graph states and give a better description about what the divergence was. You can test this using the `diff()` method I've added to a few structures.
* The implementation now supports NestedUserFunctionVariable, which is nice as it allows the true/false branches to be defined closer to the cond implementation.
* I fixed the naming of the true/false subgraphs; previously they were named `name_0`, `name_1`, now they are named `cond_true_0` and `cond_false_0`
* I added `name_to_input` to the saved graph state. I don't actually know if this is necessary, but it seemed like a good idea.
* I have to play some tricks to get the speculating execution of the true/false branch to record into a subgraph. After a careful read of OutputGraph, I found that what would work is overriding graph with a fresh Graph that we want to write things into, and manually setting up the inputs/outputs. It's a little delicate as you have to make sure you reset the Graph to its original before you restore a checkpoint, as checkpoints don't actually save graph for efficiency, and just undo changes on the graph. This capability may usefully get refactored to OutputGraph but I didn't do it in this PR for simplicity.

There are some further problems with the cond() implementation that I leave for future work. Most of these were preexisting with the original implementation.

* Not a problem per se, but if an NN module is used by both the true/false branch, it will show up in the final graph twice (since it has to be a submodule of the GraphModule that makes use of it.) I hope the export pipeline can deal with this.
* List of tensor output for cond is not supported.
* The true/false return values may not have consistent sizes/dims/etc, and we don't check them for consistency.
* If we modify fake tensors in the true/false branches, we aren't rolling them back, c.f. https://github.com/pytorch/torchdynamo/issues/1840

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: #90286
Approved by: https://github.com/voznesenskym
kulinseth referenced this issue in kulinseth/pytorch Dec 10, 2022
…h#90286)

The original implementation of cond() operator support in dynamo operated by recursively calling export() on the inner subgraph.  This is problematic for a number of reasons:

* My original motivating reason: the original implementation had to play tricks to feed real tensors to the recursive export call, which means that it doesn't work well with tracing with dynamic shapes (where we MUST stay in fake tensors to accurately track dynamic shapes across the cond invocation)
* If there are pending side effects, the recursive export() call won't see those side effects (as they are only tracked by Dynamo, not actually applied to the Python environment.) You can see an example where dynamo cond tracing does the wrong thing at pytorch#90208
* If there were side effects inside the true/false branch, these side effects were silently lost (as the export only returns the graph of tensor operations, and not any of the residual Python bytecodes necessary to reapply any side effects.) This could have substantive effects on the export of subsequent parts of the model, as those parts of the models could rely on the side effects.
* It was not possible to track NN module accesses inside the true/false branches, necessitating a hack where the NN module was explicitly passed in as an input to cond pytorch#87020 (comment) which doesn't really make any sense from a backend compilation perspective
* Guards induced from the inside of the true/false branch were not properly propagated to the top level guards; they were just silently dropped (in fact, the original implementation checked that the true/false branch produce the same guards which... is not useful? Like, I don't think that actually is even necessary for correctness)

This PR replaces the old implementation with a new implementation based on graphstate checkpointing. The basic idea is to process a cond(), we checkpoint the state of our interpreter, run the true branch, rollback to our checkpoint, run the false branch, rollback to our checkpoint and then merge the changes from both of the checkpoints. I require the true/false branches to have exactly the same side effects, but union their guards.

Some of the details:

* Dynamo is too aggressive with tracking side effects when processing closures, c.f. https://github.com/pytorch/torchdynamo/pull/233/files#r1040480078 The basic problem is whenever I define a closure, this immediately counts as a side effect, even if I didn't actually mutate anything. This triggered on the nested cond export example. To prevent this from happening, I optimistically avoid tracking side effects, but if a STORE_DEREF happens, I restart analysis with the relevant Source.name() added to `mutated_closure_cell_contents` so we start tracking on closure allocation. This is enough to fix the relevant test.
* For the most part, I assert that the graph states must be equivalent after applying the true/false branches. During debugging, I found it useful to be able to compare two graph states and give a better description about what the divergence was. You can test this using the `diff()` method I've added to a few structures.
* The implementation now supports NestedUserFunctionVariable, which is nice as it allows the true/false branches to be defined closer to the cond implementation.
* I fixed the naming of the true/false subgraphs; previously they were named `name_0`, `name_1`, now they are named `cond_true_0` and `cond_false_0`
* I added `name_to_input` to the saved graph state. I don't actually know if this is necessary, but it seemed like a good idea.
* I have to play some tricks to get the speculating execution of the true/false branch to record into a subgraph. After a careful read of OutputGraph, I found that what would work is overriding graph with a fresh Graph that we want to write things into, and manually setting up the inputs/outputs. It's a little delicate as you have to make sure you reset the Graph to its original before you restore a checkpoint, as checkpoints don't actually save graph for efficiency, and just undo changes on the graph. This capability may usefully get refactored to OutputGraph but I didn't do it in this PR for simplicity.

There are some further problems with the cond() implementation that I leave for future work. Most of these were preexisting with the original implementation.

* Not a problem per se, but if an NN module is used by both the true/false branch, it will show up in the final graph twice (since it has to be a submodule of the GraphModule that makes use of it.) I hope the export pipeline can deal with this.
* List of tensor output for cond is not supported.
* The true/false return values may not have consistent sizes/dims/etc, and we don't check them for consistency.
* If we modify fake tensors in the true/false branches, we aren't rolling them back, c.f. https://github.com/pytorch/torchdynamo/issues/1840

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: pytorch#90286
Approved by: https://github.com/voznesenskym
@malfet malfet transferred this issue from pytorch/torchdynamo Feb 1, 2023
@ezyang
Copy link
Contributor Author

ezyang commented Feb 17, 2023

I think this is the true root cause of #94797

@eellison
Copy link
Contributor

I can get to this in a few weeks if @blzheng doesn't get to it first

voznesenskym added a commit that referenced this issue Apr 26, 2023
…truth"


There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand). 

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors. 

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time. 

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`. 

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.



cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
voznesenskym added a commit that referenced this issue Apr 27, 2023
…truth"


There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand). 

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors. 

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time. 

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`. 

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.



cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
voznesenskym added a commit that referenced this issue Apr 27, 2023
…truth"


There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand). 

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors. 

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time. 

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`. 

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.



cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
voznesenskym added a commit that referenced this issue Apr 27, 2023
…truth"


There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand). 

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors. 

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time. 

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`. 

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.



cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue Apr 27, 2023
…table source of truth"


There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand). 

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors. 

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time. 

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`. 

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.



cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue Apr 27, 2023
…truth"


There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand). 

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors. 

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time. 

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`. 

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.



cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
voznesenskym added a commit that referenced this issue Apr 27, 2023
…table source of truth"


There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand). 

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors. 

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time. 

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`. 

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.



cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
voznesenskym added a commit that referenced this issue Apr 27, 2023
…truth"


There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand). 

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors. 

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time. 

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`. 

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.



cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue Apr 27, 2023
)

There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand).

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors.

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time.

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`.

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.

Pull Request resolved: #100128
Approved by: https://github.com/ezyang
@eellison
Copy link
Contributor

eellison commented Dec 5, 2023

Seems like it's fixed by Voz's pr.

@eellison eellison closed this as completed Dec 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug module: dynamo triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants