Skip to content

Saved tensor hooks checkpoint implementation cannot robustly clear storage #82482

@rohan-varma

Description

@rohan-varma

🐛 Describe the bug

In the saved tensor hooks based checkpointing approach (

storage: Dict[int, Optional[torch.Tensor]] = {}
), when autograd needs to unpack an activation, it potentially re-runs the forward to recompute all activations, and then returns the activation for the index it is unpacking.

However, we currently do a storage.pop() for this to ensure we don't hold references to the tensor after the backward is over. This raises the issue that if the same tensor is unpacked twice, without a pack in between, we'll run into an error. A (silly) example repro is here:

def test_checkpointing_without_reentrant_custom_function_raises(self):

Another concern is if a tensor is packed by autograd but never unpacked, thus leading to storage leaking. Although, we are unsure if this can occur in practice.

Versions

main

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: activation checkpointingRelated to activation checkpointingmodule: autogradRelated to torch.autograd, and the autograd engine in generalneeds designWe want to add this feature but we need to figure out how firsttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions