-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Makes AccumulateGrad high priority in backwards passes #7604
Conversation
@teng-li This should also fix the bad perf you've been seeing while scaling gloo. |
I've attached a minimal working sample that exposes the symptoms and hopefully illustrates our point. |
The failure on linux-trusty-py3.6-gcc7.2 seems unrelated: 00:58:43 gcc -pthread -B /opt/conda/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -fPIC -I/var/lib/jenkins/workspace -I/var/lib/jenkins/workspace/torch/csrc -I/var/lib/jenkins/workspace/third_party/pybind11/include -I/var/lib/jenkins/workspace/torch/lib/tmp_install/include -I/var/lib/jenkins/workspace/torch/lib/tmp_install/include/TH -I/var/lib/jenkins/workspace/torch/lib/tmp_install/include/THNN -I/var/lib/jenkins/workspace/torch/lib/tmp_install/include/ATen -I/opt/conda/lib/python3.6/site-packages/numpy/core/include -I/var/lib/jenkins/workspace/torch/lib/tmp_install/include/THD -I/opt/conda/include/python3.6m -c torch/csrc/jit/script/compiler.cpp -o build/temp.linux-x86_64-3.6/torch/csrc/jit/script/compiler.o -D_THP_CORE -std=c++11 -Wall -Wextra -Wno-unused-parameter -Wno-missing-field-initializers -Wno-write-strings -Wno-zero-length-array -Wno-unknown-pragmas -Wno-deprecated-declarations -fno-strict-aliasing -Wno-missing-braces -Werror -DWITH_NUMPY -DWITH_DISTRIBUTED |
@mruberry nice debugging to figure out the backward hooks issue. Your reputation is well-calibrated! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, but I think the implementation could be greatly simplified.
torch/csrc/autograd/function.h
Outdated
@@ -303,6 +316,7 @@ struct Function : std::enable_shared_from_this<Function> { | |||
// Since `Function`s are neither copyable nor moveable, we can have const | |||
// fields. | |||
const uint64_t sequence_nr_; | |||
const int backwards_priority_; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
The errors seem unrelated. On Windows: On linux-trusty-py3.6-gcc7.2: |
thanks @mruberry ! |
* upstream/master: Makes AccumulateGrad high priority in backwards passes (pytorch#7604) [C++ API] Implement builder style construction (pytorch#7597) C10D: Added TCPStore to support C10D store interface (pytorch#7560) [auto] Update onnx to ba86ec2 - Protobuf typing (onnx/onnx#982) onnx/onnx@ba86ec2 Add LBFGS optimization algorithm to C++ API (pytorch#7596)
* Makes accumulate_grad functions high priority in backwards passes * Delegating constructor and comments * Sequence_nr ain't pretty no more * Sequence_nr ain't pretty no more
This change is a little tricky, and is best understood with a simple example.
Assume you have a simple model like ab = a + b, cd = c + d, final = ab + cd and you have some loss and call backward(). Backwards will evaluate the final summation, then the cd summation, then the ab summation, then c and d, where it will accumulate the gradient into both c and d, then finally a and b where it will accumulate the gradient into both a and b. So far so good.
Now if you zero_grad() (or not) and run the model again the order of the backward pass changes and the AccumulateGrad functions are called last.
This is because when functions are created they are assigned a (monotonically increasing, thread local) sequence number which is used to prioritize them in the backwards pass. On the first pass through the graph the AccumulateGrad functions have sequence numbers near (actually just higher than) those of the first functions where the variables are used. This means that the AccumulateGrad functions are called as soon as possible on the first call to backward().
On the second pass, however, the functions other than AccumulateGrad are new functions and so are all given higher sequence numbers than the AccumulateGrad functions. This means that the AccumulateGrad functions will all be processed last, not as soon as possible.
Processing all the AccumulateGrad functions at the end of the backward pass is unfortunate when trying to overlap communication with the backward pass, like @csarofeen and @mcarilli are doing. This PR adds a new "backwards_priority" to functions that makes backward() process AccumulateGrad ASAP. This should also make the order of AccumulateGrad calls consistent across passes through the same network.
Adding backwards_priority to functions was a little tricky because the existing constructor has a default first argument. To avoid changing the current constructor I added a new one. The former constructor is now a delegating constructor to avoid code duplication. The implementation of backwards_priority is otherwise analogous to sequence_nr. Priorities during backwards are now set by backwards_priority first and sequence_nr second.