You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When attempting to export any Pytorch model containing a BatchNorm2d layer with ExecuTorch the following error is encountered:
raise SpecViolationError(torch._export.verifier.SpecViolationError: Operator torch._ops.aten._native_batch_norm_legit_functional.default is not Aten Canonical.
I am running the following code:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._export import capture_pre_autograd_graph
from torch.export import export, ExportedProgram
from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager
from executorch import exir
Traceback (most recent call last):
File "/Users/dino-illumix/Documents/Projects/LightingPredict/executorch_export.py", line 35, in
edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 635, in to_edge
EXIRATenDialectVerifier()(program.graph_module)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/verification/verifier.py", line 57, in call
return self._check_graph_module(*args, **kwargs)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/_export/verifier.py", line 208, in _check_graph_module
_check_valid_op(node.target)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/_export/verifier.py", line 191, in _check_valid_op
self.check_valid_op(op)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/verification/verifier.py", line 87, in check_valid_op
raise SpecViolationError(
torch._export.verifier.SpecViolationError: Operator torch._ops.aten._native_batch_norm_legit_functional.default is not Aten Canonical.
I can suppress the error by running:
edge_program: EdgeProgramManager = to_edge(aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False)) # to_edge(aten_dialect
however I then get an error message:
RuntimeError: Missing out variants: {'aten::_native_batch_norm_legit_functional'}
Versions
Collecting environment information...
PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.2.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.0.40.1)
CMake version: version 3.29.3
Libc version: N/A
Python version: 3.10.14 (main, May 6 2024, 14:42:37) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.2.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
馃悰 Describe the bug
When attempting to export any Pytorch model containing a BatchNorm2d layer with ExecuTorch the following error is encountered:
raise SpecViolationError(torch._export.verifier.SpecViolationError: Operator torch._ops.aten._native_batch_norm_legit_functional.default is not Aten Canonical.
I am running the following code:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._export import capture_pre_autograd_graph
from torch.export import export, ExportedProgram
from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager
from executorch import exir
class Model(nn.Module):
def init(self):
super().init()
self.conv1 = nn.Conv2d(1, 20, 5)
self.norm = nn.BatchNorm2d(20)
self.conv2 = nn.Conv2d(20, 20, 5)
example_args = (torch.randn(1, 1, 224, 224),)
Enable eager execution
torch.set_grad_enabled(False)
pre_autograd_aten_dialect = capture_pre_autograd_graph(Model(), example_args)
Optionally do quantization:
pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, CustomBackendQuantizer))
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
print(aten_dialect)
edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)
Optionally do delegation:
edge_program = edge_program.to_backend(CustomBackendPartitioner)
executorch_program: exir.ExecutorchProgramManager = edge_program.to_executorch(
ExecutorchBackendConfig(
passes=[], # User-defined passes
)
)
with open("model.pte", "wb") as file:
file.write(executorch_program.buffer)
print("Model saved to model.pte")
which produces the following output:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[20, 1, 5, 5]", arg1_1: "f32[20]", arg2_1: "f32[20]", arg3_1: "f32[20]", arg4_1: "f32[20, 20, 5, 5]", arg5_1: "f32[20]", arg6_1: "i64[]", arg7_1: "f32[20]", arg8_1: "f32[20]", arg9_1: "f32[1, 1, 224, 224]"):
# File: /Users/dino-illumix/Documents/Projects/LightingPredict/executorch_export.py:19 in forward, code: x = F.relu(self.conv1(x))
convolution: "f32[1, 20, 220, 220]" = torch.ops.aten.convolution.default(arg9_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg9_1 = arg0_1 = arg1_1 = None
relu: "f32[1, 20, 220, 220]" = torch.ops.aten.relu.default(convolution); convolution = None
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='conv1_weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg1_1'), target='conv1_bias', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg2_1'), target='norm_weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg3_1'), target='norm_bias', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg4_1'), target='conv2_weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg5_1'), target='conv2_bias', persistent=None), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg6_1'), target='norm_num_batches_tracked', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg7_1'), target='norm_running_mean', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg8_1'), target='norm_running_var', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg9_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add'), target='norm_num_batches_tracked'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_3'), target='norm_running_mean'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_4'), target='norm_running_var'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='relu_1'), target=None)])
Range constraints: {}
Traceback (most recent call last):
File "/Users/dino-illumix/Documents/Projects/LightingPredict/executorch_export.py", line 35, in
edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 635, in to_edge
EXIRATenDialectVerifier()(program.graph_module)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/verification/verifier.py", line 57, in call
return self._check_graph_module(*args, **kwargs)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/_export/verifier.py", line 208, in _check_graph_module
_check_valid_op(node.target)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/_export/verifier.py", line 191, in _check_valid_op
self.check_valid_op(op)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/verification/verifier.py", line 87, in check_valid_op
raise SpecViolationError(
torch._export.verifier.SpecViolationError: Operator torch._ops.aten._native_batch_norm_legit_functional.default is not Aten Canonical.
o/s= OSX 14.2.1
torch==2.3.0
executorch==0.2.0
python=3.10.14
I can suppress the error by running:
edge_program: EdgeProgramManager = to_edge(aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False)) # to_edge(aten_dialect
however I then get an error message:
RuntimeError: Missing out variants: {'aten::_native_batch_norm_legit_functional'}
Versions
Collecting environment information...
PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.2.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.0.40.1)
CMake version: version 3.29.3
Libc version: N/A
Python version: 3.10.14 (main, May 6 2024, 14:42:37) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.2.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M2 Pro
Versions of relevant libraries:
[pip3] executorch==0.2.0
[pip3] numpy==1.26.4
[pip3] onnx==1.16.0
[pip3] torch==2.3.0
[pip3] torchaudio==2.3.0
[pip3] torchsr==1.0.4
[pip3] torchvision==0.18.0
[conda] executorch 0.2.0 pypi_0 pypi
[conda] numpy 1.26.4 pypi_0 pypi
[conda] torch 2.3.0 pypi_0 pypi
[conda] torchaudio 2.3.0 pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchvision 0.18.0 pypi_0 pypi
cc @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4
The text was updated successfully, but these errors were encountered: