[Train] Fix lightning checkpoint report callback #42751
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Why are these changes needed?
In
RayTrainReportCallback
, we were usingtorch.distributed.barrier()
to coordinate workers without specifying the device argument. If users do not set up the torch cuda device themselves, then the collective calls will be all bound to the default device (cuda:0
). This could mess up the device map of the barrier call.Why it hangs?
pytorch/pytorch#53658
Internally,
torch.distributed.barrier
callsallreduce
on a dummy tensor link. The dummy tensor is created on the GPU specified bybarrier(device_id=)
. NCCL will try to create a communicator for current process on device x if it doesn't exist, which is a blocking operation.When the users don't specify
device_id
, and the barrier will use thetorch.cuda.current_device
, which can be allcuda:0
, if notorch.cuda.set_device
was called before. The first call works, but the successive call hangs if it's binding processes with differentdevice_id
.Therefore, to avoid deadlock, the rule-of-thumb is to always explicitly specify different device id for each workers for all the collective calls.
Solution
This PR switch from
torch.distributed.barrier()
to lightning'strainer.strategy.barrier()
, which explicitly specified the cuda device for each barrier call to ensure the ordering.Related issue number
Closes #42927
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.