-
Notifications
You must be signed in to change notification settings - Fork 25.6k
add checkpoint support for custom device #99626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99626
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New FailureAs of commit 3eaa287: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@albanD sorry to borther you , can we take a look? and the failed check |
@pytorchbot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Successfully rebased |
2ed2b21
to
de25b5d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious what @soulitzer thinks but my personal feeling is that this should always work on all devices and not expect the user to specify the device?
yes, it should always work on all devices. And in the implementation there are some funcs related to |
@heidongxianhua I think what Alban means here is that it should work for all devices without the user having to explicitly pass in the device at all (for example maybe infer the device somehow from the inputs?) |
yeah, thanks for your reply @soulitzer . The inputs parameter here has no restriction type, so there may be no tensor and we could not to infer the device from inputs. so I a add a func to specify the device. |
de25b5d
to
b564c7a
Compare
@soulitzer I have reviewed these code, the inputs may only have no tensor or may be empty, so we can not get a device type from the inputs. And I have make the arg |
torch/utils/checkpoint.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It appears that in this comment we already do this tradeoff, so I think its okay to make the same trade off here.
We should just have a note on the checkpoint docs that device state is only preserved for devices of the Tensor args, the workaround if there are no Tensor args is just to explicitly pass in a dummy tensor on the correct device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It appears that in this comment we already do this tradeoff, so I think its okay to make the same trade off here.
We should just have a note on the checkpoint docs that device state is only preserved for devices of the Tensor args, the workaround if there are no Tensor args is just to explicitly pass in a dummy tensor on the correct device.
yeah, thanks for your reminder, I hadn't noticed the note. And I made some modifications to extract the device information from the input parameters. @soulitzer
Thanks for the update. From talking with @albanD, I don't think should need this constraint that there must only be a single non-cpu device. Let's just loop through all devices and just get/set state for all of them? |
yeah, thanks for your comment. @albanD @soulitzer . And there are a few points: 1. we add a func to extract the device type info from |
@heidongxianhua I guess this is bc-breaking because previously one was allowed to pass in tensors on multiple devices. The device state wouldn't be saved, but that may not matter unless the checkpointed functions have randomness. One way to not make it bc-breaking is just to remove the error (and therefore just take the first device). To make behavior identical to what it was before we'd have to check if cuda is one of the devices and prioritize saving the device state of cuda even if cuda is not the first device type. Ideally we'd just support saving state of devices from multiple device types though, could you clarify what is difficult about that? |
08b49de
to
22ba7dc
Compare
yes, I got it. If there are multiple devices, we use the first device type, and if there has cuda, using cuda is ok. It is a good solution,thank you. Now it is ok for multiple devices. And I give changes again as you said, maybe you could review again. @soulitzer |
Thanks, one more thing we want to do is for it to warn when one of the devices is getting ignored - so just change that error into a warning would be good. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, beyond just adding back the check you had as a warning, just had some small comments.
torch/utils/checkpoint.py
Outdated
def get_device_type(): | ||
return DefaultDevice._default_device_type | ||
|
||
def infer_device_type(*args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's make this a private function (by prepending underscore to its name?)
or we should add its name to the __all__
list in this file
same applies to DefaultDevice class above
I guess I can imagine DefaultDevice being used elsewhere, but curious what the reasoning would be for infer_device_type.
device_module = _get_device_module(device) | ||
device_autocast_kwargs = {"enabled": device_module.is_autocast_enabled(), | ||
"dtype": device_module.get_autocast_dtype(), | ||
"cache_enabled": torch.is_autocast_cache_enabled()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure, this is intentional right? (is there no such thing as device_module.is_autocast_cache_enabled()
)
torch/utils/checkpoint.py
Outdated
# Cuda was not initialized before running the forward, so we didn't | ||
# stash the CUDA state. | ||
if device_module._initialized and preserve_rng_state and not had_device_in_fwd: | ||
# Deivce was not initialized before running the forward, so we didn't |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deivce -> Device
torch/utils/checkpoint.py
Outdated
device_types = list({arg.device.type for arg in args | ||
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"}) | ||
if len(device_types) > 1: | ||
warnings.warn("Tensor args except CPU tensor are on at least two devices ", device_types, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe something along the lines of:
Tensor arguments, excluding CPU tensors, are detected on at least two
types of devices. Device state will only be saved for devices of a single
device type, and the remaining devices will be ignored. Consequently,
if any checkpointed functions involve randomness, this may result in
incorrect gradients. (Note that if CUDA devices are among the devices
detected, it will be prioritized; otherwise, the first device encountered will
be selected.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just had a suggestion on how to word the warning. Also don't forget about lint.
2aabcc9
to
f1e852a
Compare
@pytorchbot merge |
0fe93fe
to
7ea1acf
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot merge |
The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot. |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
torch/utils/checkpoint.py
Outdated
|
||
@staticmethod | ||
def _get_device_type(): | ||
return _DefaultDevice._default_device_type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could do something like this?
class DefaultDevice:
"""
A class that manages the default device type for checkpointing.
If no non-CPU tensors are present, the default device type will be used.
The default value is 'cuda'. The device type is used in the checkpointing
process when determining which device states to save and restore
for recomputation.
"""
_default_device_type = "cuda"
@staticmethod
def set_device_type(device: str = "cuda"):
"""
Set the default device type for checkpointing.
Args:
device (str): The device type to be set as default. Default is 'cuda'.
"""
DefaultDevice._default_device_type = device
@staticmethod
def get_device_type() -> str:
"""
Get the current default device type for checkpointing.
Returns:
str: The current default device type.
"""
return DefaultDevice._default_device_type
Merge failedReason: 1 jobs have failed, first few of them are: windows-binary-libtorch-debug / libtorch-cpu-shared-with-deps-debug-test Details for Dev Infra teamRaised by workflow job |
Thanks for the update @heidongxianhua, I just had one final comment on the docs Would also be good to have a look at the doc build before land, to make sure things render properly. |
thanks for your comment, I add the detailed comment@soulitzer. I think this function is simple and clear, so I do not give detailed comments. But your comment is very well and detailed, and is user friendly. |
Thanks for adding that! Sorry for the back and forth on this one - I just had another random comment on naming - maybe it would be better if DefaultDevice were named DefaultDeviceType? I think its a super subtle difference, but might make things slightly clearer. |
yeah,it is better to be named with |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #ISSUE_NUMBER
1、add checkpoint support for custom device
2、add a device argument, I want to add a device="cuda" parameter to the func
forward
ofCheckpointFunction
, and I can specify the device type when using it, but the funcapply
oftorch.autograd.Function
does not supportkwargs
, so I added a variable named_device
.