Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.autograd.backward() fails to sync with other stream #47028

Open
jeffdaily opened this issue Oct 28, 2020 · 7 comments
Open

torch.autograd.backward() fails to sync with other stream #47028

jeffdaily opened this issue Oct 28, 2020 · 7 comments
Assignees
Labels
module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jeffdaily
Copy link
Collaborator

jeffdaily commented Oct 28, 2020

The new test written for #45787 suggests a possible failure scenario, which indeed occurs. It is a race condition, most often encountered by ROCm CI.

https://github.com/pytorch/pytorch/pull/45787/files#diff-893b1eea27352f336f4cd832919e48d721e4e90186e63400b8596db6b82e7450R1772-R1773

Putting torch.cuda.synchronize() right after backward() fixes the problem.

Originally posted by @mcarilli in #45787 (comment)

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved

@albanD
Copy link
Collaborator

albanD commented Oct 29, 2020

So the proper fix is just to add the synchronize in the test right?

@albanD albanD added module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 29, 2020
@jeffdaily
Copy link
Collaborator Author

@albanD, I'd like @mcarilli to weigh in here. It wasn't clear to me if they had something else in mind to fix this.

@mcarilli
Copy link
Collaborator

mcarilli commented Nov 9, 2020

I'm fairly sure backward() always syncs with the default stream when it's finished. However, if you run backward in a non-default stream context, I'm not sure if it also syncs with the ambient non-default stream instead of/in addition to the default stream. I'd have to look at engine.cpp again.

@jeffdaily
Copy link
Collaborator Author

We're still seeing this issue with our ROCm CI.
https://ci.pytorch.org/jenkins/job/pytorch-builds/job/pytorch-linux-bionic-rocm3.9-py3.6-test1/1163//console

Do we put a sync into the test to work around the issue, or fix the engine?

@jeffdaily
Copy link
Collaborator Author

I would like to work towards a resolution on this issue. The test continues to be flaky on ROCm, and I'd rather fix it than skip it.

Is this where the fix should be?

// Syncs leaf streams with default streams (if necessary)
// See note "Streaming backwards"
for (const auto& leaf_stream : leaf_streams) {
const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
const auto default_stream = guard.getDefaultStream(leaf_stream.device());
if (leaf_stream != default_stream) {
auto event = c10::Event{c10::DeviceType::CUDA};
event.record(leaf_stream);
default_stream.wait(event);
}
}
}

jeffdaily added a commit to ROCm/pytorch that referenced this issue Nov 23, 2020
Otherwise, this test will appear flaky for ROCm even though it is a
generic PyTorch issue.
facebook-github-bot pushed a commit that referenced this issue Nov 27, 2020
Summary:
Otherwise, this test will appear flaky for ROCm even though it is a generic PyTorch issue.

CC albanD

Pull Request resolved: #48405

Reviewed By: mrshenli

Differential Revision: D25183473

Pulled By: ngimel

fbshipit-source-id: 0fa19b5497a713cc6c5d251598e57cc7068604be
@ezyang
Copy link
Contributor

ezyang commented Dec 2, 2020

Sorry about the delay, this is a pretty tricky issue and I had to spend some time reading code.

The relevant comment for existing stream synchronization logic in autograd engine is:

// Note [Streaming backwards]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
// On CUDA devices the autograd engine's device operations are run on the
// same stream that ran them in forward. This requires automatically
// syncing the streams so that function A finishes producing its
// output before function B consumes it.
//
// This synchronization occurs when outputs are placed into input buffers.
// The functions corresponding to input buffer positions have metadata
// recording their streams from forward, and during backward this
// data is used to sync the producer's stream with the consumer's.
//
// When a CUDA function is run either all its inputs were accumulated on the
// stream used to run the function OR the inputs are on different devices
// and the function is responsible for properly acquiring them.
//
// Historically, the autograd engine ran all CUDA operations on their
// device's DEFAULT stream. This meant that syncing (implicitly or
// explicitly) with the default streams was required before and after
// calling backward(). It also meant, however, that syncing with
// the default streams after backward() was sufficient to ensure
// that backward() had finished running. To preserve this historic
// behavior the engine records "leaf streams," the streams of the
// leaf variables, and syncs them with their device's default stream
// at the end of backward. All other streams are already synchronized
// to happen before at least one leaf stream (per the above), so syncing
// the leaf streams with the default streams is sufficient to implement
// the historic behavior.

The leaf streams in this situation, however, are the streams associated with the gradient accumulators, which should all be associated with fwd_bwd_op_stream (when the leafs were used). So based on this analysis, the test is wrong, and we should do a sync with fwd_bwd_op_stream before checking the outputs, as the engine is not contractually obligated to do this sync.

@mcarilli would you agree with this analysis?

@mcarilli
Copy link
Collaborator

mcarilli commented Jun 11, 2021

Looking over old issues I forgot I was assigned, I saw this one and realized #57833 probably fixes it. One of the patterns #57833 wants to fix, the snippet under "Because of the inconsistency, in some cases it's hard to be safe:" in the original submission #54227, looks like it matches the case that was breaking in rocm CI (https://github.com/pytorch/pytorch/pull/45787/files#diff-893b1eea27352f336f4cd832919e48d721e4e90186e63400b8596db6b82e7450R1772-R1773).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
4 participants