Skip to content

Conversation

tinglvv
Copy link
Collaborator

@tinglvv tinglvv commented Oct 30, 2023

fixes "Duplicate GPU detected : rank 1 and rank 0 both on CUDA device" on test_fsdp_fine_tune.py. Only run the test if GPU number > 1.
cc @eqy @ptrblck

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 30, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112406

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6456a48 with merge base 8858eda (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eqy
Copy link
Collaborator

eqy commented Oct 30, 2023

It seems a bit scary add a skip for the run function vs. individual tests; would it be possible to fix this issue by adding the skip decorator to the individual failing tests?

@awgu
Copy link
Collaborator

awgu commented Oct 30, 2023

Strangely, it looks like every test in that file already has @skip_if_lt_x_gpu(2).

@tinglvv
Copy link
Collaborator Author

tinglvv commented Oct 31, 2023

Thanks @eqy and @awgu for reviewing!

It seems a bit scary add a skip for the run function vs. individual tests; would it be possible to fix this issue by adding the skip decorator to the individual failing tests?

Currently the decorator is added to the test_fsdp_fine_tune.py, and with "super" we are not modifying the parent class. It would be the only way as far as I see now, so from the error log
Last error: Duplicate GPU detected : rank 1 and rank 0 both on CUDA device 81000 dist init r=0, world=2 dist init r=1, world=2

and https://github.com/coyotelll/pytorch/blob/f0ae3b73369849ce6e530effc0e5c770472c67aa/torch/testing/_internal/common_fsdp.py#L904,
It appears the _run method is called twice, which will have the error with rank. And as the main purpose of _run method is to run_test, which has @skip_if_lt_x_gpu(2) all over, it is okay to skip the _run method. https://github.com/coyotelll/pytorch/blob/f0ae3b73369849ce6e530effc0e5c770472c67aa/torch/testing/_internal/common_fsdp.py#L931.

Strangely, it looks like every test in that file already has @skip_if_lt_x_gpu(2).

Yes. So it would be okay to skip the _run method for 1 GPU case.

@awgu
Copy link
Collaborator

awgu commented Nov 1, 2023

@coyotelll Sorry, I am not sure I followed.

What is special about test_fsdp_fine_tune.py as opposed to the other FSDP unit tests? What setup does it take to reproduce the error?

@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 2, 2023
@fegin
Copy link
Contributor

fegin commented Nov 2, 2023

You can change the world_size to min(torch.cuda.device_count(), 2). This will make each skip_if_lt_x_gpu, though I'm also wondering the same thing as @awgu, in what setup do you get the error?

@tinglvv
Copy link
Collaborator Author

tinglvv commented Nov 2, 2023

You can change the world_size to min(torch.cuda.device_count(), 2). This will make each skip_if_lt_x_gpu, though I'm also wondering the same thing as @awgu, in what setup do you get the error?

Thanks for the comment. We get this error when running L0_self_test_distributed on for example H100. To reproduce, can run test/distributed/fsdp/test_fsdp_fine_tune.py::TestFSDPFineTune::test_parity_with_ddp.

Error is
Last error: Duplicate GPU detected : rank 1 and rank 0 both on CUDA device 101000

@tinglvv tinglvv requested a review from LucasLLC as a code owner November 3, 2023 23:20
@eqy
Copy link
Collaborator

eqy commented Nov 9, 2023

Rebasing to see if the android failure is fix (sorry if you need to update your local branch after this)

@eqy
Copy link
Collaborator

eqy commented Nov 9, 2023

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased tingl-fsdpfix onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout tingl-fsdpfix && git pull --rebase)

@eqy
Copy link
Collaborator

eqy commented Nov 10, 2023

@awgu are your concerns addressed or did you want to try out the fix suggested by @fegin ?

@awgu
Copy link
Collaborator

awgu commented Nov 10, 2023

@eqy Sorry if I missed something, but I am still not clear on:

What is special about test_fsdp_fine_tune.py as opposed to the other FSDP unit tests?

@eqy
Copy link
Collaborator

eqy commented Nov 17, 2023

@coyotelll does this fix the issue as observed on your end (are we ready to merge?)

@tinglvv
Copy link
Collaborator Author

tinglvv commented Nov 17, 2023

@eqy Yes the update solves this issue. We can merge now. Thanks.

@eqy
Copy link
Collaborator

eqy commented Nov 17, 2023

@awgu do you mind stamping this?

@awgu
Copy link
Collaborator

awgu commented Nov 17, 2023

I am happy to stamp. I just want to understand the implications more broadly on our unit tests since we typically return 2 or min(4, torch.cuda.device_count()) for world size.

@eqy
Copy link
Collaborator

eqy commented Nov 17, 2023

Could it be that upstream CI only ever runs fsdp tests on runners with enough GPUs (world size >= 2)? That could explain why the failure is not visible in upstream. In other words, my understanding is that the hardcoded world size would turn decorators gating tests based on the number of GPUs to be hardcoded no-ops. This fix in particular should also only affect TestFSDPFineTune, correct?

@awgu
Copy link
Collaborator

awgu commented Nov 17, 2023

This fix in particular should also only affect TestFSDPFineTune, correct?

Yes. I think so. That is why I still have some uncertainty -- it seems like either we make a change to all unit tests that override world_size or something else is up.

If I run

CUDA_VISIBLE_DEVICES=0 python -m pytest test/distributed/fsdp/test_fsdp_fine_tune.py

I can reproduce the error on my machine. I can also do so for test_fsdp_grad_acc.py, and I imagine as well for the other tests that override world_size. I am not sure if this was a regression somehow though. I recall it would just skip the unit test before when there were an insufficient number of GPUs.

Either way, let us stamp this one to unblock, and we can investigate further.

@eqy
Copy link
Collaborator

eqy commented Nov 17, 2023

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 17, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@tinglvv
Copy link
Collaborator Author

tinglvv commented Nov 17, 2023

@awgu, thanks for approving! More data on this: We see this error in several places in our CI where world_size is defined to be >= 2. Another example is https://github.com/pytorch/pytorch/blob/main/test/distributed/fsdp/test_fsdp_state_dict.py#L1239. Would probably require several changes to fix this issue, but not sure if this would suffice as a regression.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants