Skip to content

Conversation

heidongxianhua
Copy link
Contributor

Fixes #ISSUE_NUMBER
1、add checkpoint support for custom device
2、add a device argument, I want to add a device="cuda" parameter to the func forward of CheckpointFunction, and I can specify the device type when using it, but the func apply of torch.autograd.Function does not support kwargs, so I added a variable named _device.

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99626

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure

As of commit 3eaa287:

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@heidongxianhua
Copy link
Contributor Author

heidongxianhua commented Apr 20, 2023

@albanD sorry to borther you , can we take a look? and the failed check No module named 'triton' seems unrelated to this change.

@heidongxianhua
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix_checkpoint_main onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout fix_checkpoint_main && git pull --rebase)

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious what @soulitzer thinks but my personal feeling is that this should always work on all devices and not expect the user to specify the device?

@albanD albanD requested a review from soulitzer April 21, 2023 18:24
@heidongxianhua
Copy link
Contributor Author

I am curious what @soulitzer thinks but my personal feeling is that this should always work on all devices and not expect the user to specify the device?

yes, it should always work on all devices. And in the implementation there are some funcs related to device and it specifies to use the cuda device, such as with torch.cuda.device, so I add a func to specify the device at first time and it can works on other devices too.

@soulitzer
Copy link
Contributor

@heidongxianhua I think what Alban means here is that it should work for all devices without the user having to explicitly pass in the device at all (for example maybe infer the device somehow from the inputs?)

@heidongxianhua
Copy link
Contributor Author

heidongxianhua commented Apr 21, 2023

@heidongxianhua I think what Alban means here is that it should work for all devices without the user having to explicitly pass in the device at all (for example maybe infer the device somehow from the inputs?)

yeah, thanks for your reply @soulitzer . The inputs parameter here has no restriction type, so there may be no tensor and we could not to infer the device from inputs. so I a add a func to specify the device.

@heidongxianhua
Copy link
Contributor Author

@soulitzer I have reviewed these code, the inputs may only have no tensor or may be empty, so we can not get a device type from the inputs. And I have make the arg device as a static argument of CheckpointFunction, and these changes will not affect the existing funcs. could you have a look again?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that in this comment we already do this tradeoff, so I think its okay to make the same trade off here.

We should just have a note on the checkpoint docs that device state is only preserved for devices of the Tensor args, the workaround if there are no Tensor args is just to explicitly pass in a dummy tensor on the correct device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that in this comment we already do this tradeoff, so I think its okay to make the same trade off here.

We should just have a note on the checkpoint docs that device state is only preserved for devices of the Tensor args, the workaround if there are no Tensor args is just to explicitly pass in a dummy tensor on the correct device.

yeah, thanks for your reminder, I hadn't noticed the note. And I made some modifications to extract the device information from the input parameters. @soulitzer

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 25, 2023
@soulitzer
Copy link
Contributor

Thanks for the update. From talking with @albanD, I don't think should need this constraint that there must only be a single non-cpu device. Let's just loop through all devices and just get/set state for all of them?

@heidongxianhua
Copy link
Contributor Author

heidongxianhua commented Apr 27, 2023

Thanks for the update. From talking with @albanD, I don't think should need this constraint that there must only be a single non-cpu device. Let's just loop through all devices and just get/set state for all of them?

yeah, thanks for your comment. @albanD @soulitzer . And there are a few points: 1. we add a func to extract the device type info from args, and if there are no tensors or no no-cpu tensors in args, return the default device. 2. add a func to set/get default device referenced aforementioned. These changes do not affect the existing code and various other types of devices(MPS/XPU/HPU and so on) can also be supported. 3. Now it does not support multi-type device(such as there are CUDA-tensors & XPU-tensor simultaneously). If we want to support this, we have to refactor many funcs, such as CheckpointFunction.backward and _checkpoint_without_reentrant. But if you want to do this, we can do it later., it will be a big change.

@soulitzer
Copy link
Contributor

@heidongxianhua I guess this is bc-breaking because previously one was allowed to pass in tensors on multiple devices. The device state wouldn't be saved, but that may not matter unless the checkpointed functions have randomness.

One way to not make it bc-breaking is just to remove the error (and therefore just take the first device). To make behavior identical to what it was before we'd have to check if cuda is one of the devices and prioritize saving the device state of cuda even if cuda is not the first device type.

Ideally we'd just support saving state of devices from multiple device types though, could you clarify what is difficult about that?

@heidongxianhua
Copy link
Contributor Author

heidongxianhua commented Apr 27, 2023

@heidongxianhua I guess this is bc-breaking because previously one was allowed to pass in tensors on multiple devices. The device state wouldn't be saved, but that may not matter unless the checkpointed functions have randomness.

One way to not make it bc-breaking is just to remove the error (and therefore just take the first device). To make behavior identical to what it was before we'd have to check if cuda is one of the devices and prioritize saving the device state of cuda even if cuda is not the first device type.

Ideally we'd just support saving state of devices from multiple device types though, could you clarify what is difficult about that?

yes, I got it. If there are multiple devices, we use the first device type, and if there has cuda, using cuda is ok. It is a good solution,thank you. Now it is ok for multiple devices. And I give changes again as you said, maybe you could review again. @soulitzer

@soulitzer
Copy link
Contributor

Thanks, one more thing we want to do is for it to warn when one of the devices is getting ignored - so just change that error into a warning would be good.

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, beyond just adding back the check you had as a warning, just had some small comments.

def get_device_type():
return DefaultDevice._default_device_type

def infer_device_type(*args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make this a private function (by prepending underscore to its name?)
or we should add its name to the __all__ list in this file

same applies to DefaultDevice class above

I guess I can imagine DefaultDevice being used elsewhere, but curious what the reasoning would be for infer_device_type.

device_module = _get_device_module(device)
device_autocast_kwargs = {"enabled": device_module.is_autocast_enabled(),
"dtype": device_module.get_autocast_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure, this is intentional right? (is there no such thing as device_module.is_autocast_cache_enabled())

# Cuda was not initialized before running the forward, so we didn't
# stash the CUDA state.
if device_module._initialized and preserve_rng_state and not had_device_in_fwd:
# Deivce was not initialized before running the forward, so we didn't
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deivce -> Device

device_types = list({arg.device.type for arg in args
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"})
if len(device_types) > 1:
warnings.warn("Tensor args except CPU tensor are on at least two devices ", device_types,
Copy link
Contributor

@soulitzer soulitzer Apr 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something along the lines of:

Tensor arguments, excluding CPU tensors, are detected on at least two 
types of devices. Device state will only be saved for devices of a single
device type, and the remaining devices will be ignored.  Consequently,
if any checkpointed functions involve randomness, this may result in
incorrect gradients. (Note that if CUDA devices are among the devices
detected, it will be prioritized; otherwise, the first device encountered will
be selected.)

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just had a suggestion on how to word the warning. Also don't forget about lint.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Apr 28, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@heidongxianhua
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 28, 2023
@heidongxianhua heidongxianhua force-pushed the fix_checkpoint_main branch from 0fe93fe to 7ea1acf Compare May 3, 2023 09:30
@heidongxianhua
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@heidongxianhua
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here


@staticmethod
def _get_device_type():
return _DefaultDevice._default_device_type
Copy link
Contributor

@soulitzer soulitzer May 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could do something like this?

class DefaultDevice:
    """
    A class that manages the default device type for checkpointing. 
    If no non-CPU tensors are present, the default device type will be used.
    The default value is 'cuda'. The device type is used in the checkpointing 
    process when determining which device states to save and restore
    for recomputation.
    """

    _default_device_type = "cuda"

    @staticmethod
    def set_device_type(device: str = "cuda"):
        """
        Set the default device type for checkpointing.

        Args:
            device (str): The device type to be set as default. Default is 'cuda'.
        """
        DefaultDevice._default_device_type = device

    @staticmethod
    def get_device_type() -> str:
        """
        Get the current default device type for checkpointing.

        Returns:
            str: The current default device type.
        """
        return DefaultDevice._default_device_type

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: windows-binary-libtorch-debug / libtorch-cpu-shared-with-deps-debug-test

Details for Dev Infra team Raised by workflow job

@soulitzer
Copy link
Contributor

soulitzer commented May 3, 2023

Thanks for the update @heidongxianhua, I just had one final comment on the docs

Would also be good to have a look at the doc build before land, to make sure things render properly.

@heidongxianhua
Copy link
Contributor Author

Thanks for the update @heidongxianhua, I just had one final comment on the docs

Would also be good to have a look at the doc build before land, to make sure things render properly.

thanks for your comment, I add the detailed comment@soulitzer. I think this function is simple and clear, so I do not give detailed comments. But your comment is very well and detailed, and is user friendly.

@soulitzer
Copy link
Contributor

Thanks for adding that! Sorry for the back and forth on this one - I just had another random comment on naming - maybe it would be better if DefaultDevice were named DefaultDeviceType? I think its a super subtle difference, but might make things slightly clearer.

@heidongxianhua
Copy link
Contributor Author

Thanks for adding that! Sorry for the back and forth on this one - I just had another random comment on naming - maybe it would be better if DefaultDevice were named DefaultDeviceType? I think its a super subtle difference, but might make things slightly clearer.

yeah,it is better to be named with DefaultDeviceType. I didn't pay attention to these details, thank you for your careful comments to making the code more perfect. @soulitzer

@heidongxianhua
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants