New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace non-reentrant checkpoint with a rewrite that can be nested and contain grad #90105
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90105
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit fd5b2d2: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 167d29c238ed0efee74fca8326598a5469e2095c Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all. ``` import torch from torch.autograd.graph import _checkpoint as checkpoint x = torch.ones(1, requires_grad=True) def fn(x): return x.sin().exp().sin() def c(fn): def wrapped(*args, **kwargs): return checkpoint(fn, *args, **kwargs) return wrapped def g(fn): def wrapper(x): with torch.enable_grad(): out = fn(x) grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0] return grad_input return wrapper def sum(fn): def wrapped(x): return fn(x).sum() return wrapped grad = g out = g(c(g(c(fn))))(x) print("out:", out) out = g(g(fn))(x) print("out:", out) ``` [ghstack-poisoned]
ghstack-source-id: d12f192deff0e1ce0e34d02a6b58ab7b2552e93e Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all. ``` import torch from torch.autograd.graph import _checkpoint as checkpoint x = torch.ones(1, requires_grad=True) def fn(x): return x.sin().exp().sin() def c(fn): def wrapped(*args, **kwargs): return checkpoint(fn, *args, **kwargs) return wrapped def g(fn): def wrapper(x): with torch.enable_grad(): out = fn(x) grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0] return grad_input return wrapper def sum(fn): def wrapped(x): return fn(x).sum() return wrapped grad = g out = g(c(g(c(fn))))(x) print("out:", out) out = g(g(fn))(x) print("out:", out) ``` [ghstack-poisoned]
ghstack-source-id: edc647882164d9b802fa45f6333456cbb37a66af Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all. ``` import torch from torch.autograd.graph import _checkpoint as checkpoint x = torch.ones(1, requires_grad=True) def fn(x): return x.sin().exp().sin() def c(fn): def wrapped(*args, **kwargs): return checkpoint(fn, *args, **kwargs) return wrapped def g(fn): def wrapper(x): with torch.enable_grad(): out = fn(x) grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0] return grad_input return wrapper def sum(fn): def wrapped(x): return fn(x).sum() return wrapped grad = g out = g(c(g(c(fn))))(x) print("out:", out) out = g(g(fn))(x) print("out:", out) ``` [ghstack-poisoned]
ghstack-source-id: 8eb516da5dad6fcbe573de5eb8de5becd68ce45d Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all. ``` import torch from torch.autograd.graph import _checkpoint as checkpoint x = torch.ones(1, requires_grad=True) def fn(x): return x.sin().exp().sin() def c(fn): def wrapped(*args, **kwargs): return checkpoint(fn, *args, **kwargs) return wrapped def g(fn): def wrapper(x): with torch.enable_grad(): out = fn(x) grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0] return grad_input return wrapper def sum(fn): def wrapped(x): return fn(x).sum() return wrapped grad = g out = g(c(g(c(fn))))(x) print("out:", out) out = g(g(fn))(x) print("out:", out) ``` [ghstack-poisoned]
ghstack-source-id: 81eb1fc53bd391424ab85f25ed209989ef2ea7e2 Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all. ``` import torch from torch.autograd.graph import _checkpoint as checkpoint x = torch.ones(1, requires_grad=True) def fn(x): return x.sin().exp().sin() def c(fn): def wrapped(*args, **kwargs): return checkpoint(fn, *args, **kwargs) return wrapped def g(fn): def wrapper(x): with torch.enable_grad(): out = fn(x) grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0] return grad_input return wrapper def sum(fn): def wrapped(x): return fn(x).sum() return wrapped grad = g out = g(c(g(c(fn))))(x) print("out:", out) out = g(g(fn))(x) print("out:", out) ``` [ghstack-poisoned]
ghstack-source-id: aa0f8bf5ac7b210375a7e66fffc1e28c27840930 Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all. ``` import torch from torch.autograd.graph import _checkpoint as checkpoint x = torch.ones(1, requires_grad=True) def fn(x): return x.sin().exp().sin() def c(fn): def wrapped(*args, **kwargs): return checkpoint(fn, *args, **kwargs) return wrapped def g(fn): def wrapper(x): with torch.enable_grad(): out = fn(x) grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0] return grad_input return wrapper def sum(fn): def wrapped(x): return fn(x).sum() return wrapped grad = g out = g(c(g(c(fn))))(x) print("out:", out) out = g(g(fn))(x) print("out:", out) ``` [ghstack-poisoned]
ghstack-source-id: 1cc1ae46772f7db4a111092a1f2c6c5b7a112e5a Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all. ``` import torch from torch.autograd.graph import _checkpoint as checkpoint x = torch.ones(1, requires_grad=True) def fn(x): return x.sin().exp().sin() def c(fn): def wrapped(*args, **kwargs): return checkpoint(fn, *args, **kwargs) return wrapped def g(fn): def wrapper(x): with torch.enable_grad(): out = fn(x) grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0] return grad_input return wrapper def sum(fn): def wrapped(x): return fn(x).sum() return wrapped grad = g out = g(c(g(c(fn))))(x) print("out:", out) out = g(g(fn))(x) print("out:", out) ``` [ghstack-poisoned]
ghstack-source-id: 5805313f59c92e32e791d10e3021cc3ed100a81d Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all. ``` import torch from torch.autograd.graph import _checkpoint as checkpoint x = torch.ones(1, requires_grad=True) def fn(x): return x.sin().exp().sin() def c(fn): def wrapped(*args, **kwargs): return checkpoint(fn, *args, **kwargs) return wrapped def g(fn): def wrapper(x): with torch.enable_grad(): out = fn(x) grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0] return grad_input return wrapper def sum(fn): def wrapped(x): return fn(x).sum() return wrapped grad = g out = g(c(g(c(fn))))(x) print("out:", out) out = g(g(fn))(x) print("out:", out) ``` [ghstack-poisoned]
ghstack-source-id: 8a1a99874f864eb1f51e567690982f7ba5270ad2 Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all. ``` import torch from torch.autograd.graph import _checkpoint as checkpoint x = torch.ones(1, requires_grad=True) def fn(x): return x.sin().exp().sin() def c(fn): def wrapped(*args, **kwargs): return checkpoint(fn, *args, **kwargs) return wrapped def g(fn): def wrapper(x): with torch.enable_grad(): out = fn(x) grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0] return grad_input return wrapper def sum(fn): def wrapped(x): return fn(x).sum() return wrapped grad = g out = g(c(g(c(fn))))(x) print("out:", out) out = g(g(fn))(x) print("out:", out) ``` [ghstack-poisoned]
ghstack-source-id: 722f264d012206c76da8cc0f65723b3322fb9634 Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all. ``` import torch from torch.autograd.graph import _checkpoint as checkpoint x = torch.ones(1, requires_grad=True) def fn(x): return x.sin().exp().sin() def c(fn): def wrapped(*args, **kwargs): return checkpoint(fn, *args, **kwargs) return wrapped def g(fn): def wrapper(x): with torch.enable_grad(): out = fn(x) grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0] return grad_input return wrapper def sum(fn): def wrapped(x): return fn(x).sum() return wrapped grad = g out = g(c(g(c(fn))))(x) print("out:", out) out = g(g(fn))(x) print("out:", out) ``` [ghstack-poisoned]
…e nested and contain grad" Changes: - bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications: - Accessing _saved_tensors multiple times will silently recompute forward multiple times. - Accessing ctx.saved_tensor twice in the same backward will now raise an error. - To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default. Before land: - import to check for more bc-breakingness - implement any workarounds for the bc-breaking-ness, if we decide on any - update docs to reflect new lifetime of recomputed variables - update docs to mention the early stop feature Follow ups: - enable early-stopping by default - update docs/tutorial to feature nested use cases Related docs: - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448 - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit# - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit cc ezyang gchanan [ghstack-poisoned]
ghstack-source-id: fd516e630cc582240055c6046961cc21db0a79ba Pull Request resolved: #90105
…e nested and contain grad" Changes: - bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications: - Accessing _saved_tensors multiple times will silently recompute forward multiple times. - Accessing ctx.saved_tensor twice in the same backward will now raise an error. - To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default. Before land: - import to check for more bc-breakingness - implement any workarounds for the bc-breaking-ness, if we decide on any - update docs to reflect new lifetime of recomputed variables - update docs to mention the early stop feature Follow ups: - enable early-stopping by default - update docs/tutorial to feature nested use cases Related docs: - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448 - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit# - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit cc ezyang gchanan [ghstack-poisoned]
ghstack-source-id: dcd9d19ea7f342a01074e07a8dc5b3c577676862 Pull Request resolved: #90105
torch/utils/checkpoint.py
Outdated
# generate a temporary id if we trigger unpack outside of a backward call | ||
gid = int(uuid.uuid4()) | ||
|
||
if frame.is_recomputed[gid]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not that it is new in this impl, but we should add an issue to make this thread safe ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may be missing something, but doesn't the GIL protect this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The GIL ensures that you won't corrupt data and badly crash. But at any point within this code, the GIL might be released and another thread jump in.
So if there are two calls to backward() in two different threads, we might be running this function twice at the same time from different threads. So we must make sure that any shared data structure will NOT be written into while we read it to make sure we don't see stale data.
For example https://stackoverflow.com/questions/1312331/using-a-global-dictionary-with-threads-in-python
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh interesting, thanks for the clarification!
…e nested and contain grad" Changes: - bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications: - Accessing _saved_tensors multiple times will silently recompute forward multiple times. - Accessing ctx.saved_tensor twice in the same backward will now raise an error. - To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default. Before land: - import to check for more bc-breakingness - implement any workarounds for the bc-breaking-ness, if we decide on any - update docs to reflect new lifetime of recomputed variables - update docs to mention the early stop feature Follow ups: - enable early-stopping by default - update docs/tutorial to feature nested use cases Related docs: - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448 - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit# - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit cc ezyang gchanan [ghstack-poisoned]
ghstack-source-id: 663eee148bf7a8b357d867b359a76b3724b508b2 Pull Request resolved: #90105
|
||
if not frame.is_recomputed[gid]: | ||
ctx = frame.input_saver.grad_fn | ||
args = ctx.get_args(ctx.saved_tensors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you don't really need to pass in the saved_tensors here, you can easily access them from within the ctx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm this causes a leak for some reason, still investigating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be a follow up. Sounds ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not a true memory leak. This is a cycle that python can detect.
ctx -> get_args -> closure -> ctx
The issue is that we need to call gc.collect() in the test:
import torch
import weakref
import gc
def scope():
class Test(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
def fn():
a = ctx
ctx.fn = fn
return x
@staticmethod
def backward(ctx, x):
return x
a = torch.tensor(1., requires_grad=True)
out = Test.apply(a)
# Already a weakref tbh
ref = weakref.ref(out.grad_fn)
return ref
ref = scope()
gc.collect()
print(ref())
However, I think we may prefer the way it is now, just so that things get cleared faster without need for a manual gc.collect().
…e nested and contain grad" Changes: - bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications: - Accessing _saved_tensors multiple times will silently recompute forward multiple times. - Accessing ctx.saved_tensor twice in the same backward will now raise an error. - To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default. Before land: - import to check for more bc-breakingness - implement any workarounds for the bc-breaking-ness, if we decide on any - update docs to reflect new lifetime of recomputed variables - update docs to mention the early stop feature Follow ups: - enable early-stopping by default - update docs/tutorial to feature nested use cases Related docs: - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448 - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit# - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit cc ezyang gchanan [ghstack-poisoned]
ghstack-source-id: be788002d4a3f15635a57b95c659928df65e5678 Pull Request resolved: #90105
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
…e nested and contain grad" Changes: - bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications: - Accessing _saved_tensors multiple times will silently recompute forward multiple times. - Accessing ctx.saved_tensor twice in the same backward will now raise an error. - To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default. Before land: - import to check for more bc-breakingness - implement any workarounds for the bc-breaking-ness, if we decide on any - update docs to reflect new lifetime of recomputed variables - update docs to mention the early stop feature Follow ups: - enable early-stopping by default - update docs/tutorial to feature nested use cases Related docs: - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448 - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit# - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit cc ezyang gchanan [ghstack-poisoned]
ghstack-source-id: 1907cf81e950f885566a301a3d250f1a8cd89165 Pull Request resolved: #90105
@pytorchbot merge -f "Unrelated failures" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…d contain grad (#90105) Changes: - bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications: - Accessing _saved_tensors multiple times will silently recompute forward multiple times. - Accessing ctx.saved_tensor twice in the same backward will now raise an error. - To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default. Before land: - import to check for more bc-breakingness - implement any workarounds for the bc-breaking-ness, if we decide on any - update docs to reflect new lifetime of recomputed variables - update docs to mention the early stop feature Follow ups: - enable early-stopping by default - update docs/tutorial to feature nested use cases Related docs: - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448 - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit# - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit Pull Request resolved: pytorch/pytorch#90105 Approved by: https://github.com/albanD
…d contain grad (#90105) Changes: - bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications: - Accessing _saved_tensors multiple times will silently recompute forward multiple times. - Accessing ctx.saved_tensor twice in the same backward will now raise an error. - To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default. Before land: - import to check for more bc-breakingness - implement any workarounds for the bc-breaking-ness, if we decide on any - update docs to reflect new lifetime of recomputed variables - update docs to mention the early stop feature Follow ups: - enable early-stopping by default - update docs/tutorial to feature nested use cases Related docs: - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448 - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit# - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit Pull Request resolved: pytorch/pytorch#90105 Approved by: https://github.com/albanD
Stack from ghstack (oldest at bottom):
Changes:
Before land:
Follow ups:
Related docs:
cc @ezyang @gchanan