Skip to content

Commit

Permalink
Fix selective activation checkpointing with subclasses that override …
Browse files Browse the repository at this point in the history
…sizes() (#113380)

The problem is that we have a subclass (FunctionalTensor) that overrides size/stride calls, causing them to go through __torch_dispatch__.

But when SAC is active, we have _CachingTorchDispatchMode.__torch_dispatch__ active, that intercepts those size/stride calls first, and does something different with them instead of letting FunctionalTensor.__torch_dispatch__ handle them.

This PR updates the SAC torch dispatch mode to know to not handle metadata calls, and let its tensor arguments handle them directly.

Right now, `FunctionalTensor` has a hardcoded list of metadata ops, but we should probably put them somewhere more general.

I'll add better testing before landing this PR.

Pull Request resolved: #113380
Approved by: https://github.com/yf225, https://github.com/wanchaol
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Nov 10, 2023
1 parent cb48f78 commit 7064fbf
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,7 @@ def _detach(x):
_ignored_ops = {
torch.ops.prim.device.default,
torch.ops.aten.detach.default,
}
} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)


class _CachingTorchDispatchMode(TorchDispatchMode):
Expand Down

0 comments on commit 7064fbf

Please sign in to comment.