-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Fix final callbacks for reentrant backwards #35066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Closes #24965 Prior to this commit, final_callbacks_ are cleared on exit of ANY backward. When using reentrant backward, the last backward would remove all callbacks from the engine. However, this might lead to unexpected behavior. For example, the application could install a final callback after forward, and expecting this callback to fire when all gradients are ready. If there is a renentrant backward on a subgraph, it would fire the callback and delete it on exit, meaning that when fired, not all gradients are ready. The 1st attempt was trying to move the callback to the GraphTask in engine::execute(). However, this failed because more callbacks could be installed during backward pass. Final callbacks are stored as part of the context for each reentrant backward. Thisi PR converts final_callbacks_ into a map that maps reentrant depth to the callbacks installed during the lifetime of the corresponding backward. More specifically: * Insertion: when inserting callbacks, queue_callback() uses the thread_local total_depth to find the vector to insert into. * Deletion: the ClearCallbacks guard remembers the depth of the current backward call, and only erases the corresponding vector from the map on exit. * Execution: at the end of backward with depth x, the engine runs all callbacks installed for depth smaller or equal to x. This design decision tries to stay consistent with the previous version where we only have a single global final_callbacks_ vector, which is executed and then cleared by the last reentrant backward. Note that: (1) Using a map instead of a vector to avoid the overhead of creating an empty vector for each depth, while the callbacks might be only installed in a small subset of those reentrant backward calls. (2) The execution order should still be the same as the installation order as the engine will not resume working on depth x before finishing depth x + 1 and all previously installed callbacks in depth x are run before callbacks installed in depth x + 1. [ghstack-poisoned]
Closes #24965 Prior to this commit, final_callbacks_ are cleared on exit of ANY backward. When using reentrant backward, the last backward would remove all callbacks from the engine. However, this might lead to unexpected behavior. For example, the application could install a final callback after forward, and expecting this callback to fire when all gradients are ready. If there is a renentrant backward on a subgraph, it would fire the callback and delete it on exit, meaning that when fired, not all gradients are ready. The 1st attempt was trying to move the callback to the GraphTask in engine::execute(). However, this failed because more callbacks could be installed during backward pass. Final callbacks are stored as part of the context for each reentrant backward. Thisi PR converts final_callbacks_ into a map that maps reentrant depth to the callbacks installed during the lifetime of the corresponding backward. More specifically: * Insertion: when inserting callbacks, queue_callback() uses the thread_local total_depth to find the vector to insert into. * Deletion: the ClearCallbacks guard remembers the depth of the current backward call, and only erases the corresponding vector from the map on exit. * Execution: at the end of backward with depth x, the engine runs all callbacks installed for depth smaller or equal to x. This design decision tries to stay consistent with the previous version where we only have a single global final_callbacks_ vector, which is executed and then cleared by the last reentrant backward. Note that: (1) Using a map instead of a vector to avoid the overhead of creating an empty vector for each depth, while the callbacks might be only installed in a small subset of those reentrant backward calls. (2) The execution order should still be the same as the installation order as the engine will not resume working on depth x before finishing depth x + 1 and all previously installed callbacks in depth x are run before callbacks installed in depth x + 1. ghstack-source-id: 0858bbd Pull Request resolved: #35066
💊 CircleCI build failures summary and remediationsAs of commit 8cf2130 (more details on the Dr. CI page): ✅ None of the build failures appear to be your fault 💚
🚧 2 upstream failures:These were probably caused by upstream breakages:
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 32 times. |
Closes #24965 Prior to this commit, final_callbacks_ are cleared on exit of ANY backward. When using reentrant backward, the last backward would remove all callbacks from the engine. However, this might lead to unexpected behavior. For example, the application could install a final callback after forward, and expecting this callback to fire when all gradients are ready. If there is a renentrant backward on a subgraph, it would fire the callback and delete it on exit, meaning that when fired, not all gradients are ready. ### Failed Attempt The 1st attempt was trying to move the callback to the GraphTask in engine::execute(). However, this failed because more callbacks could be installed during backward pass. ### Solution in This PR Final callbacks are stored as part of the context for each reentrant backward. Thisi PR converts final_callbacks_ into a map that maps reentrant depth to the callbacks installed during the lifetime of the corresponding backward. More specifically: * Insertion: when inserting callbacks, queue_callback() uses the thread_local total_depth to find the vector to insert into. * Deletion: the ClearCallbacks guard remembers the depth of the current backward call, and only erases the corresponding vector from the map on exit. * Execution: at the end of backward with depth x, the engine runs all callbacks installed for depth smaller or equal to x. This design decision tries to stay consistent with the previous version where we only have a single global final_callbacks_ vector, which is executed and then cleared by the last reentrant backward. Note that: (1) Using a map instead of a vector to avoid the overhead of creating an empty vector for each depth, while the callbacks might be only installed in a small subset of those reentrant backward calls. (2) The execution order should still be the same as the installation order as the engine will not resume working on depth x before finishing depth x + 1 and all previously installed callbacks in depth x are run before callbacks installed in depth x + 1. Differential Revision: [D20546474](https://our.internmc.facebook.com/intern/diff/D20546474) [ghstack-poisoned]
Closes #24965 Prior to this commit, final_callbacks_ are cleared on exit of ANY backward. When using reentrant backward, the last backward would remove all callbacks from the engine. However, this might lead to unexpected behavior. For example, the application could install a final callback after forward, and expecting this callback to fire when all gradients are ready. If there is a renentrant backward on a subgraph, it would fire the callback and delete it on exit, meaning that when fired, not all gradients are ready. The 1st attempt was trying to move the callback to the GraphTask in engine::execute(). However, this failed because more callbacks could be installed during backward pass. Final callbacks are stored as part of the context for each reentrant backward. Thisi PR converts final_callbacks_ into a map that maps reentrant depth to the callbacks installed during the lifetime of the corresponding backward. More specifically: * Insertion: when inserting callbacks, queue_callback() uses the thread_local total_depth to find the vector to insert into. * Deletion: the ClearCallbacks guard remembers the depth of the current backward call, and only erases the corresponding vector from the map on exit. * Execution: at the end of backward with depth x, the engine runs all callbacks installed for depth smaller or equal to x. This design decision tries to stay consistent with the previous version where we only have a single global final_callbacks_ vector, which is executed and then cleared by the last reentrant backward. Note that: (1) Using a map instead of a vector to avoid the overhead of creating an empty vector for each depth, while the callbacks might be only installed in a small subset of those reentrant backward calls. (2) The execution order should still be the same as the installation order as the engine will not resume working on depth x before finishing depth x + 1 and all previously installed callbacks in depth x are run before callbacks installed in depth x + 1. ghstack-source-id: 34c4715 Pull Request resolved: #35066
Closes #24965 Prior to this commit, final_callbacks_ are cleared on exit of ANY backward. When using reentrant backward, the last backward would remove all callbacks from the engine. However, this might lead to unexpected behavior. For example, the application could install a final callback after forward, and expecting this callback to fire when all gradients are ready. If there is a renentrant backward on a subgraph, it would fire the callback and delete it on exit, meaning that when fired, not all gradients are ready. ### Failed Attempt The 1st attempt was trying to move the callback to the GraphTask in engine::execute(). However, this failed because more callbacks could be installed during backward pass. ### Solution in This PR Final callbacks are stored as part of the context for each reentrant backward. Thisi PR converts final_callbacks_ into a map that maps reentrant depth to the callbacks installed during the lifetime of the corresponding backward. More specifically: * Insertion: when inserting callbacks, queue_callback() uses the thread_local total_depth to find the vector to insert into. * Deletion: the ClearCallbacks guard remembers the depth of the current backward call, and only erases the corresponding vector from the map on exit. * Execution: at the end of backward with depth x, the engine runs all callbacks installed for depth smaller or equal to x. This design decision tries to stay consistent with the previous version where we only have a single global final_callbacks_ vector, which is executed and then cleared by the last reentrant backward. Note that: (1) Using a map instead of a vector to avoid the overhead of creating an empty vector for each depth, while the callbacks might be only installed in a small subset of those reentrant backward calls. (2) The execution order should still be the same as the installation order as the engine will not resume working on depth x before finishing depth x + 1 and all previously installed callbacks in depth x are run before callbacks installed in depth x + 1. Differential Revision: [D20546474](https://our.internmc.facebook.com/intern/diff/D20546474) [ghstack-poisoned]
Closes #24965 Prior to this commit, final_callbacks_ are cleared on exit of ANY backward. When using reentrant backward, the last backward would remove all callbacks from the engine. However, this might lead to unexpected behavior. For example, the application could install a final callback after forward, and expecting this callback to fire when all gradients are ready. If there is a renentrant backward on a subgraph, it would fire the callback and delete it on exit, meaning that when fired, not all gradients are ready. ### Failed Attempt The 1st attempt was trying to move the callback to the GraphTask in engine::execute(). However, this failed because more callbacks could be installed during backward pass. ### Solution in This PR Final callbacks are stored as part of the context for each reentrant backward. Thisi PR converts final_callbacks_ into a map that maps reentrant depth to the callbacks installed during the lifetime of the corresponding backward. More specifically: * Insertion: when inserting callbacks, queue_callback() uses the thread_local total_depth to find the vector to insert into. * Deletion: the ClearCallbacks guard remembers the depth of the current backward call, and only erases the corresponding vector from the map on exit. * Execution: at the end of backward with depth x, the engine runs all callbacks installed for depth smaller or equal to x. This design decision tries to stay consistent with the previous version where we only have a single global final_callbacks_ vector, which is executed and then cleared by the last reentrant backward. Note that: (1) Using a map instead of a vector to avoid the overhead of creating an empty vector for each depth, while the callbacks might be only installed in a small subset of those reentrant backward calls. (2) The execution order should still be the same as the installation order as the engine will not resume working on depth x before finishing depth x + 1 and all previously installed callbacks in depth x are run before callbacks installed in depth x + 1. Differential Revision: [D20546474](https://our.internmc.facebook.com/intern/diff/D20546474) [ghstack-poisoned]
Closes #24965 Prior to this commit, final_callbacks_ are cleared on exit of ANY backward. When using reentrant backward, the last backward would remove all callbacks from the engine. However, this might lead to unexpected behavior. For example, the application could install a final callback after forward, and expecting this callback to fire when all gradients are ready. If there is a renentrant backward on a subgraph, it would fire the callback and delete it on exit, meaning that when fired, not all gradients are ready. ### Failed Attempt The 1st attempt was trying to move the callback to the GraphTask in engine::execute(). However, this failed because more callbacks could be installed during backward pass. ### Solution in This PR Final callbacks are stored as part of the context for each reentrant backward. This PR converts final_callbacks_ into a map that maps reentrant depth to the callbacks installed during the lifetime of the corresponding backward. More specifically: * Insertion: when inserting callbacks, queue_callback() uses the thread_local total_depth to find the vector to insert into. * Deletion: the ClearCallbacks guard remembers the depth of the current backward call, and only erases the corresponding vector from the map on exit. * Execution: at the end of backward with depth x, the engine only executes callbacks installed during the corresponding backward with depth x. Note that: Using a map instead of a vector to avoid the overhead of creating an empty vector for each depth, while the callbacks might be only installed in a small subset of those reentrant backward calls. Differential Revision: [D20546474](https://our.internmc.facebook.com/intern/diff/D20546474) [ghstack-poisoned]
Closes #24965 Prior to this commit, final_callbacks_ are cleared on exit of ANY backward. When using reentrant backward, the last backward would remove all callbacks from the engine. However, this might lead to unexpected behavior. For example, the application could install a final callback after forward, and expecting this callback to fire when all gradients are ready. If there is a renentrant backward on a subgraph, it would fire the callback and delete it on exit, meaning that when fired, not all gradients are ready. **Failed Attempt** The 1st attempt was trying to move the callback to the GraphTask in engine::execute(). However, this failed because more callbacks could be installed during backward pass. **Current Solution** Final callbacks are stored as a member variable in the GraphTask. * Insertion: use the thread_local current_graph_task to find the target GraphTask, and append final callback. * Deletion: final callbacks have the same lifetime as a GraphTask * Execution: Use the GraphTask provided in the argument to find final callbacks. Differential Revision: [D20546474](https://our.internmc.facebook.com/intern/diff/D20546474) [ghstack-poisoned]
Closes #24965 Prior to this commit, final_callbacks_ are cleared on exit of ANY backward. When using reentrant backward, the last backward would remove all callbacks from the engine. However, this might lead to unexpected behavior. For example, the application could install a final callback after forward, and expecting this callback to fire when all gradients are ready. If there is a renentrant backward on a subgraph, it would fire the callback and delete it on exit, meaning that when fired, not all gradients are ready. **Failed Attempt** The 1st attempt was trying to move the callback to the GraphTask in engine::execute(). However, this failed because more callbacks could be installed during backward pass. **Current Solution** Final callbacks are stored as a member variable in the GraphTask. * Insertion: use the thread_local current_graph_task to find the target GraphTask, and append final callback. * Deletion: final callbacks have the same lifetime as a GraphTask * Execution: Use the GraphTask provided in the argument to find final callbacks. Differential Revision: [D20546474](https://our.internmc.facebook.com/intern/diff/D20546474) [ghstack-poisoned]
Closes #24965 Prior to this commit, final_callbacks_ are cleared on exit of ANY backward. When using reentrant backward, the last backward would remove all callbacks from the engine. However, this might lead to unexpected behavior. For example, the application could install a final callback after forward, and expecting this callback to fire when all gradients are ready. If there is a renentrant backward on a subgraph, it would fire the callback and delete it on exit, meaning that when fired, not all gradients are ready. **Failed Attempt** The 1st attempt was trying to move the callback to the GraphTask in engine::execute(). However, this failed because more callbacks could be installed during backward pass. **Current Solution** Final callbacks are stored as a member variable in the GraphTask. * Insertion: use the thread_local current_graph_task to find the target GraphTask, and append final callback. * Deletion: final callbacks have the same lifetime as a GraphTask * Execution: Use the GraphTask provided in the argument to find final callbacks. Differential Revision: [D20546474](https://our.internmc.facebook.com/intern/diff/D20546474) [ghstack-poisoned]
torch/csrc/autograd/engine.cpp
Outdated
|
|
||
| // The guard that sets and restores current_graph_task. | ||
| struct GraphTaskGuard { | ||
| GraphTaskGuard(const std::shared_ptr<GraphTask>& graph_task) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're taking ownership of GraphTask, so you should accept graph_task as shared_ptr by value (not const ref)
Closes #24965 Prior to this commit, final_callbacks_ are cleared on exit of ANY backward. When using reentrant backward, the last backward would remove all callbacks from the engine. However, this might lead to unexpected behavior. For example, the application could install a final callback after forward, and expecting this callback to fire when all gradients are ready. If there is a renentrant backward on a subgraph, it would fire the callback and delete it on exit, meaning that when fired, not all gradients are ready. **Failed Attempt** The 1st attempt was trying to move the callback to the GraphTask in engine::execute(). However, this failed because more callbacks could be installed during backward pass. **Current Solution** Final callbacks are stored as a member variable in the GraphTask. * Insertion: use the thread_local current_graph_task to find the target GraphTask, and append final callback. * Deletion: final callbacks have the same lifetime as a GraphTask * Execution: Use the GraphTask provided in the argument to find final callbacks. Differential Revision: [D20546474](https://our.internmc.facebook.com/intern/diff/D20546474) [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Thanks for making all the changes :)
Closes #24965 Prior to this commit, final_callbacks_ are cleared on exit of ANY backward. When using reentrant backward, the last backward would remove all callbacks from the engine. However, this might lead to unexpected behavior. For example, the application could install a final callback after forward, and expecting this callback to fire when all gradients are ready. If there is a renentrant backward on a subgraph, it would fire the callback and delete it on exit, meaning that when fired, not all gradients are ready. **Failed Attempt** The 1st attempt was trying to move the callback to the GraphTask in engine::execute(). However, this failed because more callbacks could be installed during backward pass. **Current Solution** Final callbacks are stored as a member variable in the GraphTask. * Insertion: use the thread_local current_graph_task to find the target GraphTask, and append final callback. * Deletion: final callbacks have the same lifetime as a GraphTask * Execution: Use the GraphTask provided in the argument to find final callbacks. Differential Revision: [D20546474](https://our.internmc.facebook.com/intern/diff/D20546474) [ghstack-poisoned]
Stack from ghstack:
Closes #24965
Prior to this commit, final_callbacks_ are cleared on exit of ANY
backward. When using reentrant backward, the last backward would
remove all callbacks from the engine. However, this might lead to
unexpected behavior. For example, the application could install
a final callback after forward, and expecting this callback to fire
when all gradients are ready. If there is a renentrant backward on
a subgraph, it would fire the callback and delete it on exit,
meaning that when fired, not all gradients are ready.
Failed Attempt
The 1st attempt was trying to move the callback to the GraphTask
in engine::execute(). However, this failed because more callbacks
could be installed during backward pass.
Current Solution
Final callbacks are stored as a member variable in the GraphTask.
target GraphTask, and append final callback.
final callbacks.
Differential Revision: D20546474