Skip to content

Conversation

@yiming0416
Copy link
Contributor

@yiming0416 yiming0416 commented Nov 18, 2025

This PR integrates the changes in #1970 to compiler toolkit (applying joint_ac_pass on the joint graph graph to tag nodes based on reshard_after_forward flag)

Also did some refactor for applying graph passes in compiler toolkit experiments. We will have two kinds of passes

  1. joint_custom_passes: these are passes to be applied on the captured joint graph before partitioner. By default we validate_flex_attn_annotation_pass and fsdp_reshard_after_fwd_pass

  2. compiler_passes: there are passes to be applied on partitioned fwd and bwd graphs as backend optimizations. By default there is none. We can indicate autobucketing_reordering_pass and regional_inductor_pass using configs.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 18, 2025
@yiming0416 yiming0416 force-pushed the yiming/add_reshard_after_forward_ac_pass branch from 0d70d9e to 88c79b6 Compare November 18, 2025 00:47
@yiming0416 yiming0416 marked this pull request as ready for review November 18, 2025 01:02
return gm


# Apply activation checkpointing on joint graph before partitioner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use the "joint_ac_pass" here to implement FSDP's reshard_after_forward. Maybe better to rename it to fsdp_reshard_after_fwd_pass?

)

joint_custom_passes = []
joint_custom_passes.append(validate_flex_attn_annotation_pass)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just out of curiosity: do we care about the order of passes applied in joint_custom_passes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably care the order of passes when later we have more joint_custom_passes.

Right now the validate_flex_attn_annotation_pass is just validating if flex attention nodes are annotated and will error out if it fails. So I put it at first.

@yiming0416 yiming0416 force-pushed the yiming/add_reshard_after_forward_ac_pass branch from 88c79b6 to 528b319 Compare November 18, 2025 02:08
@yiming0416 yiming0416 merged commit 3819737 into main Nov 18, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants