Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update on "Update save_on_cpu and checkpointing to work with functorc…
…h wrapped tensors" Design doc: https://docs.google.com/document/d/1OX5__xKsZP-natgEnsRrK4gfD0wSN9WS6j4kekfn-IA/edit This approach saves the inner-most tensor. As we unwrap, we also append to a list of rewrap functions that capture the necessary information to restore the metadata on the original wrappers. This PR tries to do most things in Python, but there are probably some APIs that could exist (or maybe already exist) that could simplify this PR. - This PR does very weird things to stash autograd metadata: - The rewrap function needs to capture autograd metadata so that the higher order graphs don't get disconnected, we reuse TensorWrapper to do this, but in a way that is careful not to save the original TensorWrapper's data - During packing, we do a clone on the original TensorWrapper, then replace the value_ with an empty tensor, so this new dataless TensorWrapper gets captured instead by rewrap fn - During unpacking, when we run the rewrap fn, we just replace the value_ with the value we desire (this could either be the recomputed value or value that was previously offloaded) - The API exposed to replace value_ is set_data! - There doesn't seem to be a reliable way to uniquely identify a tensor since id() gets reused, using data_ptr helps but it is also not enough sometimes. In this PR, I'm also using the first element of the Tensor to get a test to pass. Unanswered questions: - Why did we need to enable grad mode while packing (where was it disabled) Other prototypes: - #89159 (alternate approach that saves the outer-most tensor instead and unwraps the necessary number of layers during unpack - the issue is that we cannot tell when we are saving the outer-most tensor) - #88976 (same approach as this PR, but in cpp, unfinished) TODO: - verify in tests that we are actually saving the correct amount of tensors - try a non-zero bdim - make that assert more precise [ghstack-poisoned]
- Loading branch information