-
Notifications
You must be signed in to change notification settings - Fork 611
CUDAGraph support for SimpleFSDP and TP #2050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| def copy_static_inputs(self, *args): | ||
| for i in self.input_indices_to_copy: | ||
| self.args[i].copy_(args[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could replace this for loop with foreach copy. However, I empirically observed there is only 1 tensor to copy for fwd and 1 tensor to copy for bwd. So no need to add code complexity here.
| self.cudagraph, pool=self.graph_pool, stream=self.stream | ||
| ): | ||
| # `output` is managed by pytorch's cudagraph pool | ||
| self.output = self.runnable(*args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could potentially use weakref for output tensor to reduce memory. Will do in a followup pr.
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! should handle not persisting input and output for serious use. I would also add assertions for assumptions that will manifest as silent incorrectness, at least behind a config. Also, we probably shouldn't globally turn off expandable segments when cudagraphs is not enabled.
| input_addresses = [ | ||
| x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we're assuming that the non tensor inputs are the same every time ? Should we just assert they're all tensors if we're not handling the other cases ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, there would only be tensor and symint (for moe layer). let me add assertion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are also rng_state: torch._C.Generator, used by
graphsafe_run_with_rng_state_2 = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten._scaled_dot_product_flash_attention.default, transpose_20, transpose_21, transpose_22, 0.0, True, scale = 0.25, rng_state = fwd_rng_state_2);
See the last 3 args in P2047035404
|
|
||
| self.copy_static_inputs(*args) | ||
| self.cudagraph.replay() | ||
| return self.output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The persistent input and output is not good for memory, as you've commented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes will add in the next pr.
| self.args = None | ||
| self.output = None | ||
|
|
||
| def copy_static_inputs(self, *args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if any of the static inputs changes, you'll get silent incorrectness. you might consider at least a config to check this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes added
| ( | ||
| _global_dummy_graph, | ||
| _global_graph_pool, | ||
| _global_graph_capture_stream, | ||
| ) = init_global_graph_pool() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work when backward is on a separate stream ? or not an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this is not an issue currently. since fwd and bwd are on the same cuda stream by default.
cudagraph trees has used the same graph capture stream for both fwd and bwd.
https://github.com/pytorch/pytorch/blob/7a928397cda89b71c24b0efe9db6df7fb04a46cb/torch/_inductor/cudagraph_trees.py#L1945
run_train.sh
Outdated
|
|
||
| # need to turn off expandable segments when using cudagraph, since | ||
| # it does not work with cg and nccl yet. | ||
| # https://github.com/pytorch/pytorch/issues/158029 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we turn this off only when using cudagraph ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently it's on by default. when using cudagraph, we need to explicitly turn it off with USE_EXPANDABLE_SEGMENTS=False [other commands].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of turning off expandable segments you can turn off nccl memory registration, as the issue suggests
torchtitan/train.py
Outdated
| # in joint_graph_module. An explicit gc.collect() is necessary | ||
| # to clean up reference cycles. | ||
| for part in self.model_parts: | ||
| part.joint_graph_module = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think joint_graph_module only exists for compiler toolkit experiments. So for other experiments or training runs part won't have joint_graph_module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can only cleanup if it has the joint_graph_module
yiming0416
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
experiments part LGTM
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds great! Had a suggestion.
torchtitan/train.py
Outdated
| # in joint_graph_module. An explicit gc.collect() is necessary | ||
| # to clean up reference cycles. | ||
| for part in self.model_parts: | ||
| if hasattr(part, "joint_graph_module"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
joint_graph_module is exclusively used for compiler_toolkit right? If it can't be made general, let's create a Trainer subclass to overwrite this method. E.g. https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/torchcomms/train.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding CudaGraph pass (#2050) would require some custom logic in Trainer's close() method. So we create a Trainer subclass in compiler toolkit
Features
Command:
Note: we use
NCCL_GRAPH_REGISTER=0due to a known issue that nccl + cudagraphs + expandable segments result in IMA. pytorch/pytorch#158029trace
Result
Numerics:
Achieved bitwise equivalence w/ and w/o cudagraph pass on llama3.1-8B AND llama3.1-70B.
Performance:

Raw log: llama3-8b, llama3-70b
Memory:
On llama3.1-70b, cudagraph takes 6% more memory consumption (143 GiB vs 153 GiB).
A few tricks to reduce memory consumption (use llama3.1-70b w/ cudagraph as an example):
static input copy:
On llama3.1-70B, for forward, we copy 1 tensor of 128 bytes; for backward, we copy 1 tensor of 0.98 GB. This shows static input indices is handled correctly.
Followup PR
In the followup PR, I will enable fx graph partition for deepseek v3 pytorch/pytorch#165945.