Skip to content

Conversation

anshul-si
Copy link
Contributor

@anshul-si anshul-si commented Sep 27, 2025

Summary: In order to test numerics for replicate + pp, stage.py needs to be able to call replicate's backward manually as pipeline parallelism doesn't have this feature.

Test Case

  1. pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp

Stack from ghstack (oldest at bottom):

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci

…on with pipeline parallelism

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Sep 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164031

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b33e7fc with merge base bec6541 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 27, 2025
anshul-si added a commit that referenced this pull request Sep 27, 2025
…on with pipeline parallelism

ghstack-source-id: f808f7f
Pull Request resolved: #164031
@anshul-si anshul-si added the topic: not user facing topic category label Sep 27, 2025
…cate function with pipeline parallelism"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 27, 2025
…on with pipeline parallelism

ghstack-source-id: b1f52c7
Pull Request resolved: #164031
…cate function with pipeline parallelism"

**Summary:** In order to test numerics for replicate + pp, stage.py needs to be able to call replicate's backward manually as pipeline parallelism doesn't have this feature. 

**Test Case**
1.  pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 29, 2025
…on with pipeline parallelism

ghstack-source-id: 2432b6a
Pull Request resolved: #164031
…cate function with pipeline parallelism"

**Summary:** In order to test numerics for replicate + pp, stage.py needs to be able to call replicate's backward manually as pipeline parallelism doesn't have this feature. 

**Test Case**
1.  pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 29, 2025
…on with pipeline parallelism

ghstack-source-id: 8f3247d
Pull Request resolved: #164031
@anshul-si anshul-si added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 29, 2025
…cate function with pipeline parallelism"

**Summary:** In order to test numerics for replicate + pp, stage.py needs to be able to call replicate's backward manually as pipeline parallelism doesn't have this feature. 

**Test Case**
1.  pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 29, 2025
…on with pipeline parallelism

ghstack-source-id: d3517eb
Pull Request resolved: #164031
@H-Huang H-Huang added the module: pipelining Pipeline Parallelism label Sep 29, 2025
result = perform_backward(backward_type)()

# If submod is a Replicate module
elif isinstance(self.submod, ReplicateModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I compared ReplicateModule and FSDPModule line by line, but seems they are almost identical. what's the core diffrerence?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only difference is we need a way to distinguish modules that have been replicated vs fully-sharded so we know whether to get replicate state or fully shard state

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, agree here. The only difference i see is from fully_shard.state() and replicate.state(). To reduce code duplication, i think you could define a helper method like _handle_fsdp_module_backward that takes in a lamba function.

I'm also not sure why ReplicateModule and FSDPModule are different modules if it seems like they support the same things. We should look into that first

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got you. is it possible to do this?

I don't mind unit test to be verbose. But I really hope the core code stay dry

elif isinstance(self.submod, FSDPModule):  # ReplicateModule is also a FSDPModule
      ...
      if isinstance(self.submod, ReplicateModule):
          state = replicate.state(self.submod)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, if ReplicateModule is also an FSDP Module, then I like what @weifengpy proposed we should do that!

ref_pipeline_schedule.step(target=labels, losses=ref_losses)

for loss, ref_loss in zip(losses, ref_losses):
self.assertEqual(loss, ref_loss)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a lot of our tests we compare single-device model grads with pipeline parallel model grads. This test just compares the loss of two pipelines. Might be out of scope to add this though, but maybe can add a TODO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, my only goals with this were to confirm that replicate + pipeline parallelism actually runs and that adding replicate to a pipeline doesn't affect it. I would prefer to make this a TODO that I can work on while running MAST tests in a few days

…cate function with pipeline parallelism"

**Summary:** In order to test numerics for replicate + pp, stage.py needs to be able to call replicate's backward manually as pipeline parallelism doesn't have this feature. 

**Test Case**
1.  pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 30, 2025
…on with pipeline parallelism

ghstack-source-id: d24f57f
Pull Request resolved: #164031
…cate function with pipeline parallelism"

**Summary:** In order to test numerics for replicate + pp, stage.py needs to be able to call replicate's backward manually as pipeline parallelism doesn't have this feature. 

**Test Case**
1.  pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 30, 2025
…on with pipeline parallelism

ghstack-source-id: 42a71a9
Pull Request resolved: #164031
Copy link
Contributor

@weifengpy weifengpy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking good on the fsdp part

Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@anshul-si
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants