Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix selective activation checkpointing with subclasses that override …
…sizes() (#113380) The problem is that we have a subclass (FunctionalTensor) that overrides size/stride calls, causing them to go through __torch_dispatch__. But when SAC is active, we have _CachingTorchDispatchMode.__torch_dispatch__ active, that intercepts those size/stride calls first, and does something different with them instead of letting FunctionalTensor.__torch_dispatch__ handle them. This PR updates the SAC torch dispatch mode to know to not handle metadata calls, and let its tensor arguments handle them directly. Right now, `FunctionalTensor` has a hardcoded list of metadata ops, but we should probably put them somewhere more general. I'll add better testing before landing this PR. Pull Request resolved: #113380 Approved by: https://github.com/yf225, https://github.com/wanchaol
- Loading branch information