-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Core][Distributed] remove graph mode function #4818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
f7d4195
remove graph mode function
youkaichao 46a60e5
change variable name
youkaichao 1a0721c
set stream
youkaichao 435ad07
Merge branch 'main' into graph
youkaichao b1767f2
remove old SIM117 comment
youkaichao 35b3351
add type annotation for stream and pool
youkaichao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from collections import namedtuple | ||
from contextlib import contextmanager, nullcontext | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
||
import torch | ||
|
@@ -13,45 +14,54 @@ | |
get_tp_pynccl_communicator) | ||
|
||
|
||
@contextmanager | ||
def graph_mode(): | ||
# In graph mode, we have to be very careful about the collective | ||
# operations. The current status is: | ||
# allreduce \ Mode | Eager | Graph | | ||
# -------------------------------------------- | ||
# custom allreduce | enabled | enabled | | ||
# PyNccl | disabled| enabled | | ||
# torch.distributed | enabled | disabled| | ||
# | ||
# Note that custom allreduce will have a runtime check, if the tensor size | ||
# is too large, it will fallback to the next available option. | ||
# In summary: When using CUDA graph, we use | ||
# either custom all-reduce kernel or pynccl. When not using CUDA | ||
# graph, we use either custom all-reduce kernel or PyTorch NCCL. | ||
# We always prioritize using custom all-reduce kernel but fall back | ||
# to PyTorch or pynccl if it is disabled or not supported. | ||
pynccl_comm = get_tp_pynccl_communicator() | ||
if pynccl_comm is None: | ||
context = nullcontext() | ||
else: | ||
context = pynccl_comm.change_state(enable=True, | ||
stream=torch.cuda.current_stream()) | ||
with context: | ||
yield | ||
@dataclass | ||
class GraphCaptureContext: | ||
stream: torch.cuda.Stream | ||
Comment on lines
+17
to
+19
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does this work for non-CUDA backends? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For XPU, this will be |
||
|
||
|
||
@contextmanager | ||
def graph_capture(): | ||
""" | ||
`graph_capture` is a context manager which should include the code that | ||
`graph_capture` is a context manager which should surround the code that | ||
is capturing the CUDA graph. Its main purpose is to ensure that the | ||
some operations will be run after the graph is captured, before the graph | ||
is replayed. | ||
is replayed. It returns a `GraphCaptureContext` object which contains the | ||
necessary data for the graph capture. Currently, it only contains the | ||
stream that the graph capture is running on. This stream is set to the | ||
current CUDA stream when the context manager is entered and reset to the | ||
default stream when the context manager is exited. This is to ensure that | ||
the graph capture is running on a separate stream from the default stream, | ||
in order to explicitly distinguish the kernels to capture | ||
from other kernels possibly launched on background in the default stream. | ||
""" | ||
stream = torch.cuda.Stream() | ||
graph_capture_context = GraphCaptureContext(stream) | ||
ca_comm = get_tp_ca_communicator() | ||
context = nullcontext() if ca_comm is None else ca_comm.capture() | ||
with context: | ||
yield | ||
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() | ||
with torch.cuda.stream(stream), maybe_ca_context: | ||
# In graph mode, we have to be very careful about the collective | ||
# operations. The current status is: | ||
# allreduce \ Mode | Eager | Graph | | ||
# -------------------------------------------- | ||
# custom allreduce | enabled | enabled | | ||
# PyNccl | disabled| enabled | | ||
# torch.distributed | enabled | disabled| | ||
# | ||
# Note that custom allreduce will have a runtime check, if the tensor | ||
# size is too large, it will fallback to the next available option. | ||
# In summary: When using CUDA graph, we use | ||
# either custom all-reduce kernel or pynccl. When not using CUDA | ||
# graph, we use either custom all-reduce kernel or PyTorch NCCL. | ||
# We always prioritize using custom all-reduce kernel but fall back | ||
# to PyTorch or pynccl if it is disabled or not supported. | ||
pynccl_comm = get_tp_pynccl_communicator() | ||
if pynccl_comm is None: | ||
maybe_pynccl_context = nullcontext() | ||
else: | ||
maybe_pynccl_context = pynccl_comm.change_state( | ||
enable=True, stream=torch.cuda.current_stream()) | ||
with maybe_pynccl_context: | ||
yield graph_capture_context | ||
|
||
|
||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, how do we make sure it's not using custom all reduce?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, even before this PR, we cannot make sure it's not using custom all reduce. It is true in CI because our CI does not have custom allreduce.
To solve this problem, another refactor is needed. We need to expose a new function to create tp groups with different communicators. That's my next PR to come!