Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow consumer ops to sync on GraphRoot's gradient #45787

Closed
wants to merge 8 commits into from

Conversation

mcarilli
Copy link
Collaborator

@mcarilli mcarilli commented Oct 2, 2020

Currently, a GraphRoot instance doesn't have an associated stream. Streaming backward synchronization logic assumes the instance ran on the default stream, and tells consumer ops to sync with the default stream. If the gradient the GraphRoot instance passes to consumer backward ops was populated on a non-default stream, we have a race condition.

The race condition can exist even if the user doesn't give a manually populated gradient:

with torch.cuda.stream(side_stream):
    # loss.backward() implicitly synthesizes a one-element 1.0 tensor on side_stream
    # GraphRoot passes it to consumers, but consumers first sync on default stream, not side_stream.    
    loss.backward()

    # Internally to backward(), streaming-backward logic takes over, stuff executes on the same stream it ran on in forward,
    # and the side_stream context is irrelevant.  GraphRoot's interaction with its first consumer(s) is the spot where
    # the side_stream context causes a problem.

This PR fixes the race condition by associating a GraphRoot instance, at construction time, with the current stream(s) on the device(s) of the grads it will pass to consumers. (i think this relies on GraphRoot executing in the main thread, before backward thread(s) fork, because the grads were populated on the main thread.)

The test demonstrates the race condition. It fails reliably without the PR's GraphRoot diffs and passes with the GraphRoot diffs.

With the GraphRoot diffs, manually populating an incoming-gradient arg for backward (or torch.autograd.grad) and the actual call to autograd.backward will have the same stream-semantics relationship as any other pair of ops:

# implicit population is safe
with torch.cuda.stream(side_stream):
    loss.backward()

# explicit population in side stream then backward in side stream is safe
with torch.cuda.stream(side_stream):
    kickoff_grad = torch.ones_like(loss)
    loss.backward(gradient=kickoff_grad)

# explicit population in one stream then backward kickoff in another stream
# is NOT safe, even with this PR's diffs, but that unsafety is consistent with
# stream-semantics relationship of any pair of ops
kickoff_grad = torch.ones_like(loss)
with torch.cuda.stream(side_stream):
    loss.backward(gradient=kickoff_grad)

# Safe, as you'd expect for any pair of ops
kickoff_grad = torch.ones_like(loss)
side_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(side_stream):
    loss.backward(gradient=kickoff_grad)

This PR also adds the last three examples above to cuda docs and references them from autograd docstrings.

@mruberry mruberry removed the request for review from apaszke October 4, 2020 22:34
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good catch! Thanks for the fix and the doc update!

test/test_cuda.py Outdated Show resolved Hide resolved
test/test_cuda.py Show resolved Hide resolved
test/test_cuda.py Outdated Show resolved Hide resolved
test/test_cuda.py Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Oct 6, 2020

Codecov Report

Merging #45787 into master will increase coverage by 0.00%.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master   #45787   +/-   ##
=======================================
  Coverage   68.32%   68.32%           
=======================================
  Files         410      410           
  Lines       52978    52978           
=======================================
+ Hits        36195    36196    +1     
+ Misses      16783    16782    -1     
Impacted Files Coverage Δ
torch/autograd/__init__.py 84.28% <ø> (ø)
torch/tensor.py 88.86% <ø> (ø)
torch/testing/_internal/expecttest.py 78.57% <0.00%> (+1.02%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e1ff46b...9792303. Read the comment docs.

@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 6, 2020
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for the update!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@albanD merged this pull request in 5640b79.

@jeffdaily
Copy link
Collaborator

jeffdaily commented Oct 27, 2020

@albanD and @mcarilli, if you could lend your insight. The new test_streaming_backward_sync_graph_root is flaky on our ROCm CI. Using a host similar to our CI hosts, I can get it to somewhat reliably fail (approximately 1 out of 5 times) if I run the test in a loop while also running the stress command line tool.

I'd like to know if this indicates a problem in our runtime, or a problem in pytorch. This is the line that reliably fails:

self.assertEqual(a.grad, grad * b)

The assertEqual fails because it eventually runs torch.allclose() on a.grad and grad * b. a.grad is not None (I have asserted), so it seems to me that a.grad hasn't yet been updated by the time grad * b completes and the comparison kernels start. It seems there might be a missing cross-stream sync somewhere.

@mcarilli
Copy link
Collaborator Author

mcarilli commented Oct 27, 2020

@jeffdaily the test might be exposing https://github.com/pytorch/pytorch/pull/45787/files#diff-893b1eea27352f336f4cd832919e48d721e4e90186e63400b8596db6b82e7450R1772-R1773, which would be a separate issue with a separate fix.

Can you put a torch.cuda.synchronize() right after backward() and see if that helps?

@jeffdaily
Copy link
Collaborator

@mcarilli That seems to help. How should we proceed? File a new issue? Do you have an idea of how to fix this already?

@jeffdaily
Copy link
Collaborator

Created new issue. #47028

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants