Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Aug 10, 2022

This should fix a few of the errors I was seeing when I turned on functionalization in torchbench. It also fixes this AOTAutograd repro with resnet18:

import torch
from torchvision.models import resnet18

from functorch._src.compilers import nop
from functorch._src.aot_autograd import aot_module
from functorch.compile import config

config.use_functionalize = True

model = resnet18().cuda().half().to(memory_format=torch.channels_last)
input = torch.randn(256, 3, 224, 224, device='cuda', dtype=torch.float16) \
             .to(memory_format=torch.channels_last).detach().requires_grad_(True)
input_expected = input.clone().detach().requires_grad_(True)

fn = aot_module(model, nop)
out = fn(input)
out_expected = model(input_expected)
print(torch.allclose(out, out_expected))

out.sum().backward()
out_expected.sum().backward()
print(torch.allclose(input.grad, input_expected.grad))

The problem was that functorch adds a decomp to the decomp table for new_zeros:

@register_decomposition(aten.new_zeros, aot_autograd_decompositions)
def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
    return torch.zeros(size, dtype=inp.dtype, device=inp.device)

When calling that decomp from inside of ProxyTensorDispatchMode, the ProxyTensorMode is already disabled, and torch.zeros doesn't take in any tensor-like arguments, so we never end up dispatching back into python again.

The way that manifests is that the output of new_zeros() gets baked as a constant into the AOTAutograd FX graph.

Stack from ghstack (oldest at bottom):

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Aug 10, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit d04b1c4 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Aug 10, 2022

Tangentially, I wonder what we can do to make this easier to debug. My debugging process was:

(1) I Realized that I can repro the error when using channels_last tensor inputs, and the error goes away when I remove them
(2) Printed out the FX graph from AOTAutograd with/without the channels_last change, try to see what's different
(3) Noticed that there are a few extra as_strided_() calls in the problematic graph
(4) Noticed that the error goes away if I don't trace the backward, and only trace the forward. Hmm...
(5) Identified as_strided_backward() as the culprit. From there, gdb helped me realize that the output of new_zeros() wasn't a ProxyTensor for some reason (code link)

@bdhirsh bdhirsh requested review from Chillee and ezyang August 10, 2022 22:32
@ezyang
Copy link
Contributor

ezyang commented Aug 10, 2022

add a test sometime thx

…ith tensor-less decomps"

This should fix a few of the errors I was seeing when I turned on functionalization in torchbench. It also fixes this AOTAutograd repro with resnet18:
```
import torch
from torchvision.models import resnet18

from functorch._src.compilers import nop
from functorch._src.aot_autograd import aot_module
from functorch.compile import config

config.use_functionalize = True

model = resnet18().cuda().half().to(memory_format=torch.channels_last)
input = torch.randn(256, 3, 224, 224, device='cuda', dtype=torch.float16) \
             .to(memory_format=torch.channels_last).detach().requires_grad_(True)
input_expected = input.clone().detach().requires_grad_(True)

fn = aot_module(model, nop)
out = fn(input)
out_expected = model(input_expected)
print(torch.allclose(out, out_expected))

out.sum().backward()
out_expected.sum().backward()
print(torch.allclose(input.grad, input_expected.grad))
```

The problem was that functorch adds a decomp to the decomp table for `new_zeros`:
```
register_decomposition(aten.new_zeros, aot_autograd_decompositions)
def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
    return torch.zeros(size, dtype=inp.dtype, device=inp.device)
```

When calling that decomp from inside of `ProxyTensorDispatchMode`, the ProxyTensorMode is already disabled, and `torch.zeros` doesn't take in any tensor-like arguments, so we never end up dispatching back into python again.

The way that manifests is that the output of `new_zeros()` gets baked as a constant into the AOTAutograd FX graph.




[ghstack-poisoned]
…ith tensor-less decomps"

This should fix a few of the errors I was seeing when I turned on functionalization in torchbench. It also fixes this AOTAutograd repro with resnet18:
```
import torch
from torchvision.models import resnet18

from functorch._src.compilers import nop
from functorch._src.aot_autograd import aot_module
from functorch.compile import config

config.use_functionalize = True

model = resnet18().cuda().half().to(memory_format=torch.channels_last)
input = torch.randn(256, 3, 224, 224, device='cuda', dtype=torch.float16) \
             .to(memory_format=torch.channels_last).detach().requires_grad_(True)
input_expected = input.clone().detach().requires_grad_(True)

fn = aot_module(model, nop)
out = fn(input)
out_expected = model(input_expected)
print(torch.allclose(out, out_expected))

out.sum().backward()
out_expected.sum().backward()
print(torch.allclose(input.grad, input_expected.grad))
```

The problem was that functorch adds a decomp to the decomp table for `new_zeros`:
```
register_decomposition(aten.new_zeros, aot_autograd_decompositions)
def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
    return torch.zeros(size, dtype=inp.dtype, device=inp.device)
```

When calling that decomp from inside of `ProxyTensorDispatchMode`, the ProxyTensorMode is already disabled, and `torch.zeros` doesn't take in any tensor-like arguments, so we never end up dispatching back into python again.

The way that manifests is that the output of `new_zeros()` gets baked as a constant into the AOTAutograd FX graph.




[ghstack-poisoned]
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Aug 11, 2022

Added a test

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Aug 11, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here and land check progress here.
The merge job was triggered with the land checks (-l) flag. If you did not specify this flag yourself, you are likely enrolled in the land checks rollout. This means that your change will be merged once all checks on your PR and the land checks have passed (ETA 4 Hours). If you need to coordinate lands between different changes and cannot risk a land race, please add the ciflow/trunk label to your PR and wait for signal to complete, and then land your changes in proper order. Having trunk, pull, and Lint pre-run on a PR will bypass land checks and the ETA should be immediate. If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

pytorchmergebot pushed a commit that referenced this pull request Aug 11, 2022
…less decomps (#83207)

This should fix a few of the errors I was seeing when I turned on functionalization in torchbench. It also fixes this AOTAutograd repro with resnet18:
```
import torch
from torchvision.models import resnet18

from functorch._src.compilers import nop
from functorch._src.aot_autograd import aot_module
from functorch.compile import config

config.use_functionalize = True

model = resnet18().cuda().half().to(memory_format=torch.channels_last)
input = torch.randn(256, 3, 224, 224, device='cuda', dtype=torch.float16) \
             .to(memory_format=torch.channels_last).detach().requires_grad_(True)
input_expected = input.clone().detach().requires_grad_(True)

fn = aot_module(model, nop)
out = fn(input)
out_expected = model(input_expected)
print(torch.allclose(out, out_expected))

out.sum().backward()
out_expected.sum().backward()
print(torch.allclose(input.grad, input_expected.grad))
```

The problem was that functorch adds a decomp to the decomp table for `new_zeros`:
```
@register_decomposition(aten.new_zeros, aot_autograd_decompositions)
def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
    return torch.zeros(size, dtype=inp.dtype, device=inp.device)
```

When calling that decomp from inside of `ProxyTensorDispatchMode`, the ProxyTensorMode is already disabled, and `torch.zeros` doesn't take in any tensor-like arguments, so we never end up dispatching back into python again.

The way that manifests is that the output of `new_zeros()` gets baked as a constant into the AOTAutograd FX graph.

Pull Request resolved: #83207
Approved by: https://github.com/ezyang
@github-actions
Copy link
Contributor

Hey @bdhirsh.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Aug 12, 2022
…less decomps (#83207) (#83207)

Summary:
This should fix a few of the errors I was seeing when I turned on functionalization in torchbench. It also fixes this AOTAutograd repro with resnet18:
```
import torch
from torchvision.models import resnet18

from functorch._src.compilers import nop
from functorch._src.aot_autograd import aot_module
from functorch.compile import config

config.use_functionalize = True

model = resnet18().cuda().half().to(memory_format=torch.channels_last)
input = torch.randn(256, 3, 224, 224, device='cuda', dtype=torch.float16) \
             .to(memory_format=torch.channels_last).detach().requires_grad_(True)
input_expected = input.clone().detach().requires_grad_(True)

fn = aot_module(model, nop)
out = fn(input)
out_expected = model(input_expected)
print(torch.allclose(out, out_expected))

out.sum().backward()
out_expected.sum().backward()
print(torch.allclose(input.grad, input_expected.grad))
```

The problem was that functorch adds a decomp to the decomp table for `new_zeros`:
```
register_decomposition(aten.new_zeros, aot_autograd_decompositions)
def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
    return torch.zeros(size, dtype=inp.dtype, device=inp.device)
```

When calling that decomp from inside of `ProxyTensorDispatchMode`, the ProxyTensorMode is already disabled, and `torch.zeros` doesn't take in any tensor-like arguments, so we never end up dispatching back into python again.

The way that manifests is that the output of `new_zeros()` gets baked as a constant into the AOTAutograd FX graph.

Pull Request resolved: #83207
Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ba90c9f2298433778cc6a7a2008d0299aa2911da

Reviewed By: atalman

Differential Revision: D38658491

Pulled By: bdhirsh

fbshipit-source-id: 63ef387ca5385acc8e1b0cb66bf4b7d8f3bc5636
@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/295/head branch August 15, 2022 14:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants