Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

functorch.compile.aot_function fails on reshape in TorchRefsMode #96055

Closed
gilfree opened this issue Mar 5, 2023 · 16 comments
Closed

functorch.compile.aot_function fails on reshape in TorchRefsMode #96055

gilfree opened this issue Mar 5, 2023 · 16 comments
Labels
module: primTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@gilfree
Copy link

gilfree commented Mar 5, 2023

馃悰 Describe the bug

When using TorchRefsMode, aot_function cannot handle reshape.

I am not sure that my usage is correct, but my goal is to decompose a graph to prim ops, in order to later export outside pytorch.

I tried the code below, and an internal assert is triggered, with request to open a bug:

RuntimeError: !schema.hasAnyAliasInfo() INTERNAL ASSERT FAILED at "../aten/src/ATen/FunctionalizeFallbackKernel.cpp":33, please report a bug to PyTorch. mutating and aliasing ops should all have codegen'd kernels

When not using fake tensor mode I get some error in a check for is_quantized at some place.

If there is a "correct" way to get a prim graph, I would very much like to know what it is.

Code:

from typing import List
import torch
from torch import Tensor
from functorch.compile import aot_function
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._prims.context import TorchRefsMode

def func(scale: Tensor,shape):
    return torch.reshape(scale,shape)


def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("custom backend called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward

scale=torch.ones(3)


with FakeTensorMode(allow_non_fake_inputs=True):
    scale=scale.clone()

with TorchRefsMode(strict=True):
    func = aot_function(func,fw_compiler=custom_backend)
    print(func(scale,shape=(-1,)))

Stack trace below:

Traceback (most recent call last):
  File "bug.py", line 39, in <module>
    print(func(scale,shape=(-1,)))
  File ".../torch/_functorch/aot_autograd.py", line 2643, in returned_function
    compiled_fn = create_aot_dispatcher_function(
  File ".../torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File ".../torch/_functorch/aot_autograd.py", line 2491, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File ".../torch/_functorch/aot_autograd.py", line 1802, in aot_wrapper_dedupe
    compiled_fn = compiler_fn(wrapped_flat_fn, deduped_flat_args, aot_config)
  File ".../torch/_functorch/aot_autograd.py", line 1278, in aot_dispatch_base
    _fw_metadata, _out = run_functionalized_fw_and_collect_metadata(
  File ".../torch/_functorch/aot_autograd.py", line 606, in inner
    flat_f_outs = f(*flat_f_args)
  File ".../torch/_functorch/aot_autograd.py", line 1800, in wrapped_flat_fn
    return flat_fn(*add_dupe_args(args))
  File ".../torch/_functorch/aot_autograd.py", line 2623, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/homes/giladf/dlo/dlo2/docs/development/design/bug.py", line 23, in func
    return torch.reshape(scale,shape)
  File ".../torch/_prims/context.py", line 191, in __torch_function__
    return func(*args, **kwargs)
  File ".../torch/_refs/__init__.py", line 3260, in reshape
    return _reshape_view_helper(a, *shape, allow_copy=True)
  File ".../torch/_refs/__init__.py", line 3156, in _reshape_view_helper
    return prims.view_of(a)
  File ".../torch/_ops.py", line 284, in __call__
    return self._op(*args, **kwargs or {})
  File ".../torch/_prims/context.py", line 172, in __torch_function__
    return orig_func(*args, **kwargs)
  File ".../torch/_ops.py", line 284, in __call__
    return self._op(*args, **kwargs or {})
  File ".../torch/_prims/__init__.py", line 286, in _autograd_impl
    return backwards_not_supported(_prim)(*args, **kwargs)
  File ".../torch/_prims_common/wrappers.py", line 320, in _autograd_impl
    return redispatch_prim(args, kwargs)
  File ".../torch/_prims_common/wrappers.py", line 290, in redispatch_prim
    return prim(*args, **kwargs)
  File ".../torch/_ops.py", line 284, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: !schema.hasAnyAliasInfo() INTERNAL ASSERT FAILED at "../aten/src/ATen/FunctionalizeFallbackKernel.cpp":33, please report a bug to PyTorch. mutating and aliasing ops should all have codegen'd kernels

Versions

Collecting environment information...
PyTorch version: 2.0.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 9.0.0 (tags/RELEASE_900/final)
CMake version: version 3.25.0
Libc version: glibc-2.27

Python version: 3.10.6 (main, Aug 30 2022, 16:00:07) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.15.87-051587-generic-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-PCIE-40GB
GPU 1: NVIDIA A100-PCIE-40GB

Nvidia driver version: 525.78.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 64
On-line CPU(s) list: 0-63
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7502 32-Core Processor
CPU family: 23
Model: 49
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 1
Stepping: 0
BogoMIPS: 5000.28
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
Virtualization: AMD-V
L1d cache: 1 MiB (32 instances)
L1i cache: 1 MiB (32 instances)
L2 cache: 16 MiB (32 instances)
L3 cache: 128 MiB (8 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-63
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] mypy==1.0.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.1
[pip3] pytorch-lightning==1.9.3
[pip3] pytorch-triton==2.0.0+d54c04abe2
[pip3] torch==2.0.0+cu117
[pip3] torchaudio==2.0.0+cu117
[pip3] torchmetrics==0.11.1
[pip3] torchvision==0.15.0+cu117
[conda] Could not collect

cc @ezyang @mruberry @ngimel @lezcano @peterbell10

@ngimel ngimel added module: primTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 7, 2023
@ngimel
Copy link
Collaborator

ngimel commented Mar 7, 2023

Our compiler backend, inductor, and other backends that are being developed, target reduced set of aten operations, not prims. In existing lowering stack, calling compile_fx with controllable set of decompositions

def compile_fx(
will produce forward and backward fx modules that consist of fewer aten ops (depending on what decompositions you are using).

@ezyang
Copy link
Contributor

ezyang commented Mar 7, 2023

My suggestion is to not use TorchRefsMode. Why do you want that mdoe?

@gilfree
Copy link
Author

gilfree commented Mar 7, 2023

Hi @ezyang

I am working on quantization of networks to a custom hardware.

Up to now I have used plain torch.fx for getting and manipulating th ops graph.

I need to replace some operations, retrain and export. (Currently I have my own export path from fx)

In order to have fewer operation to handle I am looking for a way to get FX graph of Prims.

I couldn't find how to do that in a documented way. I tried this, and stumbled upon this error, requesting to open a bug.

I am aware it might not be the correct way to do this, but up until now I have found no other way,

@ezyang
Copy link
Contributor

ezyang commented Mar 7, 2023

Instead of using aot_function, use make_fx from torch.fx.experimental.proxy_tensor and I will guess TorchRefsMode will work.

@ngimel
Copy link
Collaborator

ngimel commented Mar 7, 2023

But unless you need prims specifically, you can just provide your own backend to torch.compile, similar to how is shown here #93491, and that will use standard decomposition set that will reduce the number of ops to support.

@gilfree
Copy link
Author

gilfree commented Mar 7, 2023

But unless you need prims specifically, you can just provide your own backend to torch.compile, similar to how is shown here #93491, and that will use standard decomposition set that will reduce the number of ops to support.

Thanks! I will try these solutions.

Prims looks like a nice set of ops to begin with, is there a way to get these decompositions with a custom backend?

Somewhat related - there are modules which I don't want to decompose - things that I want to be treated as leafs.

In standard fx it was easy to achieve as I could use my own tracer and give it leaf functions and modules and customize it as I needed. Is there an equivalent thing in torch.compile?

@ngimel
Copy link
Collaborator

ngimel commented Mar 7, 2023

Yes, but adjusting the decomposition list you send to compile_fx you can exclude things you don't want to decompose from it.
Prims are a nice set of ops, but in their current state they don't cover a lot of existing pytorch ops.

@gilfree
Copy link
Author

gilfree commented Mar 7, 2023

Thanks, I will try that.

My understanding was that the basic set of operations I will need to support are the IRs mentioned here:

https://pytorch.org/docs/master/ir.html

And that at some point in time, list will contain the ~100-200 operators I will need to support on our stack and compile & export will use them as a target.

I understand it's not mature yet, and that it will take time, but I don't need it to have full support now, I'm also just at the design stage.

Am I wrong about that being a goal?

@ngimel
Copy link
Collaborator

ngimel commented Mar 7, 2023

cc @SherlockNoMad for list of ops in ATen IR. Empirically, if you target ops currently in lowering.py (there are less than 400 of them) + standard decompositions that are used by compile_fx by default, that'll give you a pretty good coverage. That set of 400 ops can be further reduced by enabling more decompositions, but that hasn't been a priority yet.

@gilfree
Copy link
Author

gilfree commented Mar 9, 2023

I've found the talk great @ngimel and @SherlockNoMad and others gave in the PT 2 Q&A series (BTW, its not linked from the docs page). And from there the notebook @SherlockNoMad presented.

It gave me an answer on how to generate the fx graph with dynamo, and using allow_in_graph I was able to retain my own leaf operations in the torch level graph.

But - when I use the aot_module_simplified to decompose the graph to aten/prims, I was unable to find a way to keep the leaf operations in the graph, same goes for the export mode (torch._dynamo.export).

Those leaf operations have complex internal code in training - custom cuda kernels and python code that causes many graph breaks (I am ok with having them run in eager mode in training). Later I on want to export them as-is.

Just to give my motivation:

We have a custom hardware for NN, with hardware support for some operations. We detect patterns in the network the user gives us, and replace them with a hardware-compatible version (approximately but not exactly quantized version) and preform an analogue to quantization aware training. Then, we export the graph and weights and to further processing in our stack.

Currently we do this with torch.fx. Our goal is to give much broader operations support and for that end we would like on one hand to decompose to maximum (e.g broadcasting->reshape, decompose softmax etc.) and on the other hand to keep some functions as they are, and since our hardware has its own implementation for them.

(If this is not a good place to continue this discussion, please let me know, and I'll close the issue).

@ezyang
Copy link
Contributor

ezyang commented Mar 13, 2023

cc @zou3519

@gilfree
Copy link
Author

gilfree commented Mar 19, 2023

I have found a way to get better control of the ops I get, which was my original goal.
It is equivalent to torch._dynamo.allow_in_graph or torch.fx.wrap but works up to the aten level, although only for functions and not for modules. Its very hacky, and I'm pretty sure its not an intended behavior, but that the only solution I currently have.
If there is a better way - It will be great. without this or other solution - I cannot export what I need, and the unwrapped method has tens of graph breaks which causes a 5~10x slowdown with torch.compile.

import torch
from torch.library import Library

from torch.fx.experimental.proxy_tensor import (
    proxy_call,
    get_innermost_proxy_mode,
    disable_proxy_modes_tracing
)


def leaf(library: Library, schema: Optional[str]=None):
    def wrapper(func):
        def wrapped(*args, **kwargs):
            proxy_mode = get_innermost_proxy_mode()
            if not proxy_mode or getattr(wrapped,'recursive',None):
                with disable_proxy_modes_tracing():
                    return func(*args, **kwargs)
            wrapped.recursive = True
            res =  proxy_call(
                proxy_mode,
                getattr(getattr(torch.ops,library.ns), func.__name__).default,
                args,
                kwargs,
            )
            wrapped.recursive = False
            return res
        library.define(schema)
        library.impl(func.__name__, wrapped,"Autograd")
        return getattr(getattr(torch.ops,library.ns), func.__name__).default
    return wrapper

mylib = Library("mylib", "DEF")
@leaf(mylib,"foo(Tensor x)->Tensor")
def foo(x):
    return x*2

Then when exporting with aten_graph=True, and core_aten_decompositions(), foo is still kept in the graph as torch.ops.mylib.foo. The real foo is a very complicated method, that causes many graph breaks.

@ezyang
Copy link
Contributor

ezyang commented Mar 19, 2023

We are working on an official version of this API. Cc @zou3519

@gilfree
Copy link
Author

gilfree commented Mar 19, 2023

That will be great - If there is place for contribution, I am willing to invest time in it.

@zou3519, If there is something that can be done - please let me know, I'll be glad to help.

@gilfree
Copy link
Author

gilfree commented Apr 20, 2023

In the file of proxy_tensor.py, under the ProrxyTensor class the following comment appears:

# In general, we don't want to make modules leaves. In principle, users of

    # In general, we don't want to make modules leaves. In principle, users of
    # this tracer might want to override this in order to turn a couple specific
    # modules into leaves in the traced graph.

As a user, can I actually turn some modules into leaves? Is it possible / planned? Or do you plan to support only function leaves?

@ezyang
Copy link
Contributor

ezyang commented Apr 20, 2023

There is not an official way to do this and we don't currently plan to support it.

@gilfree gilfree closed this as not planned Won't fix, can't repro, duplicate, stale Apr 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: primTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants