-
Notifications
You must be signed in to change notification settings - Fork 565
Enable PP and EP overlap for MoE #1721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
3a61b86
to
0f7a7c9
Compare
Running with:
CUDA_LAUNCH_BLOCKING
|
0f7a7c9
to
6584aac
Compare
a6e46c7
to
5810c54
Compare
Just landed pytorch/pytorch#162016, so once CI picks up the nightly the errors should be fixed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very cool! Left some comments and questions.
Also looking forward to benchmarking results with overlapping enabled vs. disabled. In particular, for the 16B model, we should be able to test out on 8 GPUs, assuming SAC is composable.
|
||
[activation_checkpoint] | ||
mode = "selective" # ["none", "selective", "full"] | ||
mode = "none" # ["none", "selective", "full"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it not support SAC?
mscale=0.70, | ||
use_flex_attn=True, | ||
attn_mask_type="block_causal", | ||
use_flex_attn=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is FlexAttention not supported? It sounds unrelated.
return stages, models | ||
|
||
|
||
# TODO: is there a better place to put this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about putting them into distributed/dual_pipe_v.py
?
|
||
def run_backward(): | ||
# Set the backward thread to use the same stream as forward | ||
torch.cuda.set_stream(main_cuda_stream) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar -- can we change it to neutral calls
def run_backward(): | ||
# Set the backward thread to use the same stream as forward | ||
torch.cuda.set_stream(main_cuda_stream) | ||
with record_function( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
always enabling this may hurt perf?
if _hook_coordinator._coordination_enabled and hook_name == "D": | ||
_hook_coordinator._cycle_count += 1 | ||
# print(f"[FORWARD] cycle count: {_hook_coordinator._cycle_count}", "=" * 40) | ||
if not _hook_coordinator.check_should_continue_coordination(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is only called in SyncHook.forward
. Is it safe if for a particular overlap_f_b
call, the backward stage has more layers than the forward stage?
backward_mb_index, | ||
) | ||
|
||
def run_forward(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my education:
The run_forward()
and run_backward()
functions look general and not tied to DualPipe. Do we not have such functions in pytorch pipelining code?
full_backward=True, | ||
last_backward=last_backward, | ||
) | ||
grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may not work well with gradient accumulation. See what we did in #1732
_hook_coordinator.disable_coordination() | ||
return x | ||
|
||
_hook_coordinator.barrier() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strictly speaking, the barrier only has effect on the CPU threads, and it only forces the compute and a2a to be dispatched to GPU at the same time. But looking from the GPU perspective, it doesn't guarantee the execution of compute kernels and a2a are actually overlapped.
It may work in cases where there happen to have GPU-CPU syncs in the right places in the MoE layer (e.g. token index H2D copy etc). But I suspect it would fail to overlap as we remove those syncs (the community is working toward more efficient no-sync MoE implementations).
Theoretically we should use cuda event wait between compute/comm streams, not thread wait.
9e43a67
to
7cf98e4
Compare
Fixed one issue with FSDP last reshard not being called. Rest is mostly refactoring, changing some variables to be class variables so they can be used in pytorch/torchtitan#1721 Pull Request resolved: #165513 Approved by: https://github.com/fegin
] | ||
|
||
|
||
import fbvscode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be removed
Option 2 of #1682
These changes add a custom
overlap_callback
function to replace the OVERLAP_F_B action that is run during the schedule execution. In the custom function, we writerun_forward()
andrun_backward()
.run_backward()
is run as a separate thread so that we can have both forward and backward running together side by side. Looks like this:In order for these changes to work with Expert Parallel, we also need to add custom autograd functions to act as the boundary points at which we do communication. We added hooks before and after expert parallel dispatch and combine to signal boundary points, so our figure from before now turns into:
Now in each of these red blocks, we use a global coordinator. We need
threading.Barrier(2).wait()
so that the comm and compute from our forward and backward steps are scheduled in lock-step before continuing.DSv3 16B run command:
Trace examples: