Skip to content

Symbol undefind error in the to_out_var_pass by inputs with dynamic dims #8539

@JoshuaGhost

Description

@JoshuaGhost

🐛 Describe the bug

Description and Reproduction

I followed the tutorial of building and running ExecuTorch with XNNPack backend with my custom code. The reproducible code is as follows:

import torch
import torchvision.models as models

from torch.export import export, ExportedProgram
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge_transform_and_lower
from executorch.exir.backend.backend_api import to_backend
from torch.export import Dim
from torch.export import export_for_training
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)

class SimpleModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv1d(1, 3, 10)
        self.relu = torch.nn.ReLU()
        self.linear = torch.nn.Linear(645, 1000)
        self.n_fft = 1000
        
    def forward(self, x):
        batch_size = x.shape[0]
        x = self.conv(x.view(batch_size, 1, -1))
        x = x.view(batch_size, -1)
        x = self.linear(self.relu(x))
        return torch.fft.rfft(x, n=self.n_fft).abs()
    
    
def quantize(model, example_inputs):
    quantizer = XNNPACKQuantizer()
    operator_config = get_symmetric_quantization_config(is_per_channel=False)
    quantizer.set_global(operator_config)
    m = prepare_pt2e(model, quantizer)
    m(*example_inputs)
    return convert_pt2e(m)
    
    
mobilenet_v2 = SimpleModel().eval()
sample_inputs = (torch.randn(2, 224), )
dim1_x = Dim("dim1_x", min=1, max=10)
dynamic_shapes = {"x": {0: dim1_x}}

mobilenet_v2 = export_for_training(mobilenet_v2, sample_inputs, dynamic_shapes=dynamic_shapes).module() # 2-stage export for quantization path
quantized_mobilenetv2 = quantize(mobilenet_v2, sample_inputs)

with torch.no_grad():
    edge = to_edge_transform_and_lower(
        export(quantized_mobilenetv2, sample_inputs, dynamic_shapes=dynamic_shapes),
        compile_config=EdgeCompileConfig(
            _check_ir_validity=False,
            _core_aten_ops_exception_list=[torch.ops.aten._fft_r2c.default]),
        partitioner=[XnnpackPartitioner()]
    )

exec_prog = edge.to_executorch()

with open("qs8_xnnpack_mobilenetv2.pte", "wb") as file:
    exec_prog.write_to_file(file)
     
exec_prog.exported_program().graph.print_tabular()
exec_prog.exported_program().graph_module.print_readable()

print(exec_prog.exported_program().module()(*sample_inputs))

My main modifications are:

  • replace the mobilnet_v2 with my SimpleModule with only a convolution layer, a linear layer, and an FFT layer (patched with this commit)
  • support dynamic dimension at the first dim (batch_size)

Problem

When running the last line of the code, when calling the exported module() with the *input, the following error occurs, indicating that the definition of the symbol (s0) representing the dynamic dimension is missing:

Traceback (most recent call last):
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/graph_module.py", line 387, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1855, in _call_impl
    return inner()
           ^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1803, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.845", line 15, in forward
    alloc = executorch_exir_memory_alloc(((s0, 501), torch.complex64))
                                           ^^
NameError: name 's0' is not defined

Call using an FX-traced Module, line 15 of the traced Module's generated forward function:
    getitem_1 = executorch_call_delegate_1[0];  executorch_call_delegate_1 = None
    alloc = executorch_exir_memory_alloc(((s0, 501), torch.complex64))

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    aten__fft_r2c_default = torch.ops.aten._fft_r2c.out(getitem_1, [1], 0, True, out = alloc);  getitem_1 = alloc = None

    alloc_1 = executorch_exir_memory_alloc(((s0, 501), torch.float32))

Traceback (most recent call last):
  File "/nfs/home/zizhang/workspace/convert_litert/a.py", line 99, in <module>
    print(exec_prog.exported_program().module()(*sample_inputs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/graph_module.py", line 824, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/graph_module.py", line 398, in __call__
    raise e.with_traceback(None)  # noqa: B904
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NameError: name 's0' is not defined

Investigation

I output the generated graph module after calling the to_executorch by exec_prog.exported_program().graph_module.print_readable(). The output is as follows:

class GraphModule(torch.nn.Module):
    def forward(self, x: "f32[s0, 224]"):
         # 
        sym_size: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
        
         # File: /nfs/home/zizhang/workspace/convert_litert/a.py:23 in forward, code: x = self.conv(x.view(batch_size, 1, -1))
        aten_view_copy_default: "f32[s0, 1, 224]" = executorch_exir_memory_view(x, [sym_size, 1, -1]);  x = None
        
        # No stacktrace found for following nodes
        lowered_module_0 = self.lowered_module_0
        executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, aten_view_copy_default);  lowered_module_0 = aten_view_copy_default = None
        getitem: "f32[s0, 3, 215]" = executorch_call_delegate[0];  executorch_call_delegate = None
        
         # File: /nfs/home/zizhang/workspace/convert_litert/a.py:24 in forward, code: x = x.view(batch_size, -1)
        aten_view_copy_default_1: "f32[s0, 645]" = executorch_exir_memory_view(getitem, [sym_size, -1]);  getitem = sym_size = None
        
        # No stacktrace found for following nodes
        lowered_module_1 = self.lowered_module_1
        executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, aten_view_copy_default_1);  lowered_module_1 = aten_view_copy_default_1 = None
        getitem_1: "f32[s0, 1000]" = executorch_call_delegate_1[0];  executorch_call_delegate_1 = None
        alloc: "c64[s0, 501]" = executorch_exir_memory_alloc(((s0, 501), torch.complex64))
        
         # File: /nfs/home/zizhang/workspace/convert_litert/a.py:26 in forward, code: return torch.fft.rfft(x, n=self.n_fft).abs()
        aten__fft_r2c_default: "c64[s0, 501]" = torch.ops.aten._fft_r2c.out(getitem_1, [1], 0, True, out = alloc);  getitem_1 = alloc = None
        
        # No stacktrace found for following nodes
        alloc_1: "f32[s0, 501]" = executorch_exir_memory_alloc(((s0, 501), torch.float32))
        
         # File: /nfs/home/zizhang/workspace/convert_litert/a.py:26 in forward, code: return torch.fft.rfft(x, n=self.n_fft).abs()
        aten_abs_default: "f32[s0, 501]" = torch.ops.aten.abs.out(aten__fft_r2c_default, out = alloc_1);  aten__fft_r2c_default = alloc_1 = None
        return (aten_abs_default,)

The error occurs specifically in this line:

alloc: "c64[s0, 501]" = executorch_exir_memory_alloc(((s0, 501), torch.complex64))

Where the s0 corresponds to the dynamic dimension of the input, but its definition is not found in the context. Furthermore, the FFT is not the single operation that results in this problem, any operation involving allocating new memory, like torch.full, or torch.ones could result in this problem. After digging through the ExecuTorch code, I identified that this executorch_exir_memory_alloc operation is introduced in the ToOutVarPass, specifically in this line. As a temporal resolution for the operation like torch.full, we can replace these lines:

                if len(out_args_names) == 1:
                    alloc_node = make_alloc_node(
                        graph_module, node.meta['val'], node.meta["tensor_meta"]
                        # graph_module, size, node.meta["tensor_meta"]
                    )
                    out_var_kwargs[out_args_names[0]] = alloc_node

to this:

                if len(out_args_names) == 1:
                    alloc_node = make_alloc_node(
                        graph_module, node.meta['val'], node.meta["tensor_meta"]
                        # graph_module, size, node.meta["tensor_meta"]
                    )
                    if "aten.full" in str(node.target):  # Monkey patch for full
                        alloc_node.args = ((tuple(node.args[0]), alloc_node.args[0][1]), )
                    out_var_kwargs[out_args_names[0]] = alloc_node

I.e., replace the args of the newly added alloc_node with those of the following torch. full` node. But I have no idea how to resolve it with the FFT node, not to mention how to come up with a universal solution.

Versions

Collecting environment information...
PyTorch version: 2.7.0.dev20250131+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Oracle Linux Server 9.4 (x86_64)
GCC version: (GCC) 11.4.1 20231218 (Red Hat 11.4.1-3.0.1)
Clang version: 14.0.6
CMake version: version 3.31.4
Libc version: glibc-2.34

Python version: 3.11.7 (main, Oct 9 2024, 00:00:00) [GCC 11.4.1 20231218 (Red Hat 11.4.1-3.0.1)] (64-bit runtime)
Python platform: Linux-5.15.0-207.156.6.el9uek.x86_64-x86_64-with-glibc2.34
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 40 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 8
On-line CPU(s) list: 0-7
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9J14 96-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 2
Core(s) per socket: 4
Socket(s): 1
Stepping: 1
BogoMIPS: 5192.18
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves nt_good avx512_bf16 clzero xsaveerptr wbnoinvd arat npt nrip_save vgif vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor fsrm arch_capabilities
Virtualization: AMD-V
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 256 KiB (4 instances)
L1i cache: 256 KiB (4 instances)
L2 cache: 2 MiB (4 instances)
L3 cache: 16 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-7
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] executorch==0.6.0a0+5e4d6b6
[pip3] numpy==2.0.0
[pip3] torch==2.7.0.dev20250131+cpu
[pip3] torchao==0.8.0+git11333ba2
[pip3] torchaudio==2.6.0.dev20250131+cpu
[pip3] torchgen==0.0.1
[pip3] torchsr==1.0.4
[pip3] torchvision==0.22.0.dev20250131+cpu
[pip3] triton==3.2.0
[conda] Could not collect

cc @digantdesai @mcr229 @cbilgin @mergennachin @byjlw

Metadata

Metadata

Labels

module: user experienceIssues related to reducing friction for usersmodule: xnnpackIssues related to xnnpack delegation and the code under backends/xnnpack/triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

Projects

Status

To triage

Status

In progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions