Skip to content

Conversation

@BoyuanFeng
Copy link
Contributor

@BoyuanFeng BoyuanFeng commented Nov 17, 2025

Features

  • Support SimpleFSDP and TP
  • Support static input indices to reduce copy
  • Support memory reuse to reduce memory consumption
  • Cleanup cudagraph when training finishes to avoid nccl hang from destroy_process_group

Command:

NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4  --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes cudagraph

Note: we use NCCL_GRAPH_REGISTER=0 due to a known issue that nccl + cudagraphs + expandable segments result in IMA. pytorch/pytorch#158029

trace

Result

Numerics:
Achieved bitwise equivalence w/ and w/o cudagraph pass on llama3.1-8B AND llama3.1-70B.

Performance:
image

Raw log: llama3-8b, llama3-70b

Memory:
On llama3.1-70b, cudagraph takes 6% more memory consumption (143 GiB vs 153 GiB).

A few tricks to reduce memory consumption (use llama3.1-70b w/ cudagraph as an example):

  • Start: 161 GiB
  • + use the same stream for warmup and graph capture of both fwd and bwd: 160 GiB
  • + warmup in cudagraph memory pool instead of eager memory pool: 153 GiB

static input copy:
On llama3.1-70B, for forward, we copy 1 tensor of 128 bytes; for backward, we copy 1 tensor of 0.98 GB. This shows static input indices is handled correctly.

Followup PR

In the followup PR, I will enable fx graph partition for deepseek v3 pytorch/pytorch#165945.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 17, 2025
@BoyuanFeng BoyuanFeng marked this pull request as draft November 17, 2025 23:41

def copy_static_inputs(self, *args):
for i in self.input_indices_to_copy:
self.args[i].copy_(args[i])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could replace this for loop with foreach copy. However, I empirically observed there is only 1 tensor to copy for fwd and 1 tensor to copy for bwd. So no need to add code complexity here.

self.cudagraph, pool=self.graph_pool, stream=self.stream
):
# `output` is managed by pytorch's cudagraph pool
self.output = self.runnable(*args)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could potentially use weakref for output tensor to reduce memory. Will do in a followup pr.

@BoyuanFeng BoyuanFeng marked this pull request as ready for review November 18, 2025 21:40
Copy link

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! should handle not persisting input and output for serious use. I would also add assertions for assumptions that will manifest as silent incorrectness, at least behind a config. Also, we probably shouldn't globally turn off expandable segments when cudagraphs is not enabled.

Comment on lines 100 to 102
input_addresses = [
x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args
]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we're assuming that the non tensor inputs are the same every time ? Should we just assert they're all tensors if we're not handling the other cases ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, there would only be tensor and symint (for moe layer). let me add assertion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are also rng_state: torch._C.Generator, used by

graphsafe_run_with_rng_state_2 = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten._scaled_dot_product_flash_attention.default, transpose_20, transpose_21, transpose_22, 0.0, True, scale = 0.25, rng_state = fwd_rng_state_2);

See the last 3 args in P2047035404


self.copy_static_inputs(*args)
self.cudagraph.replay()
return self.output

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The persistent input and output is not good for memory, as you've commented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes will add in the next pr.

self.args = None
self.output = None

def copy_static_inputs(self, *args):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if any of the static inputs changes, you'll get silent incorrectness. you might consider at least a config to check this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes added

Comment on lines +52 to +56
(
_global_dummy_graph,
_global_graph_pool,
_global_graph_capture_stream,
) = init_global_graph_pool()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work when backward is on a separate stream ? or not an issue?

Copy link
Contributor Author

@BoyuanFeng BoyuanFeng Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this is not an issue currently. since fwd and bwd are on the same cuda stream by default.

cudagraph trees has used the same graph capture stream for both fwd and bwd.
https://github.com/pytorch/pytorch/blob/7a928397cda89b71c24b0efe9db6df7fb04a46cb/torch/_inductor/cudagraph_trees.py#L1945

run_train.sh Outdated

# need to turn off expandable segments when using cudagraph, since
# it does not work with cg and nccl yet.
# https://github.com/pytorch/pytorch/issues/158029

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we turn this off only when using cudagraph ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently it's on by default. when using cudagraph, we need to explicitly turn it off with USE_EXPANDABLE_SEGMENTS=False [other commands].

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of turning off expandable segments you can turn off nccl memory registration, as the issue suggests

# in joint_graph_module. An explicit gc.collect() is necessary
# to clean up reference cycles.
for part in self.model_parts:
part.joint_graph_module = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think joint_graph_module only exists for compiler toolkit experiments. So for other experiments or training runs part won't have joint_graph_module

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can only cleanup if it has the joint_graph_module

Copy link
Contributor

@yiming0416 yiming0416 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

experiments part LGTM

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great! Had a suggestion.

# in joint_graph_module. An explicit gc.collect() is necessary
# to clean up reference cycles.
for part in self.model_parts:
if hasattr(part, "joint_graph_module"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

joint_graph_module is exclusively used for compiler_toolkit right? If it can't be made general, let's create a Trainer subclass to overwrite this method. E.g. https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/torchcomms/train.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l Trainer subclass added in #2064

yiming0416 added a commit that referenced this pull request Nov 19, 2025
Adding CudaGraph pass (#2050)
would require some custom logic in Trainer's close() method.

So we create a Trainer subclass in compiler toolkit
@BoyuanFeng BoyuanFeng merged commit f5e3a84 into main Nov 20, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants