Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions torchtitan/experiments/compiler_toolkit/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,3 @@ def register_blockmask_pytree_node():
flatten_with_keys_fn=BlockMask._flatten_with_keys,
serialized_type_name="torch.nn.attention.flex_attention.BlockMask",
)


def validate_flex_attention_annotation(joint_with_descriptors):
"""Verify user annotations show up in the graph."""
for node in joint_with_descriptors.graph_module.graph.nodes:
if node.target in {
torch.ops.higher_order.flex_attention,
torch.ops.higher_order.flex_attention_backward,
}:
assert "compile_with_inductor" in node.meta.get("custom", {})
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
disable_compile,
parallelize_inputs,
register_blockmask_pytree_node,
validate_flex_attention_annotation,
)

from torchtitan.experiments.compiler_toolkit.graph_utils import (
CompiledModule,
get_compiler_passes_from_config,
get_joint_custom_passes_from_config,
joint_graph_builder,
make_compiler_with_passes,
)
Expand Down Expand Up @@ -76,6 +76,9 @@ def parallelize_deepseekv3(
with disable_compile(job_config):
model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config)

# Get joint custom passes from config
joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config)

# Get compiler passes from config
compiler_passes = get_compiler_passes_from_config(job_config)

Expand All @@ -89,7 +92,7 @@ def parallelize_deepseekv3(
joint_graph_builder,
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
joint_custom_pass=validate_flex_attention_annotation,
joint_custom_passes=joint_custom_passes,
dump_folder=job_config.job.dump_folder,
)

Expand Down
64 changes: 56 additions & 8 deletions torchtitan/experiments/compiler_toolkit/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import contextlib
import functools
from pathlib import Path
from typing import Any, Callable, List, Optional

Expand Down Expand Up @@ -86,7 +87,7 @@ def joint_graph_builder(
model_kwargs: dict,
fw_compiler: Optional[Callable] = None,
bw_compiler: Optional[Callable] = None,
joint_custom_pass: Optional[Callable] = None,
joint_custom_passes: Optional[List[Callable]] = None,
dump_folder: str | None = None,
):
"""
Expand All @@ -98,7 +99,7 @@ def joint_graph_builder(
model_kwargs: Dict of model input keyword arguments
fw_compiler: Optional custom forward compiler function
bw_compiler: Optional custom backward compiler function
joint_custom_pass: Optional custom pass to run on the joint graph
joint_custom_passes: list of custom passes to run on the joint graph
dump_folder: Optional folder to dump the graph to
"""
assert isinstance(model_args, tuple)
Expand All @@ -112,8 +113,11 @@ def joint_graph_builder(
) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder)

# Optional validation
if joint_custom_pass is not None:
joint_custom_pass(joint_with_descriptors)
if joint_custom_passes is not None:
for joint_custom_pass in joint_custom_passes:
joint_with_descriptors.graph_module = joint_custom_pass(
joint_with_descriptors.graph_module
)

with tracing(tracing_context):
fn = aot_compile_joint_with_descriptors(
Expand Down Expand Up @@ -283,20 +287,64 @@ def get_compiler_passes_from_config(job_config: JobConfig):
Returns:
List of compiler pass functions
"""
from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_PASSES
from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_COMPILER_PASSES

pass_names = getattr(job_config.compile, "passes", [])
compiler_passes = []

for pass_name in pass_names:
if pass_name not in AVAILABLE_PASSES:
if pass_name not in AVAILABLE_COMPILER_PASSES:
raise ValueError(
f"Unknown compiler pass: {pass_name}. "
f"Available passes: {list(AVAILABLE_PASSES.keys())}"
f"Available compiler passes: {list(AVAILABLE_COMPILER_PASSES.keys())}"
)
compiler_passes.append(AVAILABLE_PASSES[pass_name])
compiler_passes.append(AVAILABLE_COMPILER_PASSES[pass_name])

if pass_names:
logger.info(f"Using compiler passes from config: {pass_names}")

return compiler_passes


def get_joint_custom_passes_from_config(
parallel_dims: ParallelDims, job_config: JobConfig
):
"""
Extract and validate joint custom passes from job config.

Args:
job_config: Job configuration containing parallelism.fsdp_reshard_after_forward

Returns:
List of joint custom pass functions
"""
from torchtitan.experiments.compiler_toolkit.passes import (
fsdp_reshard_after_fwd_pass,
validate_flex_attn_annotation_pass,
)

joint_custom_passes = []
joint_custom_passes.append(validate_flex_attn_annotation_pass)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just out of curiosity: do we care about the order of passes applied in joint_custom_passes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably care the order of passes when later we have more joint_custom_passes.

Right now the validate_flex_attn_annotation_pass is just validating if flex attention nodes are annotated and will error out if it fails. So I put it at first.


match job_config.parallelism.fsdp_reshard_after_forward:
case "always":
fsdp_reshard_after_forward = True
case "never":
fsdp_reshard_after_forward = False
case "default":
# For PP, by default do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
fsdp_reshard_after_forward = not parallel_dims.pp_enabled
case _:
raise ValueError(
f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
)

joint_custom_passes.append(
functools.partial(
fsdp_reshard_after_fwd_pass,
reshard_after_forward=fsdp_reshard_after_forward,
)
)

return joint_custom_passes
9 changes: 6 additions & 3 deletions torchtitan/experiments/compiler_toolkit/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
disable_compile,
parallelize_inputs,
register_blockmask_pytree_node,
validate_flex_attention_annotation,
)

from torchtitan.experiments.compiler_toolkit.graph_utils import (
CompiledModule,
get_compiler_passes_from_config,
get_joint_custom_passes_from_config,
joint_graph_builder,
make_compiler_with_passes,
)
Expand Down Expand Up @@ -63,6 +63,9 @@ def parallelize_llama(
with disable_compile(job_config):
model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config)

# Get joint custom passes from config
joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config)

# Get compiler passes from config
compiler_passes = get_compiler_passes_from_config(job_config)

Expand All @@ -71,12 +74,12 @@ def parallelize_llama(
compiler_passes, dump_folder=job_config.job.dump_folder
)

# Create custom joint_graph_builder with llama-specific compilers and validation
# Create custom joint_graph_builder with llama-specific compilers
llama_joint_graph_builder = functools.partial(
joint_graph_builder,
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
joint_custom_pass=validate_flex_attention_annotation,
joint_custom_passes=joint_custom_passes,
dump_folder=job_config.job.dump_folder,
)

Expand Down
32 changes: 31 additions & 1 deletion torchtitan/experiments/compiler_toolkit/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import torch
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
from torch.fx.passes.regional_inductor import regional_inductor
from torchtitan.experiments.simple_fsdp.reshard_after_forward import (
annotate_fsdp_all_gather,
)


def autobucketing_reordering_pass(
Expand All @@ -39,8 +42,35 @@ def regional_inductor_pass(
return regional_inductor(gm, example_inputs)


def validate_flex_attn_annotation_pass(
gm: torch.fx.GraphModule,
) -> torch.fx.GraphModule:
"""Verify user annotations show up in the graph."""
for node in gm.graph.nodes:
if node.target in {
torch.ops.higher_order.flex_attention,
torch.ops.higher_order.flex_attention_backward,
}:
assert "compile_with_inductor" in node.meta.get("custom", {})
return gm


# Apply activation checkpointing on joint graph before partitioner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use the "joint_ac_pass" here to implement FSDP's reshard_after_forward. Maybe better to rename it to fsdp_reshard_after_fwd_pass?

def fsdp_reshard_after_fwd_pass(
gm: torch.fx.GraphModule, reshard_after_forward: bool
) -> torch.fx.GraphModule:
# this pass implements simplefsdp's fsdp_reshard_after_forward behavior
# when fsdp_reshard_after_forward set to True, it will annotate simple_fsdp AG
# to CheckpointPolicy.MUST_RECOMPUTE.
# when fsdp_reshard_after_forward set to False, it will annotate simple_fsdp AG
# to CheckpointPolicy.MUST_SAVE.
gm = annotate_fsdp_all_gather(gm, reshard_after_forward)
gm.recompile()
return gm


# Registry mapping pass names to pass functions
AVAILABLE_PASSES = {
AVAILABLE_COMPILER_PASSES = {
"autobucketing_reordering": autobucketing_reordering_pass,
"regional_inductor": regional_inductor_pass,
}