New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
functorch.compile.aot_function fails on reshape in TorchRefsMode #96055
Comments
Our compiler backend, inductor, and other backends that are being developed, target reduced set of aten operations, not prims. In existing lowering stack, calling pytorch/torch/_inductor/compile_fx.py Line 396 in fe4fec3
|
My suggestion is to not use TorchRefsMode. Why do you want that mdoe? |
Hi @ezyang I am working on quantization of networks to a custom hardware. Up to now I have used plain torch.fx for getting and manipulating th ops graph. I need to replace some operations, retrain and export. (Currently I have my own export path from fx) In order to have fewer operation to handle I am looking for a way to get FX graph of Prims. I couldn't find how to do that in a documented way. I tried this, and stumbled upon this error, requesting to open a bug. I am aware it might not be the correct way to do this, but up until now I have found no other way, |
Instead of using |
But unless you need prims specifically, you can just provide your own backend to torch.compile, similar to how is shown here #93491, and that will use standard decomposition set that will reduce the number of ops to support. |
Thanks! I will try these solutions. Prims looks like a nice set of ops to begin with, is there a way to get these decompositions with a custom backend? Somewhat related - there are modules which I don't want to decompose - things that I want to be treated as leafs. In standard fx it was easy to achieve as I could use my own tracer and give it leaf functions and modules and customize it as I needed. Is there an equivalent thing in torch.compile? |
Yes, but adjusting the decomposition list you send to compile_fx you can exclude things you don't want to decompose from it. |
Thanks, I will try that. My understanding was that the basic set of operations I will need to support are the IRs mentioned here: https://pytorch.org/docs/master/ir.html And that at some point in time, list will contain the ~100-200 operators I will need to support on our stack and compile & export will use them as a target. I understand it's not mature yet, and that it will take time, but I don't need it to have full support now, I'm also just at the design stage. Am I wrong about that being a goal? |
cc @SherlockNoMad for list of ops in ATen IR. Empirically, if you target ops currently in |
I've found the talk great @ngimel and @SherlockNoMad and others gave in the PT 2 Q&A series (BTW, its not linked from the docs page). And from there the notebook @SherlockNoMad presented. It gave me an answer on how to generate the fx graph with dynamo, and using But - when I use the Those leaf operations have complex internal code in training - custom cuda kernels and python code that causes many graph breaks (I am ok with having them run in eager mode in training). Later I on want to export them as-is. Just to give my motivation: We have a custom hardware for NN, with hardware support for some operations. We detect patterns in the network the user gives us, and replace them with a hardware-compatible version (approximately but not exactly quantized version) and preform an analogue to quantization aware training. Then, we export the graph and weights and to further processing in our stack. Currently we do this with (If this is not a good place to continue this discussion, please let me know, and I'll close the issue). |
cc @zou3519 |
I have found a way to get better control of the ops I get, which was my original goal. import torch
from torch.library import Library
from torch.fx.experimental.proxy_tensor import (
proxy_call,
get_innermost_proxy_mode,
disable_proxy_modes_tracing
)
def leaf(library: Library, schema: Optional[str]=None):
def wrapper(func):
def wrapped(*args, **kwargs):
proxy_mode = get_innermost_proxy_mode()
if not proxy_mode or getattr(wrapped,'recursive',None):
with disable_proxy_modes_tracing():
return func(*args, **kwargs)
wrapped.recursive = True
res = proxy_call(
proxy_mode,
getattr(getattr(torch.ops,library.ns), func.__name__).default,
args,
kwargs,
)
wrapped.recursive = False
return res
library.define(schema)
library.impl(func.__name__, wrapped,"Autograd")
return getattr(getattr(torch.ops,library.ns), func.__name__).default
return wrapper
mylib = Library("mylib", "DEF")
@leaf(mylib,"foo(Tensor x)->Tensor")
def foo(x):
return x*2 Then when exporting with |
We are working on an official version of this API. Cc @zou3519 |
That will be great - If there is place for contribution, I am willing to invest time in it. @zou3519, If there is something that can be done - please let me know, I'll be glad to help. |
In the file of pytorch/torch/fx/experimental/proxy_tensor.py Line 427 in ccd5ad8
# In general, we don't want to make modules leaves. In principle, users of
# this tracer might want to override this in order to turn a couple specific
# modules into leaves in the traced graph. As a user, can I actually turn some modules into leaves? Is it possible / planned? Or do you plan to support only function leaves? |
There is not an official way to do this and we don't currently plan to support it. |
馃悰 Describe the bug
When using
TorchRefsMode
, aot_function cannot handlereshape
.I am not sure that my usage is correct, but my goal is to decompose a graph to prim ops, in order to later export outside pytorch.
I tried the code below, and an internal assert is triggered, with request to open a bug:
RuntimeError: !schema.hasAnyAliasInfo() INTERNAL ASSERT FAILED at "../aten/src/ATen/FunctionalizeFallbackKernel.cpp":33, please report a bug to PyTorch. mutating and aliasing ops should all have codegen'd kernels
When not using fake tensor mode I get some error in a check for
is_quantized
at some place.If there is a "correct" way to get a prim graph, I would very much like to know what it is.
Code:
Stack trace below:
Versions
Collecting environment information...
PyTorch version: 2.0.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 9.0.0 (tags/RELEASE_900/final)
CMake version: version 3.25.0
Libc version: glibc-2.27
Python version: 3.10.6 (main, Aug 30 2022, 16:00:07) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.15.87-051587-generic-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-PCIE-40GB
GPU 1: NVIDIA A100-PCIE-40GB
Nvidia driver version: 525.78.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 64
On-line CPU(s) list: 0-63
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7502 32-Core Processor
CPU family: 23
Model: 49
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 1
Stepping: 0
BogoMIPS: 5000.28
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
Virtualization: AMD-V
L1d cache: 1 MiB (32 instances)
L1i cache: 1 MiB (32 instances)
L2 cache: 16 MiB (32 instances)
L3 cache: 128 MiB (8 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-63
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] mypy==1.0.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.1
[pip3] pytorch-lightning==1.9.3
[pip3] pytorch-triton==2.0.0+d54c04abe2
[pip3] torch==2.0.0+cu117
[pip3] torchaudio==2.0.0+cu117
[pip3] torchmetrics==0.11.1
[pip3] torchvision==0.15.0+cu117
[conda] Could not collect
cc @ezyang @mruberry @ngimel @lezcano @peterbell10
The text was updated successfully, but these errors were encountered: