From 658e73d91e6b62c9503fa717546b4bded6e7887b Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Fri, 14 Nov 2025 10:07:49 -0800 Subject: [PATCH 01/22] nit --- .../experiments/compiler_toolkit/README.md | 4 + .../compiler_toolkit/common_utils.py | 2 + .../experiments/compiler_toolkit/passes.py | 93 +++++++++++++++++++ torchtitan/models/deepseek_v3/__init__.py | 2 +- .../train_configs/deepseek_v3_16b.toml | 2 +- torchtitan/models/llama3/__init__.py | 2 +- torchtitan/models/llama3/model/model.py | 7 +- .../llama3/train_configs/debug_model.toml | 4 +- .../llama3/train_configs/llama3_70b.toml | 4 +- 9 files changed, 110 insertions(+), 10 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index 61207fc63b..14f6a8d848 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -44,3 +44,7 @@ NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./r ```shell NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor ``` + +NGPU=2 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes cudagraph_wrapper + +NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes cudagraph_wrapper diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index b7499b2f79..35066343c5 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -28,6 +28,8 @@ def parallelize_inputs(world_mesh, args, kwargs): def to_dtensor(tensor): if isinstance(tensor, torch.Tensor): return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()]) + + # return DTensor.from_local(tensor, world_mesh, [Replicate()]) return tensor dt_args = tree_map(to_dtensor, args) diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 1c00fd5c1b..49aea1b51e 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -14,6 +14,7 @@ import torch from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing from torch.fx.passes.regional_inductor import regional_inductor +from torch.utils._ordered_set import OrderedSet def autobucketing_reordering_pass( @@ -39,8 +40,100 @@ def regional_inductor_pass( return regional_inductor(gm, example_inputs) +from typing import Callable, Optional + +_global_graph_pool = torch.cuda.graph_pool_handle() + +# TODO: make output and args weakref to allow reuse. + + +class CUDAGraphWrapper: + def __init__( + self, + runnable: Callable, + graph_pool: Optional[torch.cuda._POOL_HANDLE] = None, + static_input_indices: Optional[tuple[int]] = None, + ): + self.runnable = runnable + self.graph_pool = _global_graph_pool # graph_pool if graph_pool is not None else torch.cuda.graph_pool_handle() + self.static_input_indices = OrderedSet( + static_input_indices if static_input_indices is not None else [] + ) + + self.cudagraph: Optional[torch.cuda.CUDAGraph] = None + + # TODO: weak ref + self.output = None + + self.has_warmup = False + + def __call__(self, *args, **kwargs): + # assume that args and kwargs have been copied to + # static tensors + + torch.cuda.synchronize() + print("a new run") + + if not self.has_warmup: + self.has_warmup = True + return self.runnable(*args, **kwargs) + + if self.cudagraph is None: + self.args = args + self.kwargs = kwargs + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + self.input_addresses = input_addresses + + self.cudagraph = torch.cuda.CUDAGraph() + + with torch.cuda.graph(self.cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + # TODO: use weak ref for output to reuse memory + self.output = self.runnable(*args, **kwargs) + + # TODO: add debug address check. + + if True: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == self.input_addresses, ( + f"Input addresses for cudagraphs are different " + f"during replay. Expected {self.input_addresses}, " + f"got {new_input_addresses}" + ) + + for iter in range(10): + print(f"before iter {iter}") + self.cudagraph.replay() + print(f"after iter {iter}") + torch.cuda.synchronize() + + return self.output + + +def cudagraph_wrapper( + fn, example_inputs +): # , num_partitions: int, partition_id: int) -> Callable: + """ + Wrap a function with CUDAGraphWrapper. + @param fn: the function to be wrapped + @param metadata: the metadata of the function + @return: the wrapped function + """ + gc_disable = False # partition_id != 0 + return CUDAGraphWrapper(fn, gc_disable) + + # Registry mapping pass names to pass functions AVAILABLE_PASSES = { "autobucketing_reordering": autobucketing_reordering_pass, "regional_inductor": regional_inductor_pass, + "cudagraph_wrapper": cudagraph_wrapper, } + + +# TODO: cleanup graph before nccl destroy group diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 525bd96c13..7f9626480e 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -97,7 +97,7 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, + use_flex_attn=False, attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 00ec53310e..36238286ee 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -4,7 +4,7 @@ description = "DeepSeek-V3 16B model training" print_config = false [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = false diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 191588ad9e..701be870ff 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -28,7 +28,7 @@ llama3_args = { "debugmodel": TransformerModelArgs( - dim=256, n_layers=6, n_heads=16, vocab_size=2048, rope_theta=500000 + dim=256, n_layers=0, n_heads=16, vocab_size=2048, rope_theta=500000 ), "debugmodel_flex_attn": TransformerModelArgs( dim=256, diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 124153f14c..a11d40f33e 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -494,10 +494,11 @@ def forward( """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + # h = torch.ones([8, 2048, 1024], device="cuda") - for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks=attention_masks) + # for layer in self.layers.values(): + # h = layer(h, self.freqs_cis, attention_masks=attention_masks) - h = self.norm(h) if self.norm else h + # h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 7760667edd..b191b4814c 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -4,9 +4,9 @@ description = "Llama 3 debug training" print_config = false [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" -profile_freq = 10 +profile_freq = 5 enable_memory_snapshot = false save_memory_snapshot_folder = "memory_snapshot" diff --git a/torchtitan/models/llama3/train_configs/llama3_70b.toml b/torchtitan/models/llama3/train_configs/llama3_70b.toml index 37fd35b5cb..b87ff39aa1 100644 --- a/torchtitan/models/llama3/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_70b.toml @@ -7,7 +7,7 @@ description = "Llama 3 70B training" [profiling] enable_profiling = true save_traces_folder = "profile_trace" -profile_freq = 100 +profile_freq = 10 [metrics] log_freq = 10 @@ -30,7 +30,7 @@ warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps [training] local_batch_size = 8 -seq_len = 8192 +seq_len = 4096 max_norm = 1.0 # grad norm clipping steps = 1000 dataset = "c4" From 8ee5fce3137c72c59bfe80c1effff50bdb2f155a Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Fri, 14 Nov 2025 16:15:22 -0800 Subject: [PATCH 02/22] this can run --- .../compiler_toolkit/graph_utils.py | 17 ++-- .../experiments/compiler_toolkit/passes.py | 83 +++++++++++-------- torchtitan/models/llama3/__init__.py | 2 +- torchtitan/models/llama3/model/model.py | 8 +- torchtitan/train.py | 3 + 5 files changed, 67 insertions(+), 46 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index cd758438b3..09d86a35c5 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -20,7 +20,7 @@ from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.tools.logging import logger - +from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_PASSES def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None: # TODO: make the dump rank configurable @@ -239,11 +239,11 @@ def compiler( logger.info(f"Applying pass: {pass_fn.__name__}") gm = pass_fn(gm, example_inputs) - logger.debug(f"{name} after compiler:") - logger.debug( - gm.print_readable(print_output=False, include_stride=True, include_device=True) - ) - _dump_gm(dump_folder, gm, f"{name}_after_compiler") + # logger.debug(f"{name} after compiler:") + # logger.debug( + # gm.print_readable(print_output=False, include_stride=True, include_device=True) + # ) + # _dump_gm(dump_folder, gm, f"{name}_after_compiler") return gm @@ -259,15 +259,14 @@ def make_compiler_with_passes( Returns: Tuple of (fw_compiler, bw_compiler) functions """ - def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return compiler( - "fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + "fwd_gm", gm, example_inputs, passes=[AVAILABLE_PASSES["cudagraph_wrapper"]], dump_folder=dump_folder ) def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return compiler( - "bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + "bwd_gm", gm, example_inputs, passes=[], dump_folder=dump_folder ) return fw_compiler, bw_compiler diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 49aea1b51e..01ed8ee4db 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -74,43 +74,60 @@ def __call__(self, *args, **kwargs): torch.cuda.synchronize() print("a new run") - if not self.has_warmup: - self.has_warmup = True - return self.runnable(*args, **kwargs) - - if self.cudagraph is None: - self.args = args - self.kwargs = kwargs - input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - self.input_addresses = input_addresses - - self.cudagraph = torch.cuda.CUDAGraph() - - with torch.cuda.graph(self.cudagraph, pool=self.graph_pool): - # `output` is managed by pytorch's cudagraph pool - # TODO: use weak ref for output to reuse memory - self.output = self.runnable(*args, **kwargs) - - # TODO: add debug address check. - - if True: - # check if the input addresses are the same - new_input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - assert new_input_addresses == self.input_addresses, ( - f"Input addresses for cudagraphs are different " - f"during replay. Expected {self.input_addresses}, " - f"got {new_input_addresses}" - ) + self.runnable(*args, **kwargs) + + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, pool= torch.cuda.graph_pool_handle()): + # `output` is managed by pytorch's cudagraph pool + # TODO: use weak ref for output to reuse memory + self.output = self.runnable(*args, **kwargs) for iter in range(10): print(f"before iter {iter}") - self.cudagraph.replay() + g.replay() print(f"after iter {iter}") - torch.cuda.synchronize() + + return self.runnable(*args, **kwargs) + + + # if not self.has_warmup: + # self.has_warmup = True + # return self.runnable(*args, **kwargs) + + # if self.cudagraph is None: + # self.args = args + # self.kwargs = kwargs + # input_addresses = [ + # x.data_ptr() for x in args if isinstance(x, torch.Tensor) + # ] + # self.input_addresses = input_addresses + + # self.cudagraph = torch.cuda.CUDAGraph() + + # with torch.cuda.graph(self.cudagraph, pool=self.graph_pool): + # # `output` is managed by pytorch's cudagraph pool + # # TODO: use weak ref for output to reuse memory + # self.output = self.runnable(*args, **kwargs) + + # # TODO: add debug address check. + + # if True: + # # check if the input addresses are the same + # new_input_addresses = [ + # x.data_ptr() for x in args if isinstance(x, torch.Tensor) + # ] + # assert new_input_addresses == self.input_addresses, ( + # f"Input addresses for cudagraphs are different " + # f"during replay. Expected {self.input_addresses}, " + # f"got {new_input_addresses}" + # ) + + # for iter in range(10): + # print(f"before iter {iter}") + # self.cudagraph.replay() + # print(f"after iter {iter}") + # torch.cuda.synchronize() return self.output diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 701be870ff..952890edca 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -28,7 +28,7 @@ llama3_args = { "debugmodel": TransformerModelArgs( - dim=256, n_layers=0, n_heads=16, vocab_size=2048, rope_theta=500000 + dim=256, n_layers=1, n_heads=16, vocab_size=2048, rope_theta=500000 ), "debugmodel_flex_attn": TransformerModelArgs( dim=256, diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index a11d40f33e..450ec097f8 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -358,8 +358,11 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) - out = h + self.feed_forward(self.ffn_norm(h)) + # h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) + + h = x + # out = h + self.feed_forward(self.ffn_norm(h)) + out = h return out def init_weights(self): @@ -494,7 +497,6 @@ def forward( """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - # h = torch.ones([8, 2048, 1024], device="cuda") # for layer in self.layers.values(): # h = layer(h, self.freqs_cis, attention_masks=attention_masks) diff --git a/torchtitan/train.py b/torchtitan/train.py index 5cfab998b2..dca5f09496 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -645,6 +645,7 @@ def train(self): ): data_iterator = self.batch_generator(self.dataloader) while self.should_continue_training(): + print(f"start step:{self.step}") self.step += 1 self.gc_handler.run(self.step) try: @@ -652,6 +653,8 @@ def train(self): except DataloaderExhaustedError: logger.warning("Ran out of data; last step was canceled.") break + print(f"end step:{self.step}") + self.checkpointer.save( self.step, last_step=(self.step == job_config.training.steps) From 40823fa741b3a5fa8d43464a9db2d72b93672ef2 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 17 Nov 2025 10:25:57 -0800 Subject: [PATCH 03/22] remove expandable_segments to fix IMA issue --- run_train.sh | 3 ++- torchtitan/experiments/compiler_toolkit/passes.py | 1 + torchtitan/models/llama3/model/model.py | 6 +++--- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/run_train.sh b/run_train.sh index 83319816fe..0f2392843a 100755 --- a/run_train.sh +++ b/run_train.sh @@ -25,7 +25,8 @@ if [ "$DRY_RUN" = "1" ]; then python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@" else # Normal training with torchrun - PYTORCH_ALLOC_CONF="expandable_segments:True" \ + # expandable_segments does not work with cg and nccl. + # https://github.com/pytorch/pytorch/issues/158029 TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 01ed8ee4db..125e1e0dea 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -78,6 +78,7 @@ def __call__(self, *args, **kwargs): g = torch.cuda.CUDAGraph() + # allocate a graph pool for debugging. Will reuse graph pool across cg. with torch.cuda.graph(g, pool= torch.cuda.graph_pool_handle()): # `output` is managed by pytorch's cudagraph pool # TODO: use weak ref for output to reuse memory diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 450ec097f8..7a5749591a 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -498,9 +498,9 @@ def forward( # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - # for layer in self.layers.values(): - # h = layer(h, self.freqs_cis, attention_masks=attention_masks) + for layer in self.layers.values(): + h = layer(h, self.freqs_cis, attention_masks=attention_masks) - # h = self.norm(h) if self.norm else h + h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h return output From 6a297db3ecaeb0afdd98df2752854af6ef85296a Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 17 Nov 2025 11:21:07 -0800 Subject: [PATCH 04/22] nit --- .../experiments/compiler_toolkit/passes.py | 83 +++++++------------ 1 file changed, 30 insertions(+), 53 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 125e1e0dea..4f0f48335d 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -51,11 +51,10 @@ class CUDAGraphWrapper: def __init__( self, runnable: Callable, - graph_pool: Optional[torch.cuda._POOL_HANDLE] = None, static_input_indices: Optional[tuple[int]] = None, ): self.runnable = runnable - self.graph_pool = _global_graph_pool # graph_pool if graph_pool is not None else torch.cuda.graph_pool_handle() + self.graph_pool = _global_graph_pool self.static_input_indices = OrderedSet( static_input_indices if static_input_indices is not None else [] ) @@ -63,55 +62,37 @@ def __init__( self.cudagraph: Optional[torch.cuda.CUDAGraph] = None # TODO: weak ref + self.args = None + self.kwargs = None self.output = None self.has_warmup = False - def __call__(self, *args, **kwargs): - # assume that args and kwargs have been copied to - # static tensors - - torch.cuda.synchronize() - print("a new run") - - self.runnable(*args, **kwargs) - - - g = torch.cuda.CUDAGraph() - # allocate a graph pool for debugging. Will reuse graph pool across cg. - with torch.cuda.graph(g, pool= torch.cuda.graph_pool_handle()): - # `output` is managed by pytorch's cudagraph pool - # TODO: use weak ref for output to reuse memory - self.output = self.runnable(*args, **kwargs) - - for iter in range(10): - print(f"before iter {iter}") - g.replay() - print(f"after iter {iter}") - - return self.runnable(*args, **kwargs) - + def copy_static_inputs(self, *args): + for i in range(len(self.args)): + if i not in self.static_input_indices and isinstance(self.args[i], torch.Tensor): + self.args[i].copy_(args[i]) - # if not self.has_warmup: - # self.has_warmup = True - # return self.runnable(*args, **kwargs) - - # if self.cudagraph is None: - # self.args = args - # self.kwargs = kwargs - # input_addresses = [ - # x.data_ptr() for x in args if isinstance(x, torch.Tensor) - # ] - # self.input_addresses = input_addresses - - # self.cudagraph = torch.cuda.CUDAGraph() - - # with torch.cuda.graph(self.cudagraph, pool=self.graph_pool): - # # `output` is managed by pytorch's cudagraph pool - # # TODO: use weak ref for output to reuse memory - # self.output = self.runnable(*args, **kwargs) - - # # TODO: add debug address check. + def __call__(self, *args, **kwargs): + if not self.has_warmup: + self.has_warmup = True + return self.runnable(*args, **kwargs) + + if self.cudagraph is None: + # TODO: weak ref? + self.args = args + self.kwargs = kwargs + input_addresses = [ + x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args + ] + self.input_addresses = input_addresses + + self.cudagraph = torch.cuda.CUDAGraph() + + with torch.cuda.graph(self.cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + # TODO: use weak ref for output to reuse memory + self.output = self.runnable(*args, **kwargs) # if True: # # check if the input addresses are the same @@ -124,12 +105,9 @@ def __call__(self, *args, **kwargs): # f"got {new_input_addresses}" # ) - # for iter in range(10): - # print(f"before iter {iter}") - # self.cudagraph.replay() - # print(f"after iter {iter}") - # torch.cuda.synchronize() + self.copy_static_inputs(*args) + self.cudagraph.replay() return self.output @@ -142,8 +120,7 @@ def cudagraph_wrapper( @param metadata: the metadata of the function @return: the wrapped function """ - gc_disable = False # partition_id != 0 - return CUDAGraphWrapper(fn, gc_disable) + return CUDAGraphWrapper(fn) # Registry mapping pass names to pass functions From 73812f970d91550356c8b76dfdac5c5671ddce87 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 17 Nov 2025 11:28:03 -0800 Subject: [PATCH 05/22] cleanup --- torchtitan/models/llama3/__init__.py | 2 +- torchtitan/models/llama3/model/model.py | 7 ++----- torchtitan/models/llama3/train_configs/llama3_70b.toml | 4 ++-- torchtitan/train.py | 3 --- 4 files changed, 5 insertions(+), 11 deletions(-) diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 952890edca..191588ad9e 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -28,7 +28,7 @@ llama3_args = { "debugmodel": TransformerModelArgs( - dim=256, n_layers=1, n_heads=16, vocab_size=2048, rope_theta=500000 + dim=256, n_layers=6, n_heads=16, vocab_size=2048, rope_theta=500000 ), "debugmodel_flex_attn": TransformerModelArgs( dim=256, diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 7a5749591a..124153f14c 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -358,11 +358,8 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - # h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) - - h = x - # out = h + self.feed_forward(self.ffn_norm(h)) - out = h + h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) + out = h + self.feed_forward(self.ffn_norm(h)) return out def init_weights(self): diff --git a/torchtitan/models/llama3/train_configs/llama3_70b.toml b/torchtitan/models/llama3/train_configs/llama3_70b.toml index b87ff39aa1..37fd35b5cb 100644 --- a/torchtitan/models/llama3/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_70b.toml @@ -7,7 +7,7 @@ description = "Llama 3 70B training" [profiling] enable_profiling = true save_traces_folder = "profile_trace" -profile_freq = 10 +profile_freq = 100 [metrics] log_freq = 10 @@ -30,7 +30,7 @@ warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps [training] local_batch_size = 8 -seq_len = 4096 +seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 1000 dataset = "c4" diff --git a/torchtitan/train.py b/torchtitan/train.py index dca5f09496..5cfab998b2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -645,7 +645,6 @@ def train(self): ): data_iterator = self.batch_generator(self.dataloader) while self.should_continue_training(): - print(f"start step:{self.step}") self.step += 1 self.gc_handler.run(self.step) try: @@ -653,8 +652,6 @@ def train(self): except DataloaderExhaustedError: logger.warning("Ran out of data; last step was canceled.") break - print(f"end step:{self.step}") - self.checkpointer.save( self.step, last_step=(self.step == job_config.training.steps) From 5c3da3f6bcce4910f398f10ee9963013c934910a Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 17 Nov 2025 13:21:12 -0800 Subject: [PATCH 06/22] cleanup --- .../experiments/compiler_toolkit/README.md | 4 ---- .../experiments/compiler_toolkit/common_utils.py | 2 -- .../experiments/compiler_toolkit/graph_utils.py | 16 ++++++++-------- .../experiments/compiler_toolkit/passes.py | 5 +++-- 4 files changed, 11 insertions(+), 16 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index 14f6a8d848..61207fc63b 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -44,7 +44,3 @@ NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./r ```shell NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor ``` - -NGPU=2 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes cudagraph_wrapper - -NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes cudagraph_wrapper diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index 35066343c5..b7499b2f79 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -28,8 +28,6 @@ def parallelize_inputs(world_mesh, args, kwargs): def to_dtensor(tensor): if isinstance(tensor, torch.Tensor): return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()]) - - # return DTensor.from_local(tensor, world_mesh, [Replicate()]) return tensor dt_args = tree_map(to_dtensor, args) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 09d86a35c5..e2c79d313a 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -20,7 +20,7 @@ from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.tools.logging import logger -from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_PASSES + def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None: # TODO: make the dump rank configurable @@ -239,11 +239,11 @@ def compiler( logger.info(f"Applying pass: {pass_fn.__name__}") gm = pass_fn(gm, example_inputs) - # logger.debug(f"{name} after compiler:") - # logger.debug( - # gm.print_readable(print_output=False, include_stride=True, include_device=True) - # ) - # _dump_gm(dump_folder, gm, f"{name}_after_compiler") + logger.debug(f"{name} after compiler:") + logger.debug( + gm.print_readable(print_output=False, include_stride=True, include_device=True) + ) + _dump_gm(dump_folder, gm, f"{name}_after_compiler") return gm @@ -261,12 +261,12 @@ def make_compiler_with_passes( """ def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return compiler( - "fwd_gm", gm, example_inputs, passes=[AVAILABLE_PASSES["cudagraph_wrapper"]], dump_folder=dump_folder + "fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder ) def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return compiler( - "bwd_gm", gm, example_inputs, passes=[], dump_folder=dump_folder + "fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder ) return fw_compiler, bw_compiler diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 4f0f48335d..b34c3555ff 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -112,7 +112,7 @@ def __call__(self, *args, **kwargs): def cudagraph_wrapper( - fn, example_inputs + gm: torch.fx.GraphModule, example_inputs ): # , num_partitions: int, partition_id: int) -> Callable: """ Wrap a function with CUDAGraphWrapper. @@ -120,7 +120,8 @@ def cudagraph_wrapper( @param metadata: the metadata of the function @return: the wrapped function """ - return CUDAGraphWrapper(fn) + gm.forward = CUDAGraphWrapper(gm.forward) + return gm # Registry mapping pass names to pass functions From b2b2b4f43a75e4da49798887d65a35d54700c605 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 17 Nov 2025 13:22:22 -0800 Subject: [PATCH 07/22] cleanup --- torchtitan/models/deepseek_v3/__init__.py | 2 +- .../models/deepseek_v3/train_configs/deepseek_v3_16b.toml | 2 +- torchtitan/models/llama3/train_configs/debug_model.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 7f9626480e..525bd96c13 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -97,7 +97,7 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=False, + use_flex_attn=True, attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 36238286ee..00ec53310e 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -4,7 +4,7 @@ description = "DeepSeek-V3 16B model training" print_config = false [profiling] -enable_profiling = true +enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = false diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index b191b4814c..89b90a1166 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -6,7 +6,7 @@ print_config = false [profiling] enable_profiling = true save_traces_folder = "profile_trace" -profile_freq = 5 +profile_freq = 10 enable_memory_snapshot = false save_memory_snapshot_folder = "memory_snapshot" From b433910bff4c92e4916e54ebc3c4ea84e409d75a Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 17 Nov 2025 14:17:07 -0800 Subject: [PATCH 08/22] explicit del cudagraph before destroy process group --- .../experiments/compiler_toolkit/graph_utils.py | 3 ++- .../experiments/compiler_toolkit/passes.py | 17 ++++++++++------- torchtitan/train.py | 9 +++++++++ 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index e2c79d313a..cd758438b3 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -259,6 +259,7 @@ def make_compiler_with_passes( Returns: Tuple of (fw_compiler, bw_compiler) functions """ + def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return compiler( "fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder @@ -266,7 +267,7 @@ def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return compiler( - "fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + "bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder ) return fw_compiler, bw_compiler diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index b34c3555ff..75bbb7604b 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -111,14 +111,17 @@ def __call__(self, *args, **kwargs): return self.output -def cudagraph_wrapper( +def cudagraph_pass( gm: torch.fx.GraphModule, example_inputs -): # , num_partitions: int, partition_id: int) -> Callable: +) -> torch.fx.GraphModule: """ - Wrap a function with CUDAGraphWrapper. - @param fn: the function to be wrapped - @param metadata: the metadata of the function - @return: the wrapped function + Apply cudagraph. + + This pass wraps the forward function with cudagraph during compilation and does + not record cudagraph until runtime. + - For the first run, it will warm up operators such as nccl. + - For the second run, it will record cudagraph and replay cudagraph. + - For the following runs, it will replay cudagraph. """ gm.forward = CUDAGraphWrapper(gm.forward) return gm @@ -128,7 +131,7 @@ def cudagraph_wrapper( AVAILABLE_PASSES = { "autobucketing_reordering": autobucketing_reordering_pass, "regional_inductor": regional_inductor_pass, - "cudagraph_wrapper": cudagraph_wrapper, + "cudagraph": cudagraph_pass, } diff --git a/torchtitan/train.py b/torchtitan/train.py index 5cfab998b2..eee5f10628 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -9,6 +9,7 @@ import time from datetime import timedelta from typing import Any, Generator, Iterable +import gc import torch @@ -703,6 +704,13 @@ def close(self) -> None: if hasattr(self, "metrics_processor") and self.metrics_processor: self.metrics_processor.close() + # Note [explicit cudagraph close] + # cudagraph holds reference to nccl which prevents destroy nccl group. + # so we need to explicitly delete cudagraph which is held in joint_graph_module. + # An explicit gc.collect() is needed here to clean up reference cycles. + for part in self.model_parts: + part.joint_graph_module = None + gc.collect() def main(trainer_class: type[Trainer]) -> None: """Main entry point for training with a specified trainer class. @@ -736,6 +744,7 @@ def main(trainer_class: type[Trainer]) -> None: else: trainer.close() if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() logger.info("Process group destroyed") From 5559ae4c1a4d12e55ef08ee1e59429eedb197ac6 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 17 Nov 2025 14:21:14 -0800 Subject: [PATCH 09/22] add USE_EXPANDABLE_SEGMENTS config --- run_train.sh | 8 ++++++-- torchtitan/experiments/compiler_toolkit/passes.py | 2 -- torchtitan/train.py | 1 - 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/run_train.sh b/run_train.sh index 0f2392843a..d20abe1ff6 100755 --- a/run_train.sh +++ b/run_train.sh @@ -19,14 +19,18 @@ DRY_RUN=${DRY_RUN:-0} TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} +# need to turn off expandable segments when using cudagraph, since +# it does not work with cg and nccl yet. +# https://github.com/pytorch/pytorch/issues/158029 +USE_EXPANDABLE_SEGMENTS=${USE_EXPANDABLE_SEGMENTS:-True} + if [ "$DRY_RUN" = "1" ]; then # Dry run mode: validate configuration without GPU/distributed setup echo "Running in DRY RUN mode - configuration validation only" python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@" else # Normal training with torchrun - # expandable_segments does not work with cg and nccl. - # https://github.com/pytorch/pytorch/issues/158029 + PYTORCH_ALLOC_CONF="expandable_segments:${USE_EXPANDABLE_SEGMENTS}" \ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 75bbb7604b..76b2be76ba 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -134,5 +134,3 @@ def cudagraph_pass( "cudagraph": cudagraph_pass, } - -# TODO: cleanup graph before nccl destroy group diff --git a/torchtitan/train.py b/torchtitan/train.py index eee5f10628..26c82a9298 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -744,7 +744,6 @@ def main(trainer_class: type[Trainer]) -> None: else: trainer.close() if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() logger.info("Process group destroyed") From a3ed72c93b465a4fd3e5614d1de01e4120840975 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 17 Nov 2025 15:26:12 -0800 Subject: [PATCH 10/22] add static input indices --- .../compiler_toolkit/graph_utils.py | 27 ++++++- .../experiments/compiler_toolkit/passes.py | 78 ++++++++++++------- torchtitan/train.py | 3 +- 3 files changed, 76 insertions(+), 32 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index cd758438b3..5560823eb9 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +import functools from pathlib import Path from typing import Any, Callable, List, Optional @@ -213,6 +214,7 @@ def compiler( example_inputs, passes: List[Callable] = None, dump_folder: str | None = None, + is_forward: bool = True, ): """ Compile a graph module by applying a sequence of compiler passes. @@ -235,6 +237,17 @@ def compiler( ) _dump_gm(dump_folder, gm, f"{name}_before_compiler") + if len(passes) > 0 and passes[-1].__name__ == "cudagraph_pass": + # cudagraph pass is always the last pass if it is applied + cg_pass = passes[-1] + + # to identify static input indices, cudagraph passes behaves differently for + # forward and backward pass. so we explicitly pass the info. + _cg_pass = functools.partial(cg_pass, is_forward=is_forward) + + # keep the function name to + passes[-1] = functools.wraps(cg_pass)(_cg_pass) + for pass_fn in passes: logger.info(f"Applying pass: {pass_fn.__name__}") gm = pass_fn(gm, example_inputs) @@ -262,12 +275,22 @@ def make_compiler_with_passes( def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return compiler( - "fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + "fwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=True, ) def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return compiler( - "bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + "bwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=False, ) return fw_compiler, bw_compiler diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 76b2be76ba..9dbddd4303 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -11,6 +11,8 @@ during compilation. Passes can be selected and configured via job config. """ +from typing import Any, Callable, Optional, Sequence + import torch from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing from torch.fx.passes.regional_inductor import regional_inductor @@ -40,17 +42,17 @@ def regional_inductor_pass( return regional_inductor(gm, example_inputs) -from typing import Callable, Optional - _global_graph_pool = torch.cuda.graph_pool_handle() # TODO: make output and args weakref to allow reuse. +# TODO: Check memory consumption class CUDAGraphWrapper: def __init__( self, runnable: Callable, + example_inputs: Sequence[Any], static_input_indices: Optional[tuple[int]] = None, ): self.runnable = runnable @@ -58,32 +60,32 @@ def __init__( self.static_input_indices = OrderedSet( static_input_indices if static_input_indices is not None else [] ) - + self.input_indices_to_copy = [ + i + for i, inp in enumerate(example_inputs) + if isinstance(inp, torch.Tensor) and i not in self.static_input_indices + ] self.cudagraph: Optional[torch.cuda.CUDAGraph] = None + self.has_warmup = False # TODO: weak ref self.args = None - self.kwargs = None self.output = None - self.has_warmup = False - def copy_static_inputs(self, *args): - for i in range(len(self.args)): - if i not in self.static_input_indices and isinstance(self.args[i], torch.Tensor): - self.args[i].copy_(args[i]) + for i in self.input_indices_to_copy: + self.args[i].copy_(args[i]) - def __call__(self, *args, **kwargs): + def __call__(self, *args): if not self.has_warmup: self.has_warmup = True - return self.runnable(*args, **kwargs) + return self.runnable(*args) if self.cudagraph is None: # TODO: weak ref? self.args = args - self.kwargs = kwargs input_addresses = [ - x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args + x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args ] self.input_addresses = input_addresses @@ -92,27 +94,45 @@ def __call__(self, *args, **kwargs): with torch.cuda.graph(self.cudagraph, pool=self.graph_pool): # `output` is managed by pytorch's cudagraph pool # TODO: use weak ref for output to reuse memory - self.output = self.runnable(*args, **kwargs) - - # if True: - # # check if the input addresses are the same - # new_input_addresses = [ - # x.data_ptr() for x in args if isinstance(x, torch.Tensor) - # ] - # assert new_input_addresses == self.input_addresses, ( - # f"Input addresses for cudagraphs are different " - # f"during replay. Expected {self.input_addresses}, " - # f"got {new_input_addresses}" - # ) - + self.output = self.runnable(*args) self.copy_static_inputs(*args) self.cudagraph.replay() return self.output +def get_static_input_indices(gm: torch.fx.GraphModule, is_forward: bool) -> list[int]: + """ + Get indices of gm inputs that are static input tensors whose tensor addresses do not + change across runs. Example of static input tensors include weights, buffers, and + outputs of previous cudagraph wrapped functions. + """ + from torch._inductor.utils import count_tangents + + static_input_indices = [] + if ( + is_forward + and (tracing_context := torch._guards.TracingContext.try_get()) + and hasattr(tracing_context, "fw_metadata") + ): + # for forward, we rely on graph capture (i.e., dynamo or export) to provide + # the correct static input indices stored in tracing context. Typical examples + # include weights and buffers. + static_input_indices = tracing_context.fw_metadata.static_input_indices + + elif not is_forward: + # for backward, we identify saved tensors as static inputs, since saved tensors + # are outputs of cudagraph-wrapped forward run. In PT2-generated backward gm, + # saved tensors are always the leading args. So we can get the number of saved + # tensors and generate static input indices. + fixed = count_tangents(gm) + static_input_indices = list(range(fixed)) + + return static_input_indices + + def cudagraph_pass( - gm: torch.fx.GraphModule, example_inputs + gm: torch.fx.GraphModule, example_inputs: Sequence[Any], is_forward: bool ) -> torch.fx.GraphModule: """ Apply cudagraph. @@ -123,7 +143,8 @@ def cudagraph_pass( - For the second run, it will record cudagraph and replay cudagraph. - For the following runs, it will replay cudagraph. """ - gm.forward = CUDAGraphWrapper(gm.forward) + static_input_indices = get_static_input_indices(gm, is_forward) + gm.forward = CUDAGraphWrapper(gm.forward, example_inputs, static_input_indices) return gm @@ -133,4 +154,3 @@ def cudagraph_pass( "regional_inductor": regional_inductor_pass, "cudagraph": cudagraph_pass, } - diff --git a/torchtitan/train.py b/torchtitan/train.py index 26c82a9298..02cb0f80af 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -4,12 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import gc import importlib import os import time from datetime import timedelta from typing import Any, Generator, Iterable -import gc import torch @@ -712,6 +712,7 @@ def close(self) -> None: part.joint_graph_module = None gc.collect() + def main(trainer_class: type[Trainer]) -> None: """Main entry point for training with a specified trainer class. From 2163f319d063a0f943179237c0e3102987cd642a Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 17 Nov 2025 15:32:50 -0800 Subject: [PATCH 11/22] lint --- torchtitan/train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 02cb0f80af..872e339493 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -705,9 +705,10 @@ def close(self) -> None: self.metrics_processor.close() # Note [explicit cudagraph close] - # cudagraph holds reference to nccl which prevents destroy nccl group. - # so we need to explicitly delete cudagraph which is held in joint_graph_module. - # An explicit gc.collect() is needed here to clean up reference cycles. + # cudagraph holds reference to nccl which prevents destroy nccl + # group. so we need to explicitly delete cudagraph which is held + # in joint_graph_module. An explicit gc.collect() is necessary + # to clean up reference cycles. for part in self.model_parts: part.joint_graph_module = None gc.collect() From 3ff3ada31ad4050ce17790169367951a046228b0 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 17 Nov 2025 15:47:49 -0800 Subject: [PATCH 12/22] cleanup --- torchtitan/experiments/compiler_toolkit/graph_utils.py | 2 +- torchtitan/models/llama3/train_configs/debug_model.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 5560823eb9..b3a681387d 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -245,7 +245,7 @@ def compiler( # forward and backward pass. so we explicitly pass the info. _cg_pass = functools.partial(cg_pass, is_forward=is_forward) - # keep the function name to + # keep the function name for debug log passes[-1] = functools.wraps(cg_pass)(_cg_pass) for pass_fn in passes: diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 89b90a1166..7760667edd 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -4,7 +4,7 @@ description = "Llama 3 debug training" print_config = false [profiling] -enable_profiling = true +enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = false From e49a2f243da55a7ba5a4b62ebfe2704c13b95783 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 17 Nov 2025 22:52:35 -0800 Subject: [PATCH 13/22] warmup in cudagraph memory pool --- .../experiments/compiler_toolkit/passes.py | 53 +++++++++++++++++-- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 9dbddd4303..7bee69384d 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -11,9 +11,11 @@ during compilation. Passes can be selected and configured via job config. """ +import warnings from typing import Any, Callable, Optional, Sequence import torch +from torch._inductor.cudagraph_trees import _use_cuda_memory_pool_manager from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing from torch.fx.passes.regional_inductor import regional_inductor from torch.utils._ordered_set import OrderedSet @@ -42,10 +44,44 @@ def regional_inductor_pass( return regional_inductor(gm, example_inputs) -_global_graph_pool = torch.cuda.graph_pool_handle() +def init_global_graph_pool() -> tuple[ + torch.cuda.CUDAGraph, torch.cuda._POOL_HANDLE, torch.cuda.Stream +]: + dummy_graph = torch.cuda.CUDAGraph() + + # create a global cudagraph memory pool to allow memory reuse across cudagraphs. + graph_pool = torch.cuda.graph_pool_handle() + + # create a global cuda stream for graph capture. we need to use a single stream + # for all allocations to the memory pool, otherwise the allocations to separate streams + # will not be used. + graph_capture_stream = torch.cuda.Stream() + + # use a dummy graph to keep the global graph pool alive + with ( + # suppress an empty cudagraph warning, since we intentionally create + # an empty cudagraph here + warnings.catch_warnings(record=True), + torch.cuda.graph( + dummy_graph, + pool=graph_pool, + stream=graph_capture_stream, + capture_error_mode="thread_local", + ), + ): + pass + + return dummy_graph, graph_pool, graph_capture_stream + + +( + _global_dummy_graph, + _global_graph_pool, + _global_graph_capture_stream, +) = init_global_graph_pool() + # TODO: make output and args weakref to allow reuse. -# TODO: Check memory consumption class CUDAGraphWrapper: @@ -57,6 +93,7 @@ def __init__( ): self.runnable = runnable self.graph_pool = _global_graph_pool + self.stream = _global_graph_capture_stream self.static_input_indices = OrderedSet( static_input_indices if static_input_indices is not None else [] ) @@ -79,7 +116,13 @@ def copy_static_inputs(self, *args): def __call__(self, *args): if not self.has_warmup: self.has_warmup = True - return self.runnable(*args) + device = torch.cuda.current_device() + + # warmup in cudagraph memory pool to avoid fragmentation + # across eager memory pool and cudagraph memory pool. + with _use_cuda_memory_pool_manager(device, self.graph_pool, self.stream): + out = self.runnable(*args) + return out if self.cudagraph is None: # TODO: weak ref? @@ -91,7 +134,9 @@ def __call__(self, *args): self.cudagraph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.cudagraph, pool=self.graph_pool): + with torch.cuda.graph( + self.cudagraph, pool=self.graph_pool, stream=self.stream + ): # `output` is managed by pytorch's cudagraph pool # TODO: use weak ref for output to reuse memory self.output = self.runnable(*args) From fac30dace691aa22c88edff2dd91c2a6ff7d0b23 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 17 Nov 2025 23:13:28 -0800 Subject: [PATCH 14/22] lint --- torchtitan/experiments/compiler_toolkit/passes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 831286735a..9e8d7fff8a 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -195,6 +195,7 @@ def cudagraph_pass( gm.forward = CUDAGraphWrapper(gm.forward, example_inputs, static_input_indices) return gm + def validate_flex_attn_annotation_pass( gm: torch.fx.GraphModule, ) -> torch.fx.GraphModule: From 2b5cfbc388883fa447789d2747116e8469d48982 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Nov 2025 11:44:43 -0800 Subject: [PATCH 15/22] refactor cudagraph to separate file --- .../experiments/compiler_toolkit/cudagraph.py | 145 ++++++++++++++++++ .../experiments/compiler_toolkit/passes.py | 141 +---------------- 2 files changed, 150 insertions(+), 136 deletions(-) create mode 100644 torchtitan/experiments/compiler_toolkit/cudagraph.py diff --git a/torchtitan/experiments/compiler_toolkit/cudagraph.py b/torchtitan/experiments/compiler_toolkit/cudagraph.py new file mode 100644 index 0000000000..64602cea3c --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/cudagraph.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +CUDAGraph pass for the compiler toolkit. + +This module provides a cudagraph pass that can be applied to graph modules +during compilation. +""" + +import warnings +from typing import Any, Callable, Optional, Sequence + +import torch +from torch._inductor.cudagraph_trees import _use_cuda_memory_pool_manager +from torch.utils._ordered_set import OrderedSet + + +def init_global_graph_pool() -> tuple[ + torch.cuda.CUDAGraph, torch.cuda._POOL_HANDLE, torch.cuda.Stream +]: + dummy_graph = torch.cuda.CUDAGraph() + + # create a global cudagraph memory pool to allow memory reuse across cudagraphs. + graph_pool = torch.cuda.graph_pool_handle() + + # create a global cuda stream for graph capture. we need to use a single stream + # for all allocations to the memory pool, otherwise the allocations to separate streams + # will not be used. + graph_capture_stream = torch.cuda.Stream() + + # use a dummy graph to keep the global graph pool alive + with ( + # suppress an empty cudagraph warning, since we intentionally create + # an empty cudagraph here + warnings.catch_warnings(record=True), + torch.cuda.graph( + dummy_graph, + pool=graph_pool, + stream=graph_capture_stream, + capture_error_mode="thread_local", + ), + ): + pass + + return dummy_graph, graph_pool, graph_capture_stream + + +( + _global_dummy_graph, + _global_graph_pool, + _global_graph_capture_stream, +) = init_global_graph_pool() + + +class CUDAGraphWrapper: + def __init__( + self, + runnable: Callable, + example_inputs: Sequence[Any], + static_input_indices: Optional[tuple[int]] = None, + ): + self.runnable = runnable + self.graph_pool = _global_graph_pool + self.stream = _global_graph_capture_stream + self.static_input_indices = OrderedSet( + static_input_indices if static_input_indices is not None else [] + ) + self.input_indices_to_copy = [ + i + for i, inp in enumerate(example_inputs) + if isinstance(inp, torch.Tensor) and i not in self.static_input_indices + ] + self.cudagraph: Optional[torch.cuda.CUDAGraph] = None + self.has_warmup = False + + self.args = None + self.output = None + + def copy_static_inputs(self, *args): + for i in self.input_indices_to_copy: + self.args[i].copy_(args[i]) + + def __call__(self, *args): + if not self.has_warmup: + self.has_warmup = True + device = torch.cuda.current_device() + + # warmup in cudagraph memory pool to avoid fragmentation + # across eager memory pool and cudagraph memory pool. + with _use_cuda_memory_pool_manager(device, self.graph_pool, self.stream): + out = self.runnable(*args) + return out + + if self.cudagraph is None: + self.args = args + input_addresses = [ + x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args + ] + self.input_addresses = input_addresses + + self.cudagraph = torch.cuda.CUDAGraph() + + with torch.cuda.graph( + self.cudagraph, pool=self.graph_pool, stream=self.stream + ): + # `output` is managed by pytorch's cudagraph pool + self.output = self.runnable(*args) + + self.copy_static_inputs(*args) + self.cudagraph.replay() + return self.output + + +def get_static_input_indices(gm: torch.fx.GraphModule, is_forward: bool) -> list[int]: + """ + Get indices of gm inputs that are static input tensors whose tensor addresses do not + change across runs. Example of static input tensors include weights, buffers, and + outputs of previous cudagraph wrapped functions. + """ + from torch._inductor.utils import count_tangents + + static_input_indices = [] + if ( + is_forward + and (tracing_context := torch._guards.TracingContext.try_get()) + and hasattr(tracing_context, "fw_metadata") + ): + # for forward, we rely on graph capture (i.e., dynamo or export) to provide + # the correct static input indices stored in tracing context. Typical examples + # include weights and buffers. + static_input_indices = tracing_context.fw_metadata.static_input_indices + + elif not is_forward: + # for backward, we identify saved tensors as static inputs, since saved tensors + # are outputs of cudagraph-wrapped forward run. In PT2-generated backward gm, + # saved tensors are always the leading args. So we can get the number of saved + # tensors and generate static input indices. + fixed = count_tangents(gm) + static_input_indices = list(range(fixed)) + + return static_input_indices diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 9e8d7fff8a..1f4e70074f 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -11,14 +11,15 @@ during compilation. Passes can be selected and configured via job config. """ -import warnings -from typing import Any, Callable, Optional, Sequence +from typing import Any, Sequence import torch -from torch._inductor.cudagraph_trees import _use_cuda_memory_pool_manager from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing from torch.fx.passes.regional_inductor import regional_inductor -from torch.utils._ordered_set import OrderedSet +from torchtitan.experiments.compiler_toolkit.cudagraph import ( + CUDAGraphWrapper, + get_static_input_indices, +) from torchtitan.experiments.simple_fsdp.reshard_after_forward import ( annotate_fsdp_all_gather, ) @@ -47,138 +48,6 @@ def regional_inductor_pass( return regional_inductor(gm, example_inputs) -def init_global_graph_pool() -> tuple[ - torch.cuda.CUDAGraph, torch.cuda._POOL_HANDLE, torch.cuda.Stream -]: - dummy_graph = torch.cuda.CUDAGraph() - - # create a global cudagraph memory pool to allow memory reuse across cudagraphs. - graph_pool = torch.cuda.graph_pool_handle() - - # create a global cuda stream for graph capture. we need to use a single stream - # for all allocations to the memory pool, otherwise the allocations to separate streams - # will not be used. - graph_capture_stream = torch.cuda.Stream() - - # use a dummy graph to keep the global graph pool alive - with ( - # suppress an empty cudagraph warning, since we intentionally create - # an empty cudagraph here - warnings.catch_warnings(record=True), - torch.cuda.graph( - dummy_graph, - pool=graph_pool, - stream=graph_capture_stream, - capture_error_mode="thread_local", - ), - ): - pass - - return dummy_graph, graph_pool, graph_capture_stream - - -( - _global_dummy_graph, - _global_graph_pool, - _global_graph_capture_stream, -) = init_global_graph_pool() - - -# TODO: make output and args weakref to allow reuse. - - -class CUDAGraphWrapper: - def __init__( - self, - runnable: Callable, - example_inputs: Sequence[Any], - static_input_indices: Optional[tuple[int]] = None, - ): - self.runnable = runnable - self.graph_pool = _global_graph_pool - self.stream = _global_graph_capture_stream - self.static_input_indices = OrderedSet( - static_input_indices if static_input_indices is not None else [] - ) - self.input_indices_to_copy = [ - i - for i, inp in enumerate(example_inputs) - if isinstance(inp, torch.Tensor) and i not in self.static_input_indices - ] - self.cudagraph: Optional[torch.cuda.CUDAGraph] = None - self.has_warmup = False - - # TODO: weak ref - self.args = None - self.output = None - - def copy_static_inputs(self, *args): - for i in self.input_indices_to_copy: - self.args[i].copy_(args[i]) - - def __call__(self, *args): - if not self.has_warmup: - self.has_warmup = True - device = torch.cuda.current_device() - - # warmup in cudagraph memory pool to avoid fragmentation - # across eager memory pool and cudagraph memory pool. - with _use_cuda_memory_pool_manager(device, self.graph_pool, self.stream): - out = self.runnable(*args) - return out - - if self.cudagraph is None: - # TODO: weak ref? - self.args = args - input_addresses = [ - x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args - ] - self.input_addresses = input_addresses - - self.cudagraph = torch.cuda.CUDAGraph() - - with torch.cuda.graph( - self.cudagraph, pool=self.graph_pool, stream=self.stream - ): - # `output` is managed by pytorch's cudagraph pool - # TODO: use weak ref for output to reuse memory - self.output = self.runnable(*args) - - self.copy_static_inputs(*args) - self.cudagraph.replay() - return self.output - - -def get_static_input_indices(gm: torch.fx.GraphModule, is_forward: bool) -> list[int]: - """ - Get indices of gm inputs that are static input tensors whose tensor addresses do not - change across runs. Example of static input tensors include weights, buffers, and - outputs of previous cudagraph wrapped functions. - """ - from torch._inductor.utils import count_tangents - - static_input_indices = [] - if ( - is_forward - and (tracing_context := torch._guards.TracingContext.try_get()) - and hasattr(tracing_context, "fw_metadata") - ): - # for forward, we rely on graph capture (i.e., dynamo or export) to provide - # the correct static input indices stored in tracing context. Typical examples - # include weights and buffers. - static_input_indices = tracing_context.fw_metadata.static_input_indices - - elif not is_forward: - # for backward, we identify saved tensors as static inputs, since saved tensors - # are outputs of cudagraph-wrapped forward run. In PT2-generated backward gm, - # saved tensors are always the leading args. So we can get the number of saved - # tensors and generate static input indices. - fixed = count_tangents(gm) - static_input_indices = list(range(fixed)) - - return static_input_indices - - def cudagraph_pass( gm: torch.fx.GraphModule, example_inputs: Sequence[Any], is_forward: bool ) -> torch.fx.GraphModule: From 5992263d896dddd58309a9100864459908b8ad6f Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Nov 2025 13:19:38 -0800 Subject: [PATCH 16/22] make sure cudagraph is always the last pass to apply --- torchtitan/experiments/compiler_toolkit/graph_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index d1826bae4e..fd2b567716 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -314,7 +314,12 @@ def get_compiler_passes_from_config(job_config: JobConfig): pass_names = getattr(job_config.compile, "passes", []) compiler_passes = [] + use_cudagraph = "cudagraph" in pass_names + for pass_name in pass_names: + if pass_name == "cudagraph": + continue + if pass_name not in AVAILABLE_COMPILER_PASSES: raise ValueError( f"Unknown compiler pass: {pass_name}. " @@ -322,6 +327,10 @@ def get_compiler_passes_from_config(job_config: JobConfig): ) compiler_passes.append(AVAILABLE_COMPILER_PASSES[pass_name]) + if use_cudagraph: + # cudagraph should always be the last fx pass to apply + compiler_passes.append(AVAILABLE_COMPILER_PASSES["cudagraph"]) + if pass_names: logger.info(f"Using compiler passes from config: {pass_names}") From b0feed3cd2a31cb4423bd14c9657bdde7e1cf9e9 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Nov 2025 16:43:26 -0800 Subject: [PATCH 17/22] add test --- .../compiler_toolkit/tests/integration_tests.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py index e33149fe2f..f4dd36a8f3 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py +++ b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py @@ -76,6 +76,20 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "llama3_fsdp_tp_flexattn_autobucketing_regional_inductor", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes cudagraph", + ], + ], + "llama3 FSDP+TP+cudagraph", + "llama3_fsdp_tp_cudagraph", + ngpu=4, + ), # deepseek_v3 tests OverrideDefinitions( [ From 3835a14ccec9c9848ec1863c2b87e0ce15f4be9a Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Nov 2025 22:16:03 -0800 Subject: [PATCH 18/22] more docs and tests --- .../experiments/compiler_toolkit/README.md | 6 +++++ .../compiler_toolkit/common_utils.py | 9 +++++++ .../compiler_toolkit/graph_utils.py | 26 +++++++++++++------ .../tests/integration_tests.py | 21 ++++++++++++--- 4 files changed, 51 insertions(+), 11 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index c223d1e658..f5cbb61f00 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -55,3 +55,9 @@ NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./r ```shell NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor ``` + +**SimpleFSDP + TP + FlexAttention + transformer-block-bucketing + regional-inductor + cudagraph** + +```shell +NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor,cudagraph +``` diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index 965e027bdb..997af9a2c4 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from contextlib import contextmanager +from typing import Callable import torch from torch.distributed.tensor import DTensor, Replicate @@ -53,3 +54,11 @@ def register_blockmask_pytree_node(): flatten_with_keys_fn=BlockMask._flatten_with_keys, serialized_type_name="torch.nn.attention.flex_attention.BlockMask", ) + + +def end_with_pass(passes: list[Callable], names: list[str]) -> bool: + return ( + len(passes) > 0 + and (last_pass_name := getattr(passes[-1], "__name__", None)) + and (last_pass_name in names) + ) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index f10477d833..36e7ff2124 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -20,6 +20,7 @@ from torch.distributed.tensor import DTensor from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.experiments.compiler_toolkit.common_utils import end_with_pass from torchtitan.tools.logging import logger @@ -240,7 +241,7 @@ def compiler( ) _dump_gm(dump_folder, gm, f"{name}_before_compiler") - if len(passes) > 0 and passes[-1].__name__ == "cudagraph_pass": + if end_with_pass(passes, ["cudagraph_pass"]): # cudagraph pass is always the last pass if it is applied cg_pass = passes[-1] @@ -304,6 +305,21 @@ def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return fw_compiler, bw_compiler +def validate_pass_names(pass_names: list[str]) -> None: + if "cudagraph" in pass_names: + assert ( + pass_names[-1] == "cudagraph" + ), "cudagraph has to be the last pass to apply" + + if ( + "autobucketing_reordering" in pass_names + and "transformer_block_bucketing" in pass_names + ): + raise ValueError( + "Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!" + ) + + def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfig): """ Extract and validate compiler passes from job config. @@ -320,13 +336,7 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi ) pass_names = getattr(job_config.compile, "passes", []) - if ( - "autobucketing_reordering" in pass_names - and "transformer_block_bucketing" in pass_names - ): - raise ValueError( - "Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!" - ) + validate_pass_names(pass_names) compiler_passes = [] use_cudagraph = "cudagraph" in pass_names diff --git a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py index 725dec718c..f01a1c4380 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py +++ b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py @@ -58,6 +58,20 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "llama3_fsdp_tp_manualbucketing", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes cudagraph", + ], + ], + "llama3 FSDP+TP+cudagraph", + "llama3_fsdp_tp_cudagraph", + ngpu=4, + ), OverrideDefinitions( [ [ @@ -92,12 +106,13 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "--model.name compiler_toolkit.llama3", "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", + "--model.flavor debugmodel_flex_attn", "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", - "--compile.passes cudagraph", + "--compile.passes autobucketing_reordering,regional_inductor,cudagraph", ], ], - "llama3 FSDP+TP+cudagraph", - "llama3_fsdp_tp_cudagraph", + "llama3 FSDP+TP+FlexAttn autobucketing regional_inductor+cudagraph", + "llama3_fsdp_tp_flexattn_autobucketing_regional_inductor_cudagraph", ngpu=4, ), OverrideDefinitions( From 26414c06a150cc9797e9b0c6387f2dd44c3d788d Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 19 Nov 2025 10:35:40 -0800 Subject: [PATCH 19/22] add runtime checks --- .../experiments/compiler_toolkit/cudagraph.py | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/cudagraph.py b/torchtitan/experiments/compiler_toolkit/cudagraph.py index 64602cea3c..cd6e4cfc22 100644 --- a/torchtitan/experiments/compiler_toolkit/cudagraph.py +++ b/torchtitan/experiments/compiler_toolkit/cudagraph.py @@ -62,6 +62,7 @@ def __init__( runnable: Callable, example_inputs: Sequence[Any], static_input_indices: Optional[tuple[int]] = None, + should_check_address: bool = False, ): self.runnable = runnable self.graph_pool = _global_graph_pool @@ -80,10 +81,30 @@ def __init__( self.args = None self.output = None - def copy_static_inputs(self, *args): + # (debug only) whether check static input tensor addresses during runtime + self.should_check_address = should_check_address + + def copy_non_static_inputs(self, *args): for i in self.input_indices_to_copy: self.args[i].copy_(args[i]) + def check_input_types(self, inputs) -> None: + for inp in inputs: + assert isinstance(inp, (torch.Tensor, int, torch._C.Generator)), ( + "args must be tensor, integer (for dynamic shapes), " + "or Generator (for random number generator), " + f"but found {type(inp)}" + ) + + def check_static_inputs_address(self) -> None: + for i in self.static_input_indices: + actual = args[i].data_ptr() + expected = self.input_addresses[i] + assert expected == actual, ( + "Expected the same static tensor address but found " + f"{expected} != {actual}" + ) + def __call__(self, *args): if not self.has_warmup: self.has_warmup = True @@ -96,11 +117,11 @@ def __call__(self, *args): return out if self.cudagraph is None: + self.check_input_types(args) self.args = args - input_addresses = [ + self.input_addresses = [ x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args ] - self.input_addresses = input_addresses self.cudagraph = torch.cuda.CUDAGraph() @@ -110,7 +131,10 @@ def __call__(self, *args): # `output` is managed by pytorch's cudagraph pool self.output = self.runnable(*args) - self.copy_static_inputs(*args) + if self.should_check_address: + self.check_static_inputs_address() + + self.copy_non_static_inputs(*args) self.cudagraph.replay() return self.output From 2d037e4290fde551620125eefd85558d29c0e06b Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 19 Nov 2025 11:38:08 -0800 Subject: [PATCH 20/22] cleanup --- torchtitan/experiments/compiler_toolkit/graph_utils.py | 9 --------- torchtitan/train.py | 3 ++- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 36e7ff2124..e097579cc0 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -339,12 +339,7 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi validate_pass_names(pass_names) compiler_passes = [] - use_cudagraph = "cudagraph" in pass_names - for pass_name in pass_names: - if pass_name == "cudagraph": - continue - if pass_name not in AVAILABLE_COMPILER_PASSES: raise ValueError( f"Unknown compiler pass: {pass_name}. " @@ -360,10 +355,6 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi else: compiler_passes.append(AVAILABLE_COMPILER_PASSES[pass_name]) - if use_cudagraph: - # cudagraph should always be the last fx pass to apply - compiler_passes.append(AVAILABLE_COMPILER_PASSES["cudagraph"]) - if pass_names: logger.info(f"Using compiler passes from config: {pass_names}") diff --git a/torchtitan/train.py b/torchtitan/train.py index 872e339493..3bcbc153f8 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -710,7 +710,8 @@ def close(self) -> None: # in joint_graph_module. An explicit gc.collect() is necessary # to clean up reference cycles. for part in self.model_parts: - part.joint_graph_module = None + if hasattr(part, "joint_graph_module"): + part.joint_graph_module = None gc.collect() From c8e738477796341faedb07dc2e6e8c68d15fb4ef Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 19 Nov 2025 16:32:13 -0800 Subject: [PATCH 21/22] cleanup --- run_train.sh | 7 +------ torchtitan/experiments/compiler_toolkit/README.md | 2 +- torchtitan/experiments/compiler_toolkit/train.py | 15 ++++++++++++++- .../models/llama3/train_configs/llama3_70b.toml | 2 +- torchtitan/train.py | 11 ----------- 5 files changed, 17 insertions(+), 20 deletions(-) diff --git a/run_train.sh b/run_train.sh index d20abe1ff6..83319816fe 100755 --- a/run_train.sh +++ b/run_train.sh @@ -19,18 +19,13 @@ DRY_RUN=${DRY_RUN:-0} TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} -# need to turn off expandable segments when using cudagraph, since -# it does not work with cg and nccl yet. -# https://github.com/pytorch/pytorch/issues/158029 -USE_EXPANDABLE_SEGMENTS=${USE_EXPANDABLE_SEGMENTS:-True} - if [ "$DRY_RUN" = "1" ]; then # Dry run mode: validate configuration without GPU/distributed setup echo "Running in DRY RUN mode - configuration validation only" python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@" else # Normal training with torchrun - PYTORCH_ALLOC_CONF="expandable_segments:${USE_EXPANDABLE_SEGMENTS}" \ + PYTORCH_ALLOC_CONF="expandable_segments:True" \ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index 88e43d4e8e..620911ce60 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -59,5 +59,5 @@ NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./to **SimpleFSDP + TP + FlexAttention + transformer-block-bucketing + regional-inductor + cudagraph** ```shell -NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor,cudagraph +NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor,cudagraph ``` diff --git a/torchtitan/experiments/compiler_toolkit/train.py b/torchtitan/experiments/compiler_toolkit/train.py index 26e3245b2b..7b0d58aa5a 100644 --- a/torchtitan/experiments/compiler_toolkit/train.py +++ b/torchtitan/experiments/compiler_toolkit/train.py @@ -4,11 +4,24 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import gc + from torchtitan.train import main, Trainer class CompilerToolkitTrainer(Trainer): - pass + def close(self) -> None: + super().close() + + # Note [explicit cudagraph close] + # cudagraph holds reference to nccl which prevents destroy nccl + # group. so we need to explicitly delete cudagraph which is held + # in joint_graph_module. An explicit gc.collect() is necessary + # to clean up reference cycles. + for part in self.model_parts: + if hasattr(part, "joint_graph_module"): + part.joint_graph_module = None + gc.collect() if __name__ == "__main__": diff --git a/torchtitan/models/llama3/train_configs/llama3_70b.toml b/torchtitan/models/llama3/train_configs/llama3_70b.toml index 37fd35b5cb..8dc993cbc7 100644 --- a/torchtitan/models/llama3/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_70b.toml @@ -30,7 +30,7 @@ warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps [training] local_batch_size = 8 -seq_len = 8192 +seq_len = 2048 max_norm = 1.0 # grad norm clipping steps = 1000 dataset = "c4" diff --git a/torchtitan/train.py b/torchtitan/train.py index 3bcbc153f8..5cfab998b2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import gc import importlib import os import time @@ -704,16 +703,6 @@ def close(self) -> None: if hasattr(self, "metrics_processor") and self.metrics_processor: self.metrics_processor.close() - # Note [explicit cudagraph close] - # cudagraph holds reference to nccl which prevents destroy nccl - # group. so we need to explicitly delete cudagraph which is held - # in joint_graph_module. An explicit gc.collect() is necessary - # to clean up reference cycles. - for part in self.model_parts: - if hasattr(part, "joint_graph_module"): - part.joint_graph_module = None - gc.collect() - def main(trainer_class: type[Trainer]) -> None: """Main entry point for training with a specified trainer class. From 0516fa7739ca3d68b7dab90833253d8568c035b1 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 19 Nov 2025 16:33:28 -0800 Subject: [PATCH 22/22] nit --- torchtitan/models/llama3/train_configs/llama3_70b.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/llama3/train_configs/llama3_70b.toml b/torchtitan/models/llama3/train_configs/llama3_70b.toml index 8dc993cbc7..37fd35b5cb 100644 --- a/torchtitan/models/llama3/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_70b.toml @@ -30,7 +30,7 @@ warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps [training] local_batch_size = 8 -seq_len = 2048 +seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 1000 dataset = "c4"