From 528b319520496817b02034d5528c19659c7adc24 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Mon, 17 Nov 2025 16:31:44 -0800 Subject: [PATCH] add joint_ac_pass --- .../compiler_toolkit/common_utils.py | 10 --- .../deepseek_v3/parallelize.py | 7 +- .../compiler_toolkit/graph_utils.py | 64 ++++++++++++++++--- .../compiler_toolkit/llama3/parallelize.py | 9 ++- .../experiments/compiler_toolkit/passes.py | 32 +++++++++- 5 files changed, 98 insertions(+), 24 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index b7499b2f79..965e027bdb 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -53,13 +53,3 @@ def register_blockmask_pytree_node(): flatten_with_keys_fn=BlockMask._flatten_with_keys, serialized_type_name="torch.nn.attention.flex_attention.BlockMask", ) - - -def validate_flex_attention_annotation(joint_with_descriptors): - """Verify user annotations show up in the graph.""" - for node in joint_with_descriptors.graph_module.graph.nodes: - if node.target in { - torch.ops.higher_order.flex_attention, - torch.ops.higher_order.flex_attention_backward, - }: - assert "compile_with_inductor" in node.meta.get("custom", {}) diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 20ad17f301..982843bb24 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -17,12 +17,12 @@ disable_compile, parallelize_inputs, register_blockmask_pytree_node, - validate_flex_attention_annotation, ) from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, get_compiler_passes_from_config, + get_joint_custom_passes_from_config, joint_graph_builder, make_compiler_with_passes, ) @@ -76,6 +76,9 @@ def parallelize_deepseekv3( with disable_compile(job_config): model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config) + # Get joint custom passes from config + joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config) + # Get compiler passes from config compiler_passes = get_compiler_passes_from_config(job_config) @@ -89,7 +92,7 @@ def parallelize_deepseekv3( joint_graph_builder, fw_compiler=fw_compiler, bw_compiler=bw_compiler, - joint_custom_pass=validate_flex_attention_annotation, + joint_custom_passes=joint_custom_passes, dump_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index cd758438b3..fa93b02b63 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +import functools from pathlib import Path from typing import Any, Callable, List, Optional @@ -86,7 +87,7 @@ def joint_graph_builder( model_kwargs: dict, fw_compiler: Optional[Callable] = None, bw_compiler: Optional[Callable] = None, - joint_custom_pass: Optional[Callable] = None, + joint_custom_passes: Optional[List[Callable]] = None, dump_folder: str | None = None, ): """ @@ -98,7 +99,7 @@ def joint_graph_builder( model_kwargs: Dict of model input keyword arguments fw_compiler: Optional custom forward compiler function bw_compiler: Optional custom backward compiler function - joint_custom_pass: Optional custom pass to run on the joint graph + joint_custom_passes: list of custom passes to run on the joint graph dump_folder: Optional folder to dump the graph to """ assert isinstance(model_args, tuple) @@ -112,8 +113,11 @@ def joint_graph_builder( ) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder) # Optional validation - if joint_custom_pass is not None: - joint_custom_pass(joint_with_descriptors) + if joint_custom_passes is not None: + for joint_custom_pass in joint_custom_passes: + joint_with_descriptors.graph_module = joint_custom_pass( + joint_with_descriptors.graph_module + ) with tracing(tracing_context): fn = aot_compile_joint_with_descriptors( @@ -283,20 +287,64 @@ def get_compiler_passes_from_config(job_config: JobConfig): Returns: List of compiler pass functions """ - from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_PASSES + from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_COMPILER_PASSES pass_names = getattr(job_config.compile, "passes", []) compiler_passes = [] for pass_name in pass_names: - if pass_name not in AVAILABLE_PASSES: + if pass_name not in AVAILABLE_COMPILER_PASSES: raise ValueError( f"Unknown compiler pass: {pass_name}. " - f"Available passes: {list(AVAILABLE_PASSES.keys())}" + f"Available compiler passes: {list(AVAILABLE_COMPILER_PASSES.keys())}" ) - compiler_passes.append(AVAILABLE_PASSES[pass_name]) + compiler_passes.append(AVAILABLE_COMPILER_PASSES[pass_name]) if pass_names: logger.info(f"Using compiler passes from config: {pass_names}") return compiler_passes + + +def get_joint_custom_passes_from_config( + parallel_dims: ParallelDims, job_config: JobConfig +): + """ + Extract and validate joint custom passes from job config. + + Args: + job_config: Job configuration containing parallelism.fsdp_reshard_after_forward + + Returns: + List of joint custom pass functions + """ + from torchtitan.experiments.compiler_toolkit.passes import ( + fsdp_reshard_after_fwd_pass, + validate_flex_attn_annotation_pass, + ) + + joint_custom_passes = [] + joint_custom_passes.append(validate_flex_attn_annotation_pass) + + match job_config.parallelism.fsdp_reshard_after_forward: + case "always": + fsdp_reshard_after_forward = True + case "never": + fsdp_reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + fsdp_reshard_after_forward = not parallel_dims.pp_enabled + case _: + raise ValueError( + f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." + ) + + joint_custom_passes.append( + functools.partial( + fsdp_reshard_after_fwd_pass, + reshard_after_forward=fsdp_reshard_after_forward, + ) + ) + + return joint_custom_passes diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index 62def3ef00..e746c24228 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -16,12 +16,12 @@ disable_compile, parallelize_inputs, register_blockmask_pytree_node, - validate_flex_attention_annotation, ) from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, get_compiler_passes_from_config, + get_joint_custom_passes_from_config, joint_graph_builder, make_compiler_with_passes, ) @@ -63,6 +63,9 @@ def parallelize_llama( with disable_compile(job_config): model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config) + # Get joint custom passes from config + joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config) + # Get compiler passes from config compiler_passes = get_compiler_passes_from_config(job_config) @@ -71,12 +74,12 @@ def parallelize_llama( compiler_passes, dump_folder=job_config.job.dump_folder ) - # Create custom joint_graph_builder with llama-specific compilers and validation + # Create custom joint_graph_builder with llama-specific compilers llama_joint_graph_builder = functools.partial( joint_graph_builder, fw_compiler=fw_compiler, bw_compiler=bw_compiler, - joint_custom_pass=validate_flex_attention_annotation, + joint_custom_passes=joint_custom_passes, dump_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 1c00fd5c1b..c0cec614a9 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -14,6 +14,9 @@ import torch from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing from torch.fx.passes.regional_inductor import regional_inductor +from torchtitan.experiments.simple_fsdp.reshard_after_forward import ( + annotate_fsdp_all_gather, +) def autobucketing_reordering_pass( @@ -39,8 +42,35 @@ def regional_inductor_pass( return regional_inductor(gm, example_inputs) +def validate_flex_attn_annotation_pass( + gm: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """Verify user annotations show up in the graph.""" + for node in gm.graph.nodes: + if node.target in { + torch.ops.higher_order.flex_attention, + torch.ops.higher_order.flex_attention_backward, + }: + assert "compile_with_inductor" in node.meta.get("custom", {}) + return gm + + +# Apply activation checkpointing on joint graph before partitioner +def fsdp_reshard_after_fwd_pass( + gm: torch.fx.GraphModule, reshard_after_forward: bool +) -> torch.fx.GraphModule: + # this pass implements simplefsdp's fsdp_reshard_after_forward behavior + # when fsdp_reshard_after_forward set to True, it will annotate simple_fsdp AG + # to CheckpointPolicy.MUST_RECOMPUTE. + # when fsdp_reshard_after_forward set to False, it will annotate simple_fsdp AG + # to CheckpointPolicy.MUST_SAVE. + gm = annotate_fsdp_all_gather(gm, reshard_after_forward) + gm.recompile() + return gm + + # Registry mapping pass names to pass functions -AVAILABLE_PASSES = { +AVAILABLE_COMPILER_PASSES = { "autobucketing_reordering": autobucketing_reordering_pass, "regional_inductor": regional_inductor_pass, }