Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace non-reentrant checkpoint with a rewrite that can be nested and contain grad #90105

Closed
wants to merge 61 commits into from

Conversation

soulitzer
Copy link
Contributor

@soulitzer soulitzer commented Dec 3, 2022

Stack from ghstack (oldest at bottom):

Changes:

  • bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications:
    • Accessing _saved_tensors multiple times will silently recompute forward multiple times.
    • Accessing ctx.saved_tensor twice in the same backward will now raise an error.
  • To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default.

Before land:

  • import to check for more bc-breakingness
  • implement any workarounds for the bc-breaking-ness, if we decide on any
  • update docs to reflect new lifetime of recomputed variables
  • update docs to mention the early stop feature

Follow ups:

  • enable early-stopping by default
  • update docs/tutorial to feature nested use cases

Related docs:

cc @ezyang @gchanan

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 3, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90105

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit fd5b2d2:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

soulitzer added a commit that referenced this pull request Dec 3, 2022
ghstack-source-id: 167d29c238ed0efee74fca8326598a5469e2095c
Pull Request resolved: #90105
@soulitzer soulitzer marked this pull request as draft December 3, 2022 03:08
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all.

```
import torch
from torch.autograd.graph import _checkpoint as checkpoint

x = torch.ones(1, requires_grad=True)
def fn(x):
  return x.sin().exp().sin()

def c(fn):
  def wrapped(*args, **kwargs):
    return checkpoint(fn, *args, **kwargs)
  return wrapped

def g(fn):
    def wrapper(x):
        with torch.enable_grad():
            out = fn(x)
            grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0]
        return grad_input
    return wrapper

def sum(fn):
    def wrapped(x):
        return fn(x).sum()
    return wrapped

grad = g

out = g(c(g(c(fn))))(x)
print("out:", out)
out = g(g(fn))(x)
print("out:", out)
```

[ghstack-poisoned]
@soulitzer soulitzer added the topic: not user facing topic category label Dec 3, 2022
soulitzer added a commit that referenced this pull request Dec 3, 2022
ghstack-source-id: d12f192deff0e1ce0e34d02a6b58ab7b2552e93e
Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all.

```
import torch
from torch.autograd.graph import _checkpoint as checkpoint

x = torch.ones(1, requires_grad=True)
def fn(x):
  return x.sin().exp().sin()

def c(fn):
  def wrapped(*args, **kwargs):
    return checkpoint(fn, *args, **kwargs)
  return wrapped

def g(fn):
    def wrapper(x):
        with torch.enable_grad():
            out = fn(x)
            grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0]
        return grad_input
    return wrapper

def sum(fn):
    def wrapped(x):
        return fn(x).sum()
    return wrapped

grad = g

out = g(c(g(c(fn))))(x)
print("out:", out)
out = g(g(fn))(x)
print("out:", out)
```

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 3, 2022
ghstack-source-id: edc647882164d9b802fa45f6333456cbb37a66af
Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all.

```
import torch
from torch.autograd.graph import _checkpoint as checkpoint

x = torch.ones(1, requires_grad=True)
def fn(x):
  return x.sin().exp().sin()

def c(fn):
  def wrapped(*args, **kwargs):
    return checkpoint(fn, *args, **kwargs)
  return wrapped

def g(fn):
    def wrapper(x):
        with torch.enable_grad():
            out = fn(x)
            grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0]
        return grad_input
    return wrapper

def sum(fn):
    def wrapped(x):
        return fn(x).sum()
    return wrapped

grad = g

out = g(c(g(c(fn))))(x)
print("out:", out)
out = g(g(fn))(x)
print("out:", out)
```

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 4, 2022
ghstack-source-id: 8eb516da5dad6fcbe573de5eb8de5becd68ce45d
Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all.

```
import torch
from torch.autograd.graph import _checkpoint as checkpoint

x = torch.ones(1, requires_grad=True)
def fn(x):
  return x.sin().exp().sin()

def c(fn):
  def wrapped(*args, **kwargs):
    return checkpoint(fn, *args, **kwargs)
  return wrapped

def g(fn):
    def wrapper(x):
        with torch.enable_grad():
            out = fn(x)
            grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0]
        return grad_input
    return wrapper

def sum(fn):
    def wrapped(x):
        return fn(x).sum()
    return wrapped

grad = g

out = g(c(g(c(fn))))(x)
print("out:", out)
out = g(g(fn))(x)
print("out:", out)
```

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 12, 2022
ghstack-source-id: 81eb1fc53bd391424ab85f25ed209989ef2ea7e2
Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all.

```
import torch
from torch.autograd.graph import _checkpoint as checkpoint

x = torch.ones(1, requires_grad=True)
def fn(x):
  return x.sin().exp().sin()

def c(fn):
  def wrapped(*args, **kwargs):
    return checkpoint(fn, *args, **kwargs)
  return wrapped

def g(fn):
    def wrapper(x):
        with torch.enable_grad():
            out = fn(x)
            grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0]
        return grad_input
    return wrapper

def sum(fn):
    def wrapped(x):
        return fn(x).sum()
    return wrapped

grad = g

out = g(c(g(c(fn))))(x)
print("out:", out)
out = g(g(fn))(x)
print("out:", out)
```

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 26, 2022
ghstack-source-id: aa0f8bf5ac7b210375a7e66fffc1e28c27840930
Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all.

```
import torch
from torch.autograd.graph import _checkpoint as checkpoint

x = torch.ones(1, requires_grad=True)
def fn(x):
  return x.sin().exp().sin()

def c(fn):
  def wrapped(*args, **kwargs):
    return checkpoint(fn, *args, **kwargs)
  return wrapped

def g(fn):
    def wrapper(x):
        with torch.enable_grad():
            out = fn(x)
            grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0]
        return grad_input
    return wrapper

def sum(fn):
    def wrapped(x):
        return fn(x).sum()
    return wrapped

grad = g

out = g(c(g(c(fn))))(x)
print("out:", out)
out = g(g(fn))(x)
print("out:", out)
```

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 27, 2022
ghstack-source-id: 1cc1ae46772f7db4a111092a1f2c6c5b7a112e5a
Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all.

```
import torch
from torch.autograd.graph import _checkpoint as checkpoint

x = torch.ones(1, requires_grad=True)
def fn(x):
  return x.sin().exp().sin()

def c(fn):
  def wrapped(*args, **kwargs):
    return checkpoint(fn, *args, **kwargs)
  return wrapped

def g(fn):
    def wrapper(x):
        with torch.enable_grad():
            out = fn(x)
            grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0]
        return grad_input
    return wrapper

def sum(fn):
    def wrapped(x):
        return fn(x).sum()
    return wrapped

grad = g

out = g(c(g(c(fn))))(x)
print("out:", out)
out = g(g(fn))(x)
print("out:", out)
```

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 27, 2022
ghstack-source-id: 5805313f59c92e32e791d10e3021cc3ed100a81d
Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all.

```
import torch
from torch.autograd.graph import _checkpoint as checkpoint

x = torch.ones(1, requires_grad=True)
def fn(x):
  return x.sin().exp().sin()

def c(fn):
  def wrapped(*args, **kwargs):
    return checkpoint(fn, *args, **kwargs)
  return wrapped

def g(fn):
    def wrapper(x):
        with torch.enable_grad():
            out = fn(x)
            grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0]
        return grad_input
    return wrapper

def sum(fn):
    def wrapped(x):
        return fn(x).sum()
    return wrapped

grad = g

out = g(c(g(c(fn))))(x)
print("out:", out)
out = g(g(fn))(x)
print("out:", out)
```

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 27, 2022
ghstack-source-id: 8a1a99874f864eb1f51e567690982f7ba5270ad2
Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all.

```
import torch
from torch.autograd.graph import _checkpoint as checkpoint

x = torch.ones(1, requires_grad=True)
def fn(x):
  return x.sin().exp().sin()

def c(fn):
  def wrapped(*args, **kwargs):
    return checkpoint(fn, *args, **kwargs)
  return wrapped

def g(fn):
    def wrapper(x):
        with torch.enable_grad():
            out = fn(x)
            grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0]
        return grad_input
    return wrapper

def sum(fn):
    def wrapped(x):
        return fn(x).sum()
    return wrapped

grad = g

out = g(c(g(c(fn))))(x)
print("out:", out)
out = g(g(fn))(x)
print("out:", out)
```

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 27, 2022
ghstack-source-id: 722f264d012206c76da8cc0f65723b3322fb9634
Pull Request resolved: #90105
Basically works in all situations I've tested, as in it correctly recomputes all the values when using nested checkpointing and even if you mix that with higher order autograd. However, I don't know how much memory it actually saves; it could be leaking everything which may defeat the purpose of checkpointing at all.

```
import torch
from torch.autograd.graph import _checkpoint as checkpoint

x = torch.ones(1, requires_grad=True)
def fn(x):
  return x.sin().exp().sin()

def c(fn):
  def wrapped(*args, **kwargs):
    return checkpoint(fn, *args, **kwargs)
  return wrapped

def g(fn):
    def wrapper(x):
        with torch.enable_grad():
            out = fn(x)
            grad_input = torch.autograd.grad(out, inputs=(x,), create_graph=True)[0]
        return grad_input
    return wrapper

def sum(fn):
    def wrapped(x):
        return fn(x).sum()
    return wrapped

grad = g

out = g(c(g(c(fn))))(x)
print("out:", out)
out = g(g(fn))(x)
print("out:", out)
```

[ghstack-poisoned]
…e nested and contain grad"


Changes:
- bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications:
   - Accessing _saved_tensors multiple times will silently recompute forward multiple times.
   - Accessing ctx.saved_tensor twice in the same backward will now raise an error. 
- To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default.

Before land:
- import to check for more bc-breakingness
- implement any workarounds for the bc-breaking-ness, if we decide on any 
- update docs to reflect new lifetime of recomputed variables
- update docs to mention the early stop feature

Follow ups:
- enable early-stopping by default
- update docs/tutorial to feature nested use cases

Related docs:
  - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448
  - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit#
  - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit

cc ezyang gchanan

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Mar 8, 2023
ghstack-source-id: fd516e630cc582240055c6046961cc21db0a79ba
Pull Request resolved: #90105
…e nested and contain grad"


Changes:
- bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications:
   - Accessing _saved_tensors multiple times will silently recompute forward multiple times.
   - Accessing ctx.saved_tensor twice in the same backward will now raise an error. 
- To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default.

Before land:
- import to check for more bc-breakingness
- implement any workarounds for the bc-breaking-ness, if we decide on any 
- update docs to reflect new lifetime of recomputed variables
- update docs to mention the early stop feature

Follow ups:
- enable early-stopping by default
- update docs/tutorial to feature nested use cases

Related docs:
  - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448
  - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit#
  - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit

cc ezyang gchanan

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Mar 10, 2023
ghstack-source-id: dcd9d19ea7f342a01074e07a8dc5b3c577676862
Pull Request resolved: #90105
torch/utils/checkpoint.py Outdated Show resolved Hide resolved
torch/utils/checkpoint.py Outdated Show resolved Hide resolved
torch/utils/checkpoint.py Outdated Show resolved Hide resolved
torch/utils/checkpoint.py Show resolved Hide resolved
torch/utils/checkpoint.py Show resolved Hide resolved
torch/utils/checkpoint.py Outdated Show resolved Hide resolved
torch/utils/checkpoint.py Outdated Show resolved Hide resolved
torch/utils/checkpoint.py Outdated Show resolved Hide resolved
# generate a temporary id if we trigger unpack outside of a backward call
gid = int(uuid.uuid4())

if frame.is_recomputed[gid]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that it is new in this impl, but we should add an issue to make this thread safe ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be missing something, but doesn't the GIL protect this?

Copy link
Collaborator

@albanD albanD Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GIL ensures that you won't corrupt data and badly crash. But at any point within this code, the GIL might be released and another thread jump in.
So if there are two calls to backward() in two different threads, we might be running this function twice at the same time from different threads. So we must make sure that any shared data structure will NOT be written into while we read it to make sure we don't see stale data.
For example https://stackoverflow.com/questions/1312331/using-a-global-dictionary-with-threads-in-python

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting, thanks for the clarification!

torch/utils/checkpoint.py Outdated Show resolved Hide resolved
…e nested and contain grad"


Changes:
- bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications:
   - Accessing _saved_tensors multiple times will silently recompute forward multiple times.
   - Accessing ctx.saved_tensor twice in the same backward will now raise an error. 
- To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default.

Before land:
- import to check for more bc-breakingness
- implement any workarounds for the bc-breaking-ness, if we decide on any 
- update docs to reflect new lifetime of recomputed variables
- update docs to mention the early stop feature

Follow ups:
- enable early-stopping by default
- update docs/tutorial to feature nested use cases

Related docs:
  - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448
  - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit#
  - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit

cc ezyang gchanan

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Mar 13, 2023
ghstack-source-id: 663eee148bf7a8b357d867b359a76b3724b508b2
Pull Request resolved: #90105
@soulitzer soulitzer requested a review from albanD March 13, 2023 22:58
torch/utils/checkpoint.py Show resolved Hide resolved
torch/utils/checkpoint.py Outdated Show resolved Hide resolved
torch/utils/checkpoint.py Show resolved Hide resolved
torch/utils/checkpoint.py Show resolved Hide resolved

if not frame.is_recomputed[gid]:
ctx = frame.input_saver.grad_fn
args = ctx.get_args(ctx.saved_tensors)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you don't really need to pass in the saved_tensors here, you can easily access them from within the ctx.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm this causes a leak for some reason, still investigating.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be a follow up. Sounds ok.

Copy link
Contributor Author

@soulitzer soulitzer Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a true memory leak. This is a cycle that python can detect.

ctx -> get_args -> closure -> ctx

The issue is that we need to call gc.collect() in the test:

import torch
import weakref
import gc

def scope():

  class Test(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
      def fn():
        a = ctx
      ctx.fn = fn
      return x

    @staticmethod
    def backward(ctx, x):
      return x

  a = torch.tensor(1., requires_grad=True)
  out = Test.apply(a)
  # Already a weakref tbh
  ref = weakref.ref(out.grad_fn)

  return ref

ref = scope()
gc.collect()
print(ref())

However, I think we may prefer the way it is now, just so that things get cleared faster without need for a manual gc.collect().

torch/utils/checkpoint.py Outdated Show resolved Hide resolved
…e nested and contain grad"


Changes:
- bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications:
   - Accessing _saved_tensors multiple times will silently recompute forward multiple times.
   - Accessing ctx.saved_tensor twice in the same backward will now raise an error. 
- To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default.

Before land:
- import to check for more bc-breakingness
- implement any workarounds for the bc-breaking-ness, if we decide on any 
- update docs to reflect new lifetime of recomputed variables
- update docs to mention the early stop feature

Follow ups:
- enable early-stopping by default
- update docs/tutorial to feature nested use cases

Related docs:
  - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448
  - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit#
  - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit

cc ezyang gchanan

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Mar 14, 2023
ghstack-source-id: be788002d4a3f15635a57b95c659928df65e5678
Pull Request resolved: #90105
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

…e nested and contain grad"


Changes:
- bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications:
   - Accessing _saved_tensors multiple times will silently recompute forward multiple times.
   - Accessing ctx.saved_tensor twice in the same backward will now raise an error. 
- To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default.

Before land:
- import to check for more bc-breakingness
- implement any workarounds for the bc-breaking-ness, if we decide on any 
- update docs to reflect new lifetime of recomputed variables
- update docs to mention the early stop feature

Follow ups:
- enable early-stopping by default
- update docs/tutorial to feature nested use cases

Related docs:
  - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448
  - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit#
  - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit

cc ezyang gchanan

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Mar 14, 2023
ghstack-source-id: 1907cf81e950f885566a301a3d250f1a8cd89165
Pull Request resolved: #90105
@soulitzer
Copy link
Contributor Author

@pytorchbot merge -f "Unrelated failures"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 23, 2023
…d contain grad (#90105)

Changes:
- bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications:
   - Accessing _saved_tensors multiple times will silently recompute forward multiple times.
   - Accessing ctx.saved_tensor twice in the same backward will now raise an error.
- To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default.

Before land:
- import to check for more bc-breakingness
- implement any workarounds for the bc-breaking-ness, if we decide on any
- update docs to reflect new lifetime of recomputed variables
- update docs to mention the early stop feature

Follow ups:
- enable early-stopping by default
- update docs/tutorial to feature nested use cases

Related docs:
  - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448
  - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit#
  - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit

Pull Request resolved: pytorch/pytorch#90105
Approved by: https://github.com/albanD
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 27, 2023
…d contain grad (#90105)

Changes:
- bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications:
   - Accessing _saved_tensors multiple times will silently recompute forward multiple times.
   - Accessing ctx.saved_tensor twice in the same backward will now raise an error.
- To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default.

Before land:
- import to check for more bc-breakingness
- implement any workarounds for the bc-breaking-ness, if we decide on any
- update docs to reflect new lifetime of recomputed variables
- update docs to mention the early stop feature

Follow ups:
- enable early-stopping by default
- update docs/tutorial to feature nested use cases

Related docs:
  - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448
  - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit#
  - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit

Pull Request resolved: pytorch/pytorch#90105
Approved by: https://github.com/albanD
@facebook-github-bot facebook-github-bot deleted the gh/soulitzer/154/head branch June 8, 2023 18:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: bc-breaking Related to a BC-breaking change release notes: autograd release notes category topic: bc breaking topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants