- 
                Notifications
    You must be signed in to change notification settings 
- Fork 25.7k
fix functionalization <> resnet18, make ProxyTensor work with tensor-less decomps #83207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…less decomps [ghstack-poisoned]
| 🔗 Helpful links
 ✅ No Failures (0 Pending)As of commit d04b1c4 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. | 
| Tangentially, I wonder what we can do to make this easier to debug. My debugging process was: (1) I Realized that I can repro the error when using  | 
| add a test sometime thx | 
…ith tensor-less decomps"
This should fix a few of the errors I was seeing when I turned on functionalization in torchbench. It also fixes this AOTAutograd repro with resnet18:
```
import torch
from torchvision.models import resnet18
from functorch._src.compilers import nop
from functorch._src.aot_autograd import aot_module
from functorch.compile import config
config.use_functionalize = True
model = resnet18().cuda().half().to(memory_format=torch.channels_last)
input = torch.randn(256, 3, 224, 224, device='cuda', dtype=torch.float16) \
             .to(memory_format=torch.channels_last).detach().requires_grad_(True)
input_expected = input.clone().detach().requires_grad_(True)
fn = aot_module(model, nop)
out = fn(input)
out_expected = model(input_expected)
print(torch.allclose(out, out_expected))
out.sum().backward()
out_expected.sum().backward()
print(torch.allclose(input.grad, input_expected.grad))
```
The problem was that functorch adds a decomp to the decomp table for `new_zeros`:
```
register_decomposition(aten.new_zeros, aot_autograd_decompositions)
def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
    return torch.zeros(size, dtype=inp.dtype, device=inp.device)
```
When calling that decomp from inside of `ProxyTensorDispatchMode`, the ProxyTensorMode is already disabled, and `torch.zeros` doesn't take in any tensor-like arguments, so we never end up dispatching back into python again.
The way that manifests is that the output of `new_zeros()` gets baked as a constant into the AOTAutograd FX graph.
[ghstack-poisoned]
    …ith tensor-less decomps"
This should fix a few of the errors I was seeing when I turned on functionalization in torchbench. It also fixes this AOTAutograd repro with resnet18:
```
import torch
from torchvision.models import resnet18
from functorch._src.compilers import nop
from functorch._src.aot_autograd import aot_module
from functorch.compile import config
config.use_functionalize = True
model = resnet18().cuda().half().to(memory_format=torch.channels_last)
input = torch.randn(256, 3, 224, 224, device='cuda', dtype=torch.float16) \
             .to(memory_format=torch.channels_last).detach().requires_grad_(True)
input_expected = input.clone().detach().requires_grad_(True)
fn = aot_module(model, nop)
out = fn(input)
out_expected = model(input_expected)
print(torch.allclose(out, out_expected))
out.sum().backward()
out_expected.sum().backward()
print(torch.allclose(input.grad, input_expected.grad))
```
The problem was that functorch adds a decomp to the decomp table for `new_zeros`:
```
register_decomposition(aten.new_zeros, aot_autograd_decompositions)
def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
    return torch.zeros(size, dtype=inp.dtype, device=inp.device)
```
When calling that decomp from inside of `ProxyTensorDispatchMode`, the ProxyTensorMode is already disabled, and `torch.zeros` doesn't take in any tensor-like arguments, so we never end up dispatching back into python again.
The way that manifests is that the output of `new_zeros()` gets baked as a constant into the AOTAutograd FX graph.
[ghstack-poisoned]
    | Added a test | 
| @pytorchbot merge | 
| @pytorchbot successfully started a merge job. Check the current status here and land check progress here. | 
…less decomps (#83207) This should fix a few of the errors I was seeing when I turned on functionalization in torchbench. It also fixes this AOTAutograd repro with resnet18: ``` import torch from torchvision.models import resnet18 from functorch._src.compilers import nop from functorch._src.aot_autograd import aot_module from functorch.compile import config config.use_functionalize = True model = resnet18().cuda().half().to(memory_format=torch.channels_last) input = torch.randn(256, 3, 224, 224, device='cuda', dtype=torch.float16) \ .to(memory_format=torch.channels_last).detach().requires_grad_(True) input_expected = input.clone().detach().requires_grad_(True) fn = aot_module(model, nop) out = fn(input) out_expected = model(input_expected) print(torch.allclose(out, out_expected)) out.sum().backward() out_expected.sum().backward() print(torch.allclose(input.grad, input_expected.grad)) ``` The problem was that functorch adds a decomp to the decomp table for `new_zeros`: ``` @register_decomposition(aten.new_zeros, aot_autograd_decompositions) def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None): return torch.zeros(size, dtype=inp.dtype, device=inp.device) ``` When calling that decomp from inside of `ProxyTensorDispatchMode`, the ProxyTensorMode is already disabled, and `torch.zeros` doesn't take in any tensor-like arguments, so we never end up dispatching back into python again. The way that manifests is that the output of `new_zeros()` gets baked as a constant into the AOTAutograd FX graph. Pull Request resolved: #83207 Approved by: https://github.com/ezyang
| Hey @bdhirsh. | 
…less decomps (#83207) (#83207) Summary: This should fix a few of the errors I was seeing when I turned on functionalization in torchbench. It also fixes this AOTAutograd repro with resnet18: ``` import torch from torchvision.models import resnet18 from functorch._src.compilers import nop from functorch._src.aot_autograd import aot_module from functorch.compile import config config.use_functionalize = True model = resnet18().cuda().half().to(memory_format=torch.channels_last) input = torch.randn(256, 3, 224, 224, device='cuda', dtype=torch.float16) \ .to(memory_format=torch.channels_last).detach().requires_grad_(True) input_expected = input.clone().detach().requires_grad_(True) fn = aot_module(model, nop) out = fn(input) out_expected = model(input_expected) print(torch.allclose(out, out_expected)) out.sum().backward() out_expected.sum().backward() print(torch.allclose(input.grad, input_expected.grad)) ``` The problem was that functorch adds a decomp to the decomp table for `new_zeros`: ``` register_decomposition(aten.new_zeros, aot_autograd_decompositions) def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None): return torch.zeros(size, dtype=inp.dtype, device=inp.device) ``` When calling that decomp from inside of `ProxyTensorDispatchMode`, the ProxyTensorMode is already disabled, and `torch.zeros` doesn't take in any tensor-like arguments, so we never end up dispatching back into python again. The way that manifests is that the output of `new_zeros()` gets baked as a constant into the AOTAutograd FX graph. Pull Request resolved: #83207 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ba90c9f2298433778cc6a7a2008d0299aa2911da Reviewed By: atalman Differential Revision: D38658491 Pulled By: bdhirsh fbshipit-source-id: 63ef387ca5385acc8e1b0cb66bf4b7d8f3bc5636
This should fix a few of the errors I was seeing when I turned on functionalization in torchbench. It also fixes this AOTAutograd repro with resnet18:
The problem was that functorch adds a decomp to the decomp table for
new_zeros:When calling that decomp from inside of
ProxyTensorDispatchMode, the ProxyTensorMode is already disabled, andtorch.zerosdoesn't take in any tensor-like arguments, so we never end up dispatching back into python again.The way that manifests is that the output of
new_zeros()gets baked as a constant into the AOTAutograd FX graph.Stack from ghstack (oldest at bottom):