-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] Fix Deepspeed device ranks check in Lightning 2.0.5 #37387
[Train] Fix Deepspeed device ranks check in Lightning 2.0.5 #37387
Conversation
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
# Device ranks have already been specified in RayEnvironment | ||
# Clear parallel_devices to skip deepspeed local rank checks | ||
self.parallel_devices = [] | ||
self.parallel_devices = list(range(torch.cuda.device_count())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would something like this make sense? Not sure if the devices indices are always starting from 0.
self.parallel_devices = list(range(torch.cuda.device_count())) | |
devices = train.torch.get_device() | |
if not isinstance(devices, list): | |
devices = [devices] | |
self.parallel_devices = [d.index for d in devices if d.index is not None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parallel_devices
here need to be set to all the cuda device ids for current node. train.torch.get_device()
only return the one for the current worker.
For example, if we have two workers each with 1 GPU
- worker_0:
CUDA_VISIBLE_DEVICES=2,3
,train.torch.get_device()=torch.device(cuda:0)
- worker_1:
CUDA_VISIBLE_DEVICES=2,3
,train.torch.get_device()=torch.device(cuda:1)
parallel_devices
here should be [torch.device(cuda:0), torch.device(cuda:1)]
Lightning added a special check for deepspeed, which requires the indices of parallel_devices
equals to list(range(len(parallel_devices)))
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
…ect#37387) Signed-off-by: NripeshN <nn2012@hw.ac.uk>
…ect#37387) Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
…ect#37387) Signed-off-by: Victor <vctr.y.m@example.com>
Why are these changes needed?
LightningTrainer deepspeed ci test failed due to lightning upgrade from 2.0.4 to 2.0.5, which introduced a check on device ranks in DeepSpeedStrategy Link. This PR aims to address the incompatibility and fix the test.
Related issue number
Fix #37374
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.